from typing import Any, Optional, Union import warp as wp from warp.fem.cache import ( TemporaryStore, borrow_temporary, borrow_temporary_like, cached_arg_value, ) from warp.fem.geometry import GeometryPartition, WholeGeometryPartition from warp.fem.types import NULL_NODE_INDEX from warp.fem.utils import _iota_kernel, compress_node_indices from .function_space import FunctionSpace from .topology import SpaceTopology wp.set_module_options({"enable_backward": False}) class SpacePartition: class PartitionArg: pass def __init__(self, space_topology: SpaceTopology, geo_partition: GeometryPartition): self.space_topology = space_topology self.geo_partition = geo_partition def node_count(self): """Returns number of nodes in this partition""" def owned_node_count(self) -> int: """Returns number of nodes in this partition, excluding exterior halo""" def interior_node_count(self) -> int: """Returns number of interior nodes in this partition""" def space_node_indices(self) -> wp.array: """Return the global function space indices for nodes in this partition""" def partition_arg_value(self, device): pass @staticmethod def partition_node_index(args: "PartitionArg", space_node_index: int): """Returns the index in the partition of a function space node, or -1 if it does not exist""" def __str__(self) -> str: return self.name @property def name(self) -> str: return f"{self.__class__.__name__}" class WholeSpacePartition(SpacePartition): @wp.struct class PartitionArg: pass def __init__(self, space_topology: SpaceTopology): super().__init__(space_topology, WholeGeometryPartition(space_topology.geometry)) self._node_indices = None def node_count(self): """Returns number of nodes in this partition""" return self.space_topology.node_count() def owned_node_count(self) -> int: """Returns number of nodes in this partition, excluding exterior halo""" return self.space_topology.node_count() def interior_node_count(self) -> int: """Returns number of interior nodes in this partition""" return self.space_topology.node_count() def space_node_indices(self): """Return the global function space indices for nodes in this partition""" if self._node_indices is None: self._node_indices = borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int) wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array, 1]) return self._node_indices.array def partition_arg_value(self, device): return WholeSpacePartition.PartitionArg() @wp.func def partition_node_index(args: Any, space_node_index: int): return space_node_index def __eq__(self, other: SpacePartition) -> bool: return isinstance(other, SpacePartition) and self.space_topology == other.space_topology @property def name(self) -> str: return "Whole" class NodeCategory: OWNED_INTERIOR = wp.constant(0) """Node is touched exclusively by this partition, not touched by frontier side""" OWNED_FRONTIER = wp.constant(1) """Node is touched by a frontier side, but belongs to an element of this partition""" HALO_LOCAL_SIDE = wp.constant(2) """Node belongs to an element of another partition, but is touched by one of our frontier side""" HALO_OTHER_SIDE = wp.constant(3) """Node belongs to an element of another partition, and is not touched by one of our frontier side""" EXTERIOR = wp.constant(4) """Node is never referenced by this partition""" COUNT = 5 class NodePartition(SpacePartition): @wp.struct class PartitionArg: space_to_partition: wp.array(dtype=int) def __init__( self, space_topology: SpaceTopology, geo_partition: GeometryPartition, with_halo: bool = True, device=None, temporary_store: TemporaryStore = None, ): super().__init__(space_topology=space_topology, geo_partition=geo_partition) self._compute_node_indices_from_sides(device, with_halo, temporary_store) def node_count(self) -> int: """Returns number of nodes referenced by this partition, including exterior halo""" return int(self._category_offsets.array.numpy()[NodeCategory.HALO_OTHER_SIDE + 1]) def owned_node_count(self) -> int: """Returns number of nodes in this partition, excluding exterior halo""" return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_FRONTIER + 1]) def interior_node_count(self) -> int: """Returns number of interior nodes in this partition""" return int(self._category_offsets.array.numpy()[NodeCategory.OWNED_INTERIOR + 1]) def space_node_indices(self): """Return the global function space indices for nodes in this partition""" return self._node_indices.array @cached_arg_value def partition_arg_value(self, device): arg = NodePartition.PartitionArg() arg.space_to_partition = self._space_to_partition.array.to(device) return arg @wp.func def partition_node_index(args: PartitionArg, space_node_index: int): return args.space_to_partition[space_node_index] def _compute_node_indices_from_sides(self, device, with_halo: bool, temporary_store: TemporaryStore): from warp.fem import cache trace_topology = self.space_topology.trace() NODES_PER_CELL = self.space_topology.NODES_PER_ELEMENT NODES_PER_SIDE = trace_topology.NODES_PER_ELEMENT @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}") def node_category_from_cells_kernel( geo_arg: self.geo_partition.geometry.CellArg, geo_partition_arg: self.geo_partition.CellArg, space_arg: self.space_topology.TopologyArg, node_mask: wp.array(dtype=int), ): partition_cell_index = wp.tid() cell_index = self.geo_partition.cell_index(geo_partition_arg, partition_cell_index) for n in range(NODES_PER_CELL): space_nidx = self.space_topology.element_node_index(geo_arg, space_arg, cell_index, n) node_mask[space_nidx] = NodeCategory.OWNED_INTERIOR @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}") def node_category_from_owned_sides_kernel( geo_arg: self.geo_partition.geometry.SideArg, geo_partition_arg: self.geo_partition.SideArg, space_arg: trace_topology.TopologyArg, node_mask: wp.array(dtype=int), ): partition_side_index = wp.tid() side_index = self.geo_partition.side_index(geo_partition_arg, partition_side_index) for n in range(NODES_PER_SIDE): space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n) if node_mask[space_nidx] == NodeCategory.EXTERIOR: node_mask[space_nidx] = NodeCategory.HALO_LOCAL_SIDE @cache.dynamic_kernel(suffix=f"{self.geo_partition.name}_{self.space_topology.name}") def node_category_from_frontier_sides_kernel( geo_arg: self.geo_partition.geometry.SideArg, geo_partition_arg: self.geo_partition.SideArg, space_arg: trace_topology.TopologyArg, node_mask: wp.array(dtype=int), ): frontier_side_index = wp.tid() side_index = self.geo_partition.frontier_side_index(geo_partition_arg, frontier_side_index) for n in range(NODES_PER_SIDE): space_nidx = trace_topology.element_node_index(geo_arg, space_arg, side_index, n) if node_mask[space_nidx] == NodeCategory.EXTERIOR: node_mask[space_nidx] = NodeCategory.HALO_OTHER_SIDE elif node_mask[space_nidx] == NodeCategory.OWNED_INTERIOR: node_mask[space_nidx] = NodeCategory.OWNED_FRONTIER node_category = borrow_temporary( temporary_store, shape=(self.space_topology.node_count(),), dtype=int, device=device, ) node_category.array.fill_(value=NodeCategory.EXTERIOR) wp.launch( dim=self.geo_partition.cell_count(), kernel=node_category_from_cells_kernel, inputs=[ self.geo_partition.geometry.cell_arg_value(device), self.geo_partition.cell_arg_value(device), self.space_topology.topo_arg_value(device), node_category.array, ], device=device, ) if with_halo: wp.launch( dim=self.geo_partition.side_count(), kernel=node_category_from_owned_sides_kernel, inputs=[ self.geo_partition.geometry.side_arg_value(device), self.geo_partition.side_arg_value(device), self.space_topology.topo_arg_value(device), node_category.array, ], device=device, ) wp.launch( dim=self.geo_partition.frontier_side_count(), kernel=node_category_from_frontier_sides_kernel, inputs=[ self.geo_partition.geometry.side_arg_value(device), self.geo_partition.side_arg_value(device), self.space_topology.topo_arg_value(device), node_category.array, ], device=device, ) self._finalize_node_indices(node_category.array, temporary_store) node_category.release() def _finalize_node_indices(self, node_category: wp.array(dtype=int), temporary_store: TemporaryStore): category_offsets, node_indices, _, __ = compress_node_indices(NodeCategory.COUNT, node_category) # Copy offsets to cpu device = node_category.device self._category_offsets = borrow_temporary( temporary_store, shape=category_offsets.array.shape, dtype=category_offsets.array.dtype, pinned=device.is_cuda, device="cpu", ) wp.copy(src=category_offsets.array, dest=self._category_offsets.array) if device.is_cuda: # TODO switch to synchronize_event once available wp.synchronize_stream(wp.get_stream(device)) category_offsets.release() # Compute global to local indices self._space_to_partition = borrow_temporary_like(node_indices, temporary_store) wp.launch( kernel=NodePartition._scatter_partition_indices, dim=self.space_topology.node_count(), device=device, inputs=[self.node_count(), node_indices.array, self._space_to_partition.array], ) # Copy to shrinked-to-fit array self._node_indices = borrow_temporary(temporary_store, shape=(self.node_count()), dtype=int, device=device) wp.copy(dest=self._node_indices.array, src=node_indices.array, count=self.node_count()) node_indices.release() @wp.kernel def _scatter_partition_indices( local_node_count: int, node_indices: wp.array(dtype=int), space_to_partition_indices: wp.array(dtype=int), ): local_idx = wp.tid() space_idx = node_indices[local_idx] if local_idx < local_node_count: space_to_partition_indices[space_idx] = local_idx else: space_to_partition_indices[space_idx] = NULL_NODE_INDEX def make_space_partition( space: Optional[FunctionSpace] = None, geometry_partition: Optional[GeometryPartition] = None, space_topology: Optional[SpaceTopology] = None, with_halo: bool = True, device=None, temporary_store: TemporaryStore = None, ) -> SpacePartition: """Computes the subset of nodes from a function space topology that touch a geometry partition Either `space_topology` or `space` must be provided (and will be considered in that order). Args: space: (deprecated) the function space defining the topology if `space_topology` is ``None``. geometry_partition: The subset of the space geometry. If not provided, use the whole geometry. space_topology: the topology of the function space to consider. If ``None``, deduced from `space`. with_halo: if True, include the halo nodes (nodes from exterior frontier cells to the partition) device: Warp device on which to perform and store computations Returns: the resulting space partition """ if space_topology is None: space_topology = space.topology space_topology = space_topology.full_space_topology() if geometry_partition is not None: if geometry_partition.cell_count() < geometry_partition.geometry.cell_count(): return NodePartition( space_topology=space_topology, geo_partition=geometry_partition, with_halo=with_halo, device=device, temporary_store=temporary_store, ) return WholeSpacePartition(space_topology)