Spaces:
Sleeping
Sleeping
File size: 6,470 Bytes
66c9c8a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import warp as wp
from warp.fem.domain import GeometryDomain
from warp.fem.types import NodeElementIndex
from warp.fem.utils import compress_node_indices
from warp.fem.cache import cached_arg_value, borrow_temporary, borrow_temporary_like, TemporaryStore
from .function_space import FunctionSpace
from .partition import SpacePartition
wp.set_module_options({"enable_backward": False})
class SpaceRestriction:
"""Restriction of a space partition to a given GeometryDomain"""
def __init__(
self,
space_partition: SpacePartition,
domain: GeometryDomain,
device=None,
temporary_store: TemporaryStore = None,
):
space_topology = space_partition.space_topology
if domain.dimension == space_topology.dimension - 1:
space_topology = space_topology.trace()
if domain.dimension != space_topology.dimension:
raise ValueError("Incompatible space and domain dimensions")
self.space_partition = space_partition
self.space_topology = space_topology
self.domain = domain
self._compute_node_element_indices(device=device, temporary_store=temporary_store)
def _compute_node_element_indices(self, device, temporary_store: TemporaryStore):
from warp.fem import cache
NODES_PER_ELEMENT = self.space_topology.NODES_PER_ELEMENT
@cache.dynamic_kernel(suffix=f"{self.domain.name}_{self.space_topology.name}_{self.space_partition.name}")
def fill_element_node_indices(
element_arg: self.domain.ElementArg,
domain_index_arg: self.domain.ElementIndexArg,
topo_arg: self.space_topology.TopologyArg,
partition_arg: self.space_partition.PartitionArg,
element_node_indices: wp.array2d(dtype=int),
):
domain_element_index = wp.tid()
element_index = self.domain.element_index(domain_index_arg, domain_element_index)
for n in range(NODES_PER_ELEMENT):
space_nidx = self.space_topology.element_node_index(element_arg, topo_arg, element_index, n)
partition_nidx = self.space_partition.partition_node_index(partition_arg, space_nidx)
element_node_indices[domain_element_index, n] = partition_nidx
element_node_indices = borrow_temporary(
temporary_store,
shape=(self.domain.element_count(), NODES_PER_ELEMENT),
dtype=int,
device=device,
)
wp.launch(
dim=element_node_indices.array.shape[0],
kernel=fill_element_node_indices,
inputs=[
self.domain.element_arg_value(device),
self.domain.element_index_arg_value(device),
self.space_topology.topo_arg_value(device),
self.space_partition.partition_arg_value(device),
element_node_indices.array,
],
device=device,
)
# Build compressed map from node to element indices
flattened_node_indices = element_node_indices.array.flatten()
(
self._dof_partition_element_offsets,
node_array_indices,
self._node_count,
self._dof_partition_indices,
) = compress_node_indices(
self.space_partition.node_count(), flattened_node_indices, temporary_store=temporary_store
)
# Extract element index and index in element
self._dof_element_indices = borrow_temporary_like(flattened_node_indices, temporary_store)
self._dof_indices_in_element = borrow_temporary_like(flattened_node_indices, temporary_store)
wp.launch(
kernel=SpaceRestriction._split_vertex_element_index,
dim=flattened_node_indices.shape,
inputs=[
NODES_PER_ELEMENT,
node_array_indices.array,
self._dof_element_indices.array,
self._dof_indices_in_element.array,
],
device=flattened_node_indices.device,
)
node_array_indices.release()
def node_count(self):
return self._node_count
def partition_element_offsets(self):
return self._dof_partition_element_offsets.array
def node_partition_indices(self):
return self._dof_partition_indices.array
def total_node_element_count(self):
return self._dof_element_indices.array.size
@wp.struct
class NodeArg:
dof_element_offsets: wp.array(dtype=int)
dof_element_indices: wp.array(dtype=int)
dof_partition_indices: wp.array(dtype=int)
dof_indices_in_element: wp.array(dtype=int)
@cached_arg_value
def node_arg(self, device):
arg = SpaceRestriction.NodeArg()
arg.dof_element_offsets = self._dof_partition_element_offsets.array.to(device)
arg.dof_element_indices = self._dof_element_indices.array.to(device)
arg.dof_partition_indices = self._dof_partition_indices.array.to(device)
arg.dof_indices_in_element = self._dof_indices_in_element.array.to(device)
return arg
@wp.func
def node_partition_index(args: NodeArg, node_index: int):
return args.dof_partition_indices[node_index]
@wp.func
def node_element_count(args: NodeArg, node_index: int):
partition_node_index = SpaceRestriction.node_partition_index(args, node_index)
return args.dof_element_offsets[partition_node_index + 1] - args.dof_element_offsets[partition_node_index]
@wp.func
def node_element_index(args: NodeArg, node_index: int, element_index: int):
partition_node_index = SpaceRestriction.node_partition_index(args, node_index)
offset = args.dof_element_offsets[partition_node_index] + element_index
domain_element_index = args.dof_element_indices[offset]
index_in_element = args.dof_indices_in_element[offset]
return NodeElementIndex(domain_element_index, index_in_element)
@wp.kernel
def _split_vertex_element_index(
vertex_per_element: int,
sorted_indices: wp.array(dtype=int),
vertex_element_index: wp.array(dtype=int),
vertex_index_in_element: wp.array(dtype=int),
):
idx = sorted_indices[wp.tid()]
element_index = idx // vertex_per_element
vertex_element_index[wp.tid()] = element_index
vertex_index_in_element[wp.tid()] = idx - vertex_per_element * element_index
|