Spaces:
Sleeping
Sleeping
File size: 10,863 Bytes
66c9c8a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | from typing import Optional, Type
import warp as wp
from warp.fem.types import ElementIndex
from warp.fem.geometry import Geometry, DeformedGeometry
from warp.fem import cache
class SpaceTopology:
"""
Interface class for defining the topology of a function space.
The topology only considers the indices of the nodes in each element, and as such,
the connectivity pattern of the function space.
It does not specify the actual location of the nodes within the elements, or the valuation function.
"""
dimension: int
"""Embedding dimension of the function space"""
NODES_PER_ELEMENT: int
"""Number of interpolation nodes per element of the geometry.
.. note:: This will change to be defined per-element in future versions
"""
@wp.struct
class TopologyArg:
"""Structure containing arguments to be passed to device functions"""
pass
def __init__(self, geometry: Geometry, nodes_per_element: int):
self._geometry = geometry
self.dimension = geometry.dimension
self.NODES_PER_ELEMENT = wp.constant(nodes_per_element)
self.ElementArg = geometry.CellArg
@property
def geometry(self) -> Geometry:
"""Underlying geometry"""
return self._geometry
def node_count(self) -> int:
"""Number of nodes in the interpolation basis"""
raise NotImplementedError
def topo_arg_value(self, device) -> "TopologyArg":
"""Value of the topology argument structure to be passed to device functions"""
return SpaceTopology.TopologyArg()
@property
def name(self):
return f"{self.__class__.__name__}_{self.NODES_PER_ELEMENT}"
def __str__(self):
return self.name
@staticmethod
def element_node_index(
geo_arg: "ElementArg", topo_arg: "TopologyArg", element_index: ElementIndex, node_index_in_elt: int
):
"""Global node index for a given node in a given element"""
raise NotImplementedError
def element_node_indices(self, out: Optional[wp.array] = None) -> wp.array:
"""Returns a temporary array containing the global index for each node of each element"""
NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
@cache.dynamic_kernel(suffix=self.name)
def fill_element_node_indices(
geo_cell_arg: self.geometry.CellArg,
topo_arg: self.TopologyArg,
element_node_indices: wp.array2d(dtype=int),
):
element_index = wp.tid()
for n in range(NODES_PER_ELEMENT):
element_node_indices[element_index, n] = self.element_node_index(
geo_cell_arg, topo_arg, element_index, n
)
shape = (self.geometry.cell_count(), NODES_PER_ELEMENT)
if out is None:
element_node_indices = wp.empty(
shape=shape,
dtype=int,
)
else:
if out.shape != shape or out.dtype != wp.int32:
raise ValueError(f"Out element node idices array must have shape {shape} and data type 'int32'")
element_node_indices = out
wp.launch(
dim=element_node_indices.shape[0],
kernel=fill_element_node_indices,
inputs=[
self.geometry.cell_arg_value(device=element_node_indices.device),
self.topo_arg_value(device=element_node_indices.device),
element_node_indices,
],
device=element_node_indices.device,
)
return element_node_indices
# Interface generating trace space topology
def trace(self) -> "TraceSpaceTopology":
"""Trace of the function space over lower-dimensional elements of the geometry"""
return TraceSpaceTopology(self)
@property
def is_trace(self) -> bool:
"""Whether this topology is defined on the trace of the geometry"""
return self.dimension == self.geometry.dimension - 1
def full_space_topology(self) -> "SpaceTopology":
"""Returns the full space topology from which this topology is derived"""
return self
def __eq__(self, other: "SpaceTopology") -> bool:
"""Checks whether two topologies are compatible"""
return self.geometry == other.geometry and self.name == other.name
def is_derived_from(self, other: "SpaceTopology") -> bool:
"""Checks whether two topologies are equal, or `self` is the trace of `other`"""
if self.dimension == other.dimension:
return self == other
if self.dimension + 1 == other.dimension:
return self.full_space_topology() == other
return False
class TraceSpaceTopology(SpaceTopology):
"""Auto-generated trace topology defining the node indices associated to the geometry sides"""
def __init__(self, topo: SpaceTopology):
super().__init__(topo.geometry, 2 * topo.NODES_PER_ELEMENT)
self._topo = topo
self.dimension = topo.dimension - 1
self.ElementArg = topo.geometry.SideArg
self.TopologyArg = topo.TopologyArg
self.topo_arg_value = topo.topo_arg_value
self.inner_cell_index = self._make_inner_cell_index()
self.outer_cell_index = self._make_outer_cell_index()
self.neighbor_cell_index = self._make_neighbor_cell_index()
self.element_node_index = self._make_element_node_index()
def node_count(self) -> int:
return self._topo.node_count()
@property
def name(self):
return f"{self._topo.name}_Trace"
def _make_inner_cell_index(self):
NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT
@cache.dynamic_func(suffix=self.name)
def inner_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
index_in_inner_cell = wp.select(node_index_in_elt < NODES_PER_ELEMENT, -1, node_index_in_elt)
return self.geometry.side_inner_cell_index(args, element_index), index_in_inner_cell
return inner_cell_index
def _make_outer_cell_index(self):
NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT
@cache.dynamic_func(suffix=self.name)
def outer_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
return self.geometry.side_outer_cell_index(args, element_index), node_index_in_elt - NODES_PER_ELEMENT
return outer_cell_index
def _make_neighbor_cell_index(self):
NODES_PER_ELEMENT = self._topo.NODES_PER_ELEMENT
@cache.dynamic_func(suffix=self.name)
def neighbor_cell_index(args: self.geometry.SideArg, element_index: ElementIndex, node_index_in_elt: int):
if node_index_in_elt < NODES_PER_ELEMENT:
return self.geometry.side_inner_cell_index(args, element_index), node_index_in_elt
else:
return (
self.geometry.side_outer_cell_index(args, element_index),
node_index_in_elt - NODES_PER_ELEMENT,
)
return neighbor_cell_index
def _make_element_node_index(self):
@cache.dynamic_func(suffix=self.name)
def trace_element_node_index(
geo_side_arg: self.geometry.SideArg,
topo_arg: self._topo.TopologyArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
cell_index, index_in_cell = self.neighbor_cell_index(geo_side_arg, element_index, node_index_in_elt)
geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
return self._topo.element_node_index(geo_cell_arg, topo_arg, cell_index, index_in_cell)
return trace_element_node_index
def full_space_topology(self) -> SpaceTopology:
"""Returns the full space topology from which this topology is derived"""
return self._topo
def __eq__(self, other: "TraceSpaceTopology") -> bool:
return self._topo == other._topo
class DiscontinuousSpaceTopologyMixin:
"""Helper for defining discontinuous topologies (per-element nodes)"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.element_node_index = self._make_element_node_index()
def node_count(self):
return self.geometry.cell_count() * self.NODES_PER_ELEMENT
@property
def name(self):
return f"{self.geometry.name}_D{self.NODES_PER_ELEMENT}"
def _make_element_node_index(self):
NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
@cache.dynamic_func(suffix=self.name)
def element_node_index(
elt_arg: self.geometry.CellArg,
topo_arg: self.TopologyArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
return NODES_PER_ELEMENT * element_index + node_index_in_elt
return element_node_index
class DiscontinuousSpaceTopology(DiscontinuousSpaceTopologyMixin, SpaceTopology):
"""Topology for generic discontinuous spaces"""
pass
class DeformedGeometrySpaceTopology(SpaceTopology):
def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
super().__init__(geometry, base_topology.NODES_PER_ELEMENT)
self.base = base_topology
self.node_count = self.base.node_count
self.topo_arg_value = self.base.topo_arg_value
self.TopologyArg = self.base.TopologyArg
self.element_node_index = self._make_element_node_index()
@property
def name(self):
return f"{self.base.name}_{self.geometry.field.name}"
def _make_element_node_index(self):
@cache.dynamic_func(suffix=self.name)
def element_node_index(
elt_arg: self.geometry.CellArg,
topo_arg: self.TopologyArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
return self.base.element_node_index(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
return element_node_index
def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
"""
If `geometry` is *not* a :class:`DeformedGeometry`, constructs a normal instance of `topology_class` over `geometry`, forwarding additional arguments.
If `geometry` *is* a :class:`DeformedGeometry`, constructs an instance of `topology_class` over the base (undeformed) geometry of `geometry`, then warp it
in a :class:`DeformedGeometrySpaceTopology` forwarding the calls to the underlying topology.
"""
if isinstance(geometry, DeformedGeometry):
base_topo = topology_class(geometry.base, *args, **kwargs)
return DeformedGeometrySpaceTopology(geometry, base_topo)
return topology_class(geometry, *args, **kwargs)
|