Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| import warp as wp | |
| from warp.fem.types import ElementIndex, Coords, make_free_sample | |
| from warp.fem.geometry import Geometry | |
| from warp.fem.quadrature import Quadrature | |
| from warp.fem import cache | |
| from .topology import SpaceTopology, DiscontinuousSpaceTopology | |
| from .shape import ShapeFunction | |
| class BasisSpace: | |
| """Interface class for defining a scalar-valued basis over a geometry. | |
| A basis space makes it easy to define multiple function spaces sharing the same basis (and thus nodes) but with different valuation functions; | |
| however, it is not a required ingredient of a function space. | |
| See also: :func:`make_polynomial_basis_space`, :func:`make_collocated_function_space` | |
| """ | |
| class BasisArg: | |
| """Argument structure to be passed to device functions""" | |
| pass | |
| def __init__(self, topology: SpaceTopology): | |
| self._topology = topology | |
| self.NODES_PER_ELEMENT = self._topology.NODES_PER_ELEMENT | |
| def topology(self) -> SpaceTopology: | |
| """Underlying topology of the basis space""" | |
| return self._topology | |
| def geometry(self) -> Geometry: | |
| """Underlying geometry of the basis space""" | |
| return self._topology.geometry | |
| def basis_arg_value(self, device) -> "BasisArg": | |
| """Value for the argument structure to be passed to device functions""" | |
| return BasisSpace.BasisArg() | |
| # Helpers for generating node positions | |
| def node_positions(self, out: Optional[wp.array] = None) -> wp.array: | |
| """Returns a temporary array containing the world position for each node""" | |
| NODES_PER_ELEMENT = self.NODES_PER_ELEMENT | |
| pos_type = cache.cached_vec_type(length=self.geometry.dimension, dtype=float) | |
| node_coords_in_element = self.make_node_coords_in_element() | |
| def fill_node_positions( | |
| geo_cell_arg: self.geometry.CellArg, | |
| basis_arg: self.BasisArg, | |
| topo_arg: self.topology.TopologyArg, | |
| node_positions: wp.array(dtype=pos_type), | |
| ): | |
| element_index = wp.tid() | |
| for n in range(NODES_PER_ELEMENT): | |
| node_index = self.topology.element_node_index(geo_cell_arg, topo_arg, element_index, n) | |
| coords = node_coords_in_element(geo_cell_arg, basis_arg, element_index, n) | |
| sample = make_free_sample(element_index, coords) | |
| pos = self.geometry.cell_position(geo_cell_arg, sample) | |
| node_positions[node_index] = pos | |
| shape = (self.topology.node_count(),) | |
| if out is None: | |
| node_positions = wp.empty( | |
| shape=shape, | |
| dtype=pos_type, | |
| ) | |
| else: | |
| if out.shape != shape or not wp.types.types_equal(pos_type, out.dtype): | |
| raise ValueError( | |
| f"Out node positions array must have shape {shape} and data type {wp.types.type_repr(pos_type)}" | |
| ) | |
| node_positions = out | |
| wp.launch( | |
| dim=self.geometry.cell_count(), | |
| kernel=fill_node_positions, | |
| inputs=[ | |
| self.geometry.cell_arg_value(device=node_positions.device), | |
| self.basis_arg_value(device=node_positions.device), | |
| self.topology.topo_arg_value(device=node_positions.device), | |
| node_positions, | |
| ], | |
| ) | |
| return node_positions | |
| def make_node_coords_in_element(self): | |
| raise NotImplementedError() | |
| def make_node_quadrature_weight(self): | |
| raise NotImplementedError() | |
| def make_element_inner_weight(self): | |
| raise NotImplementedError() | |
| def make_element_outer_weight(self): | |
| return self.make_element_inner_weight() | |
| def make_element_inner_weight_gradient(self): | |
| raise NotImplementedError() | |
| def make_element_outer_weight_gradient(self): | |
| return self.make_element_inner_weight_gradient() | |
| def make_trace_node_quadrature_weight(self): | |
| raise NotImplementedError() | |
| def trace(self) -> "TraceBasisSpace": | |
| return TraceBasisSpace(self) | |
| class ShapeBasisSpace(BasisSpace): | |
| """Base class for defining shape-function-based basis spaces.""" | |
| def __init__(self, topology: SpaceTopology, shape: ShapeFunction): | |
| super().__init__(topology) | |
| self._shape = shape | |
| self.ORDER = self._shape.ORDER | |
| if hasattr(shape, "element_node_triangulation"): | |
| self.node_triangulation = self._node_triangulation | |
| if hasattr(shape, "element_node_tets"): | |
| self.node_tets = self._node_tets | |
| if hasattr(shape, "element_node_hexes"): | |
| self.node_hexes = self._node_hexes | |
| def shape(self) -> ShapeFunction: | |
| """Shape functions used for defining individual element basis""" | |
| return self._shape | |
| def name(self): | |
| return f"{self.topology.name}_{self._shape.name}" | |
| def make_node_coords_in_element(self): | |
| shape_node_coords_in_element = self._shape.make_node_coords_in_element() | |
| def node_coords_in_element( | |
| elt_arg: self.geometry.CellArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| return shape_node_coords_in_element(node_index_in_elt) | |
| return node_coords_in_element | |
| def make_node_quadrature_weight(self): | |
| shape_node_quadrature_weight = self._shape.make_node_quadrature_weight() | |
| def node_quadrature_weight( | |
| elt_arg: self.geometry.CellArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| return shape_node_quadrature_weight(node_index_in_elt) | |
| return node_quadrature_weight | |
| def make_element_inner_weight(self): | |
| shape_element_inner_weight = self._shape.make_element_inner_weight() | |
| def element_inner_weight( | |
| elt_arg: self.geometry.CellArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| return shape_element_inner_weight(coords, node_index_in_elt) | |
| return element_inner_weight | |
| def make_element_inner_weight_gradient(self): | |
| shape_element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient() | |
| def element_inner_weight_gradient( | |
| elt_arg: self.geometry.CellArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| return shape_element_inner_weight_gradient(coords, node_index_in_elt) | |
| return element_inner_weight_gradient | |
| def make_trace_node_quadrature_weight(self, trace_basis): | |
| shape_trace_node_quadrature_weight = self._shape.make_trace_node_quadrature_weight() | |
| def trace_node_quadrature_weight( | |
| geo_side_arg: trace_basis.geometry.SideArg, | |
| basis_arg: trace_basis.BasisArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| neighbour_elem, index_in_neighbour = trace_basis.topology.neighbor_cell_index( | |
| geo_side_arg, element_index, node_index_in_elt | |
| ) | |
| return shape_trace_node_quadrature_weight(index_in_neighbour) | |
| return trace_node_quadrature_weight | |
| def _node_triangulation(self): | |
| element_node_indices = self._topology.element_node_indices().numpy() | |
| element_triangles = self._shape.element_node_triangulation() | |
| tri_indices = element_node_indices[:, element_triangles].reshape(-1, 3) | |
| return tri_indices | |
| def _node_tets(self): | |
| element_node_indices = self._topology.element_node_indices().numpy() | |
| element_tets = self._shape.element_node_tets() | |
| tet_indices = element_node_indices[:, element_tets].reshape(-1, 4) | |
| return tet_indices | |
| def _node_hexes(self): | |
| element_node_indices = self._topology.element_node_indices().numpy() | |
| element_hexes = self._shape.element_node_hexes() | |
| hex_indices = element_node_indices[:, element_hexes].reshape(-1, 8) | |
| return hex_indices | |
| class TraceBasisSpace(BasisSpace): | |
| """Auto-generated trace space evaluating the cell-defined basis on the geometry sides""" | |
| def __init__(self, basis: BasisSpace): | |
| super().__init__(basis.topology.trace()) | |
| self.ORDER = basis.ORDER | |
| self._basis = basis | |
| self.BasisArg = self._basis.BasisArg | |
| self.basis_arg_value = self._basis.basis_arg_value | |
| def name(self): | |
| return f"{self._basis.name}_Trace" | |
| def make_node_coords_in_element(self): | |
| node_coords_in_cell = self._basis.make_node_coords_in_element() | |
| def trace_node_coords_in_element( | |
| geo_side_arg: self.geometry.SideArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| neighbour_elem, index_in_neighbour = self.topology.neighbor_cell_index( | |
| geo_side_arg, element_index, node_index_in_elt | |
| ) | |
| geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg) | |
| neighbour_coords = node_coords_in_cell( | |
| geo_cell_arg, | |
| basis_arg, | |
| neighbour_elem, | |
| index_in_neighbour, | |
| ) | |
| return self.geometry.side_from_cell_coords(geo_side_arg, element_index, neighbour_elem, neighbour_coords) | |
| return trace_node_coords_in_element | |
| def make_node_quadrature_weight(self): | |
| return self._basis.make_trace_node_quadrature_weight(self) | |
| def make_element_inner_weight(self): | |
| cell_inner_weight = self._basis.make_element_inner_weight() | |
| def trace_element_inner_weight( | |
| geo_side_arg: self.geometry.SideArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt) | |
| if index_in_cell < 0: | |
| return 0.0 | |
| cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords) | |
| geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg) | |
| return cell_inner_weight( | |
| geo_cell_arg, | |
| basis_arg, | |
| cell_index, | |
| cell_coords, | |
| index_in_cell, | |
| ) | |
| return trace_element_inner_weight | |
| def make_element_outer_weight(self): | |
| cell_outer_weight = self._basis.make_element_outer_weight() | |
| def trace_element_outer_weight( | |
| geo_side_arg: self.geometry.SideArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt) | |
| if index_in_cell < 0: | |
| return 0.0 | |
| cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords) | |
| geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg) | |
| return cell_outer_weight( | |
| geo_cell_arg, | |
| basis_arg, | |
| cell_index, | |
| cell_coords, | |
| index_in_cell, | |
| ) | |
| return trace_element_outer_weight | |
| def make_element_inner_weight_gradient(self): | |
| cell_inner_weight_gradient = self._basis.make_element_inner_weight_gradient() | |
| grad_vec_type = wp.vec(length=self.geometry.dimension, dtype=float) | |
| def trace_element_inner_weight_gradient( | |
| geo_side_arg: self.geometry.SideArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt) | |
| if index_in_cell < 0: | |
| return grad_vec_type(0.0) | |
| cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords) | |
| geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg) | |
| return cell_inner_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell) | |
| return trace_element_inner_weight_gradient | |
| def make_element_outer_weight_gradient(self): | |
| cell_outer_weight_gradient = self._basis.make_element_outer_weight_gradient() | |
| grad_vec_type = wp.vec(length=self.geometry.dimension, dtype=float) | |
| def trace_element_outer_weight_gradient( | |
| geo_side_arg: self.geometry.SideArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt) | |
| if index_in_cell < 0: | |
| return grad_vec_type(0.0) | |
| cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords) | |
| geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg) | |
| return cell_outer_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell) | |
| return trace_element_outer_weight_gradient | |
| def __eq__(self, other: "TraceBasisSpace") -> bool: | |
| return self._topo == other._topo | |
| class PointBasisSpace(BasisSpace): | |
| """An unstructured :class:`BasisSpace` that is non-zero at a finite set of points only. | |
| The node locations and nodal quadrature weights are defined by a :class:`Quadrature` formula. | |
| """ | |
| def __init__(self, quadrature: Quadrature): | |
| self._quadrature = quadrature | |
| if quadrature.points_per_element() is None: | |
| raise NotImplementedError("Varying number of points per element is not supported yet") | |
| topology = DiscontinuousSpaceTopology( | |
| geometry=quadrature.domain.geometry, nodes_per_element=quadrature.points_per_element() | |
| ) | |
| super().__init__(topology) | |
| self.BasisArg = quadrature.Arg | |
| self.basis_arg_value = quadrature.arg_value | |
| self.ORDER = 0 | |
| self.make_element_outer_weight = self.make_element_inner_weight | |
| self.make_element_outer_weight_gradient = self.make_element_outer_weight_gradient | |
| def name(self): | |
| return f"{self._quadrature.name}_Point" | |
| def make_node_coords_in_element(self): | |
| def node_coords_in_element( | |
| elt_arg: self._quadrature.domain.ElementArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| return self._quadrature.point_coords(elt_arg, basis_arg, element_index, node_index_in_elt) | |
| return node_coords_in_element | |
| def make_node_quadrature_weight(self): | |
| def node_quadrature_weight( | |
| elt_arg: self._quadrature.domain.ElementArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| return self._quadrature.point_weight(elt_arg, basis_arg, element_index, node_index_in_elt) | |
| return node_quadrature_weight | |
| def make_element_inner_weight(self): | |
| def element_inner_weight( | |
| elt_arg: self._quadrature.domain.ElementArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| qp_coord = self._quadrature.point_coords(elt_arg, basis_arg, element_index, node_index_in_elt) | |
| return wp.select(wp.length_sq(coords - qp_coord) < 0.001, 0.0, 1.0) | |
| return element_inner_weight | |
| def make_element_inner_weight_gradient(self): | |
| gradient_vec = cache.cached_vec_type(length=self.geometry.dimension, dtype=float) | |
| def element_inner_weight_gradient( | |
| elt_arg: self._quadrature.domain.ElementArg, | |
| basis_arg: self.BasisArg, | |
| element_index: ElementIndex, | |
| coords: Coords, | |
| node_index_in_elt: int, | |
| ): | |
| return gradient_vec(0.0) | |
| return element_inner_weight_gradient | |
| def make_trace_node_quadrature_weight(self, trace_basis): | |
| def trace_node_quadrature_weight( | |
| elt_arg: trace_basis.geometry.SideArg, | |
| basis_arg: trace_basis.BasisArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| return 0.0 | |
| return trace_node_quadrature_weight | |