qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
import warp as wp
from warp.fem.domain import GeometryDomain
from warp.fem.types import NodeElementIndex
from warp.fem.utils import compress_node_indices
from warp.fem.cache import cached_arg_value, borrow_temporary, borrow_temporary_like, TemporaryStore
from .function_space import FunctionSpace
from .partition import SpacePartition
wp.set_module_options({"enable_backward": False})
class SpaceRestriction:
"""Restriction of a space partition to a given GeometryDomain"""
def __init__(
self,
space_partition: SpacePartition,
domain: GeometryDomain,
device=None,
temporary_store: TemporaryStore = None,
):
space_topology = space_partition.space_topology
if domain.dimension == space_topology.dimension - 1:
space_topology = space_topology.trace()
if domain.dimension != space_topology.dimension:
raise ValueError("Incompatible space and domain dimensions")
self.space_partition = space_partition
self.space_topology = space_topology
self.domain = domain
self._compute_node_element_indices(device=device, temporary_store=temporary_store)
def _compute_node_element_indices(self, device, temporary_store: TemporaryStore):
from warp.fem import cache
NODES_PER_ELEMENT = self.space_topology.NODES_PER_ELEMENT
@cache.dynamic_kernel(suffix=f"{self.domain.name}_{self.space_topology.name}_{self.space_partition.name}")
def fill_element_node_indices(
element_arg: self.domain.ElementArg,
domain_index_arg: self.domain.ElementIndexArg,
topo_arg: self.space_topology.TopologyArg,
partition_arg: self.space_partition.PartitionArg,
element_node_indices: wp.array2d(dtype=int),
):
domain_element_index = wp.tid()
element_index = self.domain.element_index(domain_index_arg, domain_element_index)
for n in range(NODES_PER_ELEMENT):
space_nidx = self.space_topology.element_node_index(element_arg, topo_arg, element_index, n)
partition_nidx = self.space_partition.partition_node_index(partition_arg, space_nidx)
element_node_indices[domain_element_index, n] = partition_nidx
element_node_indices = borrow_temporary(
temporary_store,
shape=(self.domain.element_count(), NODES_PER_ELEMENT),
dtype=int,
device=device,
)
wp.launch(
dim=element_node_indices.array.shape[0],
kernel=fill_element_node_indices,
inputs=[
self.domain.element_arg_value(device),
self.domain.element_index_arg_value(device),
self.space_topology.topo_arg_value(device),
self.space_partition.partition_arg_value(device),
element_node_indices.array,
],
device=device,
)
# Build compressed map from node to element indices
flattened_node_indices = element_node_indices.array.flatten()
(
self._dof_partition_element_offsets,
node_array_indices,
self._node_count,
self._dof_partition_indices,
) = compress_node_indices(
self.space_partition.node_count(), flattened_node_indices, temporary_store=temporary_store
)
# Extract element index and index in element
self._dof_element_indices = borrow_temporary_like(flattened_node_indices, temporary_store)
self._dof_indices_in_element = borrow_temporary_like(flattened_node_indices, temporary_store)
wp.launch(
kernel=SpaceRestriction._split_vertex_element_index,
dim=flattened_node_indices.shape,
inputs=[
NODES_PER_ELEMENT,
node_array_indices.array,
self._dof_element_indices.array,
self._dof_indices_in_element.array,
],
device=flattened_node_indices.device,
)
node_array_indices.release()
def node_count(self):
return self._node_count
def partition_element_offsets(self):
return self._dof_partition_element_offsets.array
def node_partition_indices(self):
return self._dof_partition_indices.array
def total_node_element_count(self):
return self._dof_element_indices.array.size
@wp.struct
class NodeArg:
dof_element_offsets: wp.array(dtype=int)
dof_element_indices: wp.array(dtype=int)
dof_partition_indices: wp.array(dtype=int)
dof_indices_in_element: wp.array(dtype=int)
@cached_arg_value
def node_arg(self, device):
arg = SpaceRestriction.NodeArg()
arg.dof_element_offsets = self._dof_partition_element_offsets.array.to(device)
arg.dof_element_indices = self._dof_element_indices.array.to(device)
arg.dof_partition_indices = self._dof_partition_indices.array.to(device)
arg.dof_indices_in_element = self._dof_indices_in_element.array.to(device)
return arg
@wp.func
def node_partition_index(args: NodeArg, node_index: int):
return args.dof_partition_indices[node_index]
@wp.func
def node_element_count(args: NodeArg, node_index: int):
partition_node_index = SpaceRestriction.node_partition_index(args, node_index)
return args.dof_element_offsets[partition_node_index + 1] - args.dof_element_offsets[partition_node_index]
@wp.func
def node_element_index(args: NodeArg, node_index: int, element_index: int):
partition_node_index = SpaceRestriction.node_partition_index(args, node_index)
offset = args.dof_element_offsets[partition_node_index] + element_index
domain_element_index = args.dof_element_indices[offset]
index_in_element = args.dof_indices_in_element[offset]
return NodeElementIndex(domain_element_index, index_in_element)
@wp.kernel
def _split_vertex_element_index(
vertex_per_element: int,
sorted_indices: wp.array(dtype=int),
vertex_element_index: wp.array(dtype=int),
vertex_index_in_element: wp.array(dtype=int),
):
idx = sorted_indices[wp.tid()]
element_index = idx // vertex_per_element
vertex_element_index[wp.tid()] = element_index
vertex_index_in_element[wp.tid()] = idx - vertex_per_element * element_index