Spaces:
Sleeping
Sleeping
| from typing import Optional, Type | |
| import warp as wp | |
| from warp.fem.types import ElementIndex | |
| from warp.fem.geometry import Geometry, DeformedGeometry | |
| from warp.fem import cache | |
| class SpaceTopology: | |
| """ | |
| Interface class for defining the topology of a function space. | |
| The topology only considers the indices of the nodes in each element, and as such, | |
| the connectivity pattern of the function space. | |
| It does not specify the actual location of the nodes within the elements, or the valuation function. | |
| """ | |
| dimension: int | |
| """Embedding dimension of the function space""" | |
| NODES_PER_ELEMENT: int | |
| """Number of interpolation nodes per element of the geometry. | |
| .. note:: This will change to be defined per-element in future versions | |
| """ | |
| class TopologyArg: | |
| """Structure containing arguments to be passed to device functions""" | |
| pass | |
| def __init__(self, geometry: Geometry, nodes_per_element: int): | |
| self._geometry = geometry | |
| self.dimension = geometry.dimension | |
| self.NODES_PER_ELEMENT = wp.constant(nodes_per_element) | |
| self.ElementArg = geometry.CellArg | |
| def geometry(self) -> Geometry: | |
| """Underlying geometry""" | |
| return self._geometry | |
| def node_count(self) -> int: | |
| """Number of nodes in the interpolation basis""" | |
| raise NotImplementedError | |
| def topo_arg_value(self, device) -> "TopologyArg": | |
| """Value of the topology argument structure to be passed to device functions""" | |
| return SpaceTopology.TopologyArg() | |
| def name(self): | |
| return f"{self.__class__.__name__}_{self.NODES_PER_ELEMENT}" | |
| def __str__(self): | |
| return self.name | |
| def element_node_index( | |
| geo_arg: "ElementArg", topo_arg: "TopologyArg", element_index: ElementIndex, node_index_in_elt: int | |
| ): | |
| """Global node index for a given node in a given element""" | |
| raise NotImplementedError | |
| def element_node_indices(self, out: Optional[wp.array] = None) -> wp.array: | |
| """Returns a temporary array containing the global index for each node of each element""" | |
| NODES_PER_ELEMENT = self.NODES_PER_ELEMENT | |
| def fill_element_node_indices( | |
| geo_cell_arg: self.geometry.CellArg, | |
| topo_arg: self.TopologyArg, | |
| element_node_indices: wp.array2d(dtype=int), | |
| ): | |
| element_index = wp.tid() | |
| for n in range(NODES_PER_ELEMENT): | |
| element_node_indices[element_index, n] = self.element_node_index( | |
| geo_cell_arg, topo_arg, element_index, n | |
| ) | |
| shape = (self.geometry.cell_count(), NODES_PER_ELEMENT) | |
| if out is None: | |
| element_node_indices = wp.empty( | |
| shape=shape, | |
| dtype=int, | |
| ) | |
| else: | |
| if out.shape != shape or out.dtype != wp.int32: | |
| raise ValueError(f"Out element node idices array must have shape {shape} and data type 'int32'") | |
| element_node_indices = out | |
| wp.launch( | |
| dim=element_node_indices.shape[0], | |
| kernel=fill_element_node_indices, | |
| inputs=[ | |
| self.geometry.cell_arg_value(device=element_node_indices.device), | |
| self.topo_arg_value(device=element_node_indices.device), | |
| element_node_indices, | |
| ], | |
| device=element_node_indices.device, | |
| ) | |
| return element_node_indices | |
| # Interface generating trace space topology | |
| def trace(self) -> "TraceSpaceTopology": | |
| """Trace of the function space over lower-dimensional elements of the geometry""" | |
| return TraceSpaceTopology(self) | |
| def is_trace(self) -> bool: | |
| """Whether this topology is defined on the trace of the geometry""" | |
| return self.dimension == self.geometry.dimension - 1 | |
| def full_space_topology(self) -> "SpaceTopology": | |
| """Returns the full space topology from which this topology is derived""" | |
| return self | |
| def __eq__(self, other: "SpaceTopology") -> bool: | |
| """Checks whether two topologies are compatible""" | |
| return self.geometry == other.geometry and self.name == other.name | |
| def is_derived_from(self, other: "SpaceTopology") -> bool: | |
| """Checks whether two topologies are equal, or `self` is the trace of `other`""" | |
| if self.dimension == other.dimension: | |
| return self == other | |
| if self.dimension + 1 == other.dimension: | |
| return self.full_space_topology() == other | |
| return False | |
| class TraceSpaceTopology(SpaceTopology): | |
| """Auto-generated trace topology defining the node indices associated to the geometry sides""" | |
| def __init__(self, topo: SpaceTopology): | |
| super().__init__(topo.geometry, 2 * topo.NODES_PER_ELEMENT) | |
| self._topo = topo | |
| self.dimension = topo.dimension - 1 | |
| self.ElementArg = topo.geometry.SideArg | |
| self.TopologyArg = topo.TopologyArg | |
| self.topo_arg_value = topo.topo_arg_value | |
| self.inner_cell_index = self._make_inner_cell_index() | |
| self.outer_cell_index = self._make_outer_cell_index() | |
| self.neighbor_cell_index = self._make_neighbor_cell_index() | |
| self.element_node_index = self._make_element_node_index() | |
| def node_count(self) -> int: | |
| return self._topo.node_count() | |
| def name(self): | |
| return f"{self._topo.name}_Trace" | |
| def _make_inner_cell_index(self): | |
| NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT | |
| def inner_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int): | |
| index_in_inner_cell = wp.select(node_index_in_elt < NODES_PER_ELEMENT, -1, node_index_in_elt) | |
| return self.geometry.side_inner_cell_index(args, element_index), index_in_inner_cell | |
| return inner_cell_index | |
| def _make_outer_cell_index(self): | |
| NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT | |
| def outer_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int): | |
| return self.geometry.side_outer_cell_index(args, element_index), node_index_in_elt - NODES_PER_ELEMENT | |
| return outer_cell_index | |
| def _make_neighbor_cell_index(self): | |
| NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT | |
| def neighbor_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int): | |
| if node_index_in_elt < NODES_PER_ELEMENT: | |
| return self.geometry.side_inner_cell_index(args, element_index), node_index_in_elt | |
| else: | |
| return ( | |
| self.geometry.side_outer_cell_index(args, element_index), | |
| node_index_in_elt - NODES_PER_ELEMENT, | |
| ) | |
| return neighbor_cell_index | |
| def _make_element_node_index(self): | |
| def trace_element_node_index( | |
| geo_side_arg: self.geometry.SideArg, | |
| topo_arg: self._topo.TopologyArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| cell_index, index_in_cell = self.neighbor_cell_index(geo_side_arg, element_index, node_index_in_elt) | |
| geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg) | |
| return self._topo.element_node_index(geo_cell_arg, topo_arg, cell_index, index_in_cell) | |
| return trace_element_node_index | |
| def full_space_topology(self) -> SpaceTopology: | |
| """Returns the full space topology from which this topology is derived""" | |
| return self._topo | |
| def __eq__(self, other: "TraceSpaceTopology") -> bool: | |
| return self._topo == other._topo | |
| class DiscontinuousSpaceTopologyMixin: | |
| """Helper for defining discontinuous topologies (per-element nodes)""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.element_node_index = self._make_element_node_index() | |
| def node_count(self): | |
| return self.geometry.cell_count() * self.NODES_PER_ELEMENT | |
| def name(self): | |
| return f"{self.geometry.name}_D{self.NODES_PER_ELEMENT}" | |
| def _make_element_node_index(self): | |
| NODES_PER_ELEMENT = self.NODES_PER_ELEMENT | |
| def element_node_index( | |
| elt_arg: self.geometry.CellArg, | |
| topo_arg: self.TopologyArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| return NODES_PER_ELEMENT * element_index + node_index_in_elt | |
| return element_node_index | |
| class DiscontinuousSpaceTopology(DiscontinuousSpaceTopologyMixin, SpaceTopology): | |
| """Topology for generic discontinuous spaces""" | |
| pass | |
| class DeformedGeometrySpaceTopology(SpaceTopology): | |
| def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology): | |
| super().__init__(geometry, base_topology.NODES_PER_ELEMENT) | |
| self.base = base_topology | |
| self.node_count = self.base.node_count | |
| self.topo_arg_value = self.base.topo_arg_value | |
| self.TopologyArg = self.base.TopologyArg | |
| self.element_node_index = self._make_element_node_index() | |
| def name(self): | |
| return f"{self.base.name}_{self.geometry.field.name}" | |
| def _make_element_node_index(self): | |
| def element_node_index( | |
| elt_arg: self.geometry.CellArg, | |
| topo_arg: self.TopologyArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| return self.base.element_node_index(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt) | |
| return element_node_index | |
| def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology: | |
| """ | |
| If `geometry` is *not* a :class:`DeformedGeometry`, constructs a normal instance of `topology_class` over `geometry`, forwarding additional arguments. | |
| If `geometry` *is* a :class:`DeformedGeometry`, constructs an instance of `topology_class` over the base (undeformed) geometry of `geometry`, then warp it | |
| in a :class:`DeformedGeometrySpaceTopology` forwarding the calls to the underlying topology. | |
| """ | |
| if isinstance(geometry, DeformedGeometry): | |
| base_topo = topology_class(geometry.base, *args, **kwargs) | |
| return DeformedGeometrySpaceTopology(geometry, base_topo) | |
| return topology_class(geometry, *args, **kwargs) | |