Spaces:
Sleeping
Sleeping
| import warp as wp | |
| from warp.fem.types import ElementIndex, Coords | |
| from warp.fem.geometry import Tetmesh | |
| from warp.fem import cache | |
| from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology | |
| from .basis_space import ShapeBasisSpace, TraceBasisSpace | |
| from .shape import ShapeFunction, ConstantShapeFunction | |
| from .shape import TetrahedronPolynomialShapeFunctions, TetrahedronNonConformingPolynomialShapeFunctions | |
| class TetmeshTopologyArg: | |
| tet_edge_indices: wp.array2d(dtype=int) | |
| tet_face_indices: wp.array2d(dtype=int) | |
| face_vertex_indices: wp.array(dtype=wp.vec3i) | |
| vertex_count: int | |
| edge_count: int | |
| face_count: int | |
| class TetmeshSpaceTopology(SpaceTopology): | |
| TopologyArg = TetmeshTopologyArg | |
| def __init__( | |
| self, | |
| mesh: Tetmesh, | |
| shape: ShapeFunction, | |
| need_tet_edge_indices: bool = True, | |
| need_tet_face_indices: bool = True, | |
| ): | |
| super().__init__(mesh, shape.NODES_PER_ELEMENT) | |
| self._mesh = mesh | |
| self._shape = shape | |
| if need_tet_edge_indices: | |
| self._tet_edge_indices = self._mesh.tet_edge_indices | |
| self._edge_count = self._mesh.edge_count() | |
| else: | |
| self._tet_edge_indices = wp.empty(shape=(0, 0), dtype=int) | |
| self._edge_count = 0 | |
| if need_tet_face_indices: | |
| self._compute_tet_face_indices() | |
| else: | |
| self._tet_face_indices = wp.empty(shape=(0, 0), dtype=int) | |
| def topo_arg_value(self, device): | |
| arg = TetmeshTopologyArg() | |
| arg.tet_face_indices = self._tet_face_indices.to(device) | |
| arg.tet_edge_indices = self._tet_edge_indices.to(device) | |
| arg.face_vertex_indices = self._mesh.face_vertex_indices.to(device) | |
| arg.vertex_count = self._mesh.vertex_count() | |
| arg.face_count = self._mesh.side_count() | |
| arg.edge_count = self._edge_count | |
| return arg | |
| def _compute_tet_face_indices(self): | |
| self._tet_face_indices = wp.empty( | |
| dtype=int, device=self._mesh.tet_vertex_indices.device, shape=(self._mesh.cell_count(), 4) | |
| ) | |
| wp.launch( | |
| kernel=TetmeshSpaceTopology._compute_tet_face_indices_kernel, | |
| dim=self._mesh._face_tet_indices.shape, | |
| device=self._mesh.tet_vertex_indices.device, | |
| inputs=[ | |
| self._mesh.face_tet_indices, | |
| self._mesh.face_vertex_indices, | |
| self._mesh.tet_vertex_indices, | |
| self._tet_face_indices, | |
| ], | |
| ) | |
| def _find_face_index_in_tet( | |
| face_vtx: wp.vec3i, | |
| tet_vtx: wp.vec4i, | |
| ): | |
| for k in range(3): | |
| tvk = wp.vec3i(tet_vtx[k], tet_vtx[(k + 1) % 4], tet_vtx[(k + 2) % 4]) | |
| # Use fact that face always start with min vertex | |
| min_t = wp.min(tvk) | |
| max_t = wp.max(tvk) | |
| mid_t = tvk[0] + tvk[1] + tvk[2] - min_t - max_t | |
| if min_t == face_vtx[0] and ( | |
| (face_vtx[2] == max_t and face_vtx[1] == mid_t) or (face_vtx[1] == max_t and face_vtx[2] == mid_t) | |
| ): | |
| return k | |
| return 3 | |
| def _compute_tet_face_indices_kernel( | |
| face_tet_indices: wp.array(dtype=wp.vec2i), | |
| face_vertex_indices: wp.array(dtype=wp.vec3i), | |
| tet_vertex_indices: wp.array2d(dtype=int), | |
| tet_face_indices: wp.array2d(dtype=int), | |
| ): | |
| e = wp.tid() | |
| face_vtx = face_vertex_indices[e] | |
| face_tets = face_tet_indices[e] | |
| t0 = face_tets[0] | |
| t0_vtx = wp.vec4i( | |
| tet_vertex_indices[t0, 0], tet_vertex_indices[t0, 1], tet_vertex_indices[t0, 2], tet_vertex_indices[t0, 3] | |
| ) | |
| t0_face = TetmeshSpaceTopology._find_face_index_in_tet(face_vtx, t0_vtx) | |
| tet_face_indices[t0, t0_face] = e | |
| t1 = face_tets[1] | |
| if t1 != t0: | |
| t1_vtx = wp.vec4i( | |
| tet_vertex_indices[t1, 0], | |
| tet_vertex_indices[t1, 1], | |
| tet_vertex_indices[t1, 2], | |
| tet_vertex_indices[t1, 3], | |
| ) | |
| t1_face = TetmeshSpaceTopology._find_face_index_in_tet(face_vtx, t1_vtx) | |
| tet_face_indices[t1, t1_face] = e | |
| class TetmeshDiscontinuousSpaceTopology( | |
| DiscontinuousSpaceTopologyMixin, | |
| SpaceTopology, | |
| ): | |
| def __init__(self, mesh: Tetmesh, shape: ShapeFunction): | |
| super().__init__(mesh, shape.NODES_PER_ELEMENT) | |
| class TetmeshBasisSpace(ShapeBasisSpace): | |
| def __init__(self, topology: TetmeshSpaceTopology, shape: ShapeFunction): | |
| super().__init__(topology, shape) | |
| self._mesh: Tetmesh = topology.geometry | |
| class TetmeshPiecewiseConstantBasis(TetmeshBasisSpace): | |
| def __init__(self, mesh: Tetmesh): | |
| shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=3) | |
| topology = TetmeshDiscontinuousSpaceTopology(mesh, shape) | |
| super().__init__(shape=shape, topology=topology) | |
| class Trace(TraceBasisSpace): | |
| def _node_coords_in_element( | |
| side_arg: Tetmesh.SideArg, | |
| basis_arg: TetmeshBasisSpace.BasisArg, | |
| element_index: ElementIndex, | |
| node_index_in_element: int, | |
| ): | |
| return Coords(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0) | |
| def make_node_coords_in_element(self): | |
| return self._node_coords_in_element | |
| def trace(self): | |
| return TetmeshPiecewiseConstantBasis.Trace(self) | |
| class TetmeshPolynomialSpaceTopology(TetmeshSpaceTopology): | |
| def __init__(self, mesh: Tetmesh, shape: TetrahedronPolynomialShapeFunctions): | |
| super().__init__(mesh, shape, need_tet_edge_indices=shape.ORDER >= 2, need_tet_face_indices=shape.ORDER >= 3) | |
| self.element_node_index = self._make_element_node_index() | |
| def node_count(self) -> int: | |
| ORDER = self._shape.ORDER | |
| INTERIOR_NODES_PER_EDGE = max(0, ORDER - 1) | |
| INTERIOR_NODES_PER_FACE = max(0, ORDER - 2) * max(0, ORDER - 1) // 2 | |
| INTERIOR_NODES_PER_CELL = max(0, ORDER - 3) * max(0, ORDER - 2) * max(0, ORDER - 1) // 6 | |
| return ( | |
| self._mesh.vertex_count() | |
| + self._mesh.edge_count() * INTERIOR_NODES_PER_EDGE | |
| + self._mesh.side_count() * INTERIOR_NODES_PER_FACE | |
| + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL | |
| ) | |
| def _make_element_node_index(self): | |
| ORDER = self._shape.ORDER | |
| INTERIOR_NODES_PER_EDGE = wp.constant(max(0, ORDER - 1)) | |
| INTERIOR_NODES_PER_FACE = wp.constant(max(0, ORDER - 2) * max(0, ORDER - 1) // 2) | |
| INTERIOR_NODES_PER_CELL = wp.constant(max(0, ORDER - 3) * max(0, ORDER - 2) * max(0, ORDER - 1) // 6) | |
| def element_node_index( | |
| geo_arg: Tetmesh.CellArg, | |
| topo_arg: TetmeshTopologyArg, | |
| element_index: ElementIndex, | |
| node_index_in_elt: int, | |
| ): | |
| node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt) | |
| if node_type == TetrahedronPolynomialShapeFunctions.VERTEX: | |
| return geo_arg.tet_vertex_indices[element_index][type_index] | |
| global_offset = topo_arg.vertex_count | |
| if node_type == TetrahedronPolynomialShapeFunctions.EDGE: | |
| edge = type_index // INTERIOR_NODES_PER_EDGE | |
| edge_node = type_index - INTERIOR_NODES_PER_EDGE * edge | |
| global_edge_index = topo_arg.tet_edge_indices[element_index][edge] | |
| # Test if we need to swap edge direction | |
| if INTERIOR_NODES_PER_EDGE > 1: | |
| if edge < 3: | |
| c1 = edge | |
| c2 = (edge + 1) % 3 | |
| else: | |
| c1 = edge - 3 | |
| c2 = 3 | |
| if geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2]: | |
| edge_node = INTERIOR_NODES_PER_EDGE - 1 - edge_node | |
| return global_offset + INTERIOR_NODES_PER_EDGE * global_edge_index + edge_node | |
| global_offset += INTERIOR_NODES_PER_EDGE * topo_arg.edge_count | |
| if node_type == TetrahedronPolynomialShapeFunctions.FACE: | |
| face = type_index // INTERIOR_NODES_PER_FACE | |
| face_node = type_index - INTERIOR_NODES_PER_FACE * face | |
| global_face_index = topo_arg.tet_face_indices[element_index][face] | |
| if INTERIOR_NODES_PER_FACE == 3: | |
| # Hard code for P4 case, 3 nodes per face | |
| # Higher orders would require rotating triangle coordinates, this is not supported yet | |
| vidx = geo_arg.tet_vertex_indices[element_index][(face + face_node) % 4] | |
| fvi = topo_arg.face_vertex_indices[global_face_index] | |
| if vidx == fvi[0]: | |
| face_node = 0 | |
| elif vidx == fvi[1]: | |
| face_node = 1 | |
| else: | |
| face_node = 2 | |
| return global_offset + INTERIOR_NODES_PER_FACE * global_face_index + face_node | |
| global_offset += INTERIOR_NODES_PER_FACE * topo_arg.face_count | |
| return global_offset + INTERIOR_NODES_PER_CELL * element_index + type_index | |
| return element_node_index | |
| class TetmeshPolynomialBasisSpace(TetmeshBasisSpace): | |
| def __init__( | |
| self, | |
| mesh: Tetmesh, | |
| degree: int, | |
| ): | |
| shape = TetrahedronPolynomialShapeFunctions(degree) | |
| topology = forward_base_topology(TetmeshPolynomialSpaceTopology, mesh, shape) | |
| super().__init__(topology, shape) | |
| class TetmeshDGPolynomialBasisSpace(TetmeshBasisSpace): | |
| def __init__( | |
| self, | |
| mesh: Tetmesh, | |
| degree: int, | |
| ): | |
| shape = TetrahedronPolynomialShapeFunctions(degree) | |
| topology = TetmeshDiscontinuousSpaceTopology(mesh, shape) | |
| super().__init__(topology, shape) | |
| class TetmeshNonConformingPolynomialBasisSpace(TetmeshBasisSpace): | |
| def __init__( | |
| self, | |
| mesh: Tetmesh, | |
| degree: int, | |
| ): | |
| shape = TetrahedronNonConformingPolynomialShapeFunctions(degree) | |
| topology = TetmeshDiscontinuousSpaceTopology(mesh, shape) | |
| super().__init__(topology, shape) | |