GarmentCode / NvidiaWarp-GarmentCode /warp /fem /space /trimesh_2d_function_space.py
qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
import warp as wp
from warp.fem.types import ElementIndex, Coords
from warp.fem.geometry import Trimesh2D
from warp.fem import cache
from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
from .basis_space import ShapeBasisSpace, TraceBasisSpace
from .shape import ShapeFunction, ConstantShapeFunction
from .shape import Triangle2DPolynomialShapeFunctions, Triangle2DNonConformingPolynomialShapeFunctions
@wp.struct
class Trimesh2DTopologyArg:
edge_vertex_indices: wp.array(dtype=wp.vec2i)
tri_edge_indices: wp.array2d(dtype=int)
vertex_count: int
edge_count: int
class Trimesh2DSpaceTopology(SpaceTopology):
TopologyArg = Trimesh2DTopologyArg
def __init__(self, mesh: Trimesh2D, shape: ShapeFunction):
super().__init__(mesh, shape.NODES_PER_ELEMENT)
self._mesh = mesh
self._shape = shape
self._compute_tri_edge_indices()
@cache.cached_arg_value
def topo_arg_value(self, device):
arg = Trimesh2DTopologyArg()
arg.tri_edge_indices = self._tri_edge_indices.to(device)
arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
arg.vertex_count = self._mesh.vertex_count()
arg.edge_count = self._mesh.side_count()
return arg
def _compute_tri_edge_indices(self):
self._tri_edge_indices = wp.empty(
dtype=int, device=self._mesh.tri_vertex_indices.device, shape=(self._mesh.cell_count(), 3)
)
wp.launch(
kernel=Trimesh2DSpaceTopology._compute_tri_edge_indices_kernel,
dim=self._mesh.edge_tri_indices.shape,
device=self._mesh.tri_vertex_indices.device,
inputs=[
self._mesh.edge_tri_indices,
self._mesh.edge_vertex_indices,
self._mesh.tri_vertex_indices,
self._tri_edge_indices,
],
)
@wp.func
def _find_edge_index_in_tri(
edge_vtx: wp.vec2i,
tri_vtx: wp.vec3i,
):
for k in range(2):
if (edge_vtx[0] == tri_vtx[k] and edge_vtx[1] == tri_vtx[k + 1]) or (
edge_vtx[1] == tri_vtx[k] and edge_vtx[0] == tri_vtx[k + 1]
):
return k
return 2
@wp.kernel
def _compute_tri_edge_indices_kernel(
edge_tri_indices: wp.array(dtype=wp.vec2i),
edge_vertex_indices: wp.array(dtype=wp.vec2i),
tri_vertex_indices: wp.array2d(dtype=int),
tri_edge_indices: wp.array2d(dtype=int),
):
e = wp.tid()
edge_vtx = edge_vertex_indices[e]
edge_tris = edge_tri_indices[e]
t0 = edge_tris[0]
t0_vtx = wp.vec3i(tri_vertex_indices[t0, 0], tri_vertex_indices[t0, 1], tri_vertex_indices[t0, 2])
t0_edge = Trimesh2DSpaceTopology._find_edge_index_in_tri(edge_vtx, t0_vtx)
tri_edge_indices[t0, t0_edge] = e
t1 = edge_tris[1]
if t1 != t0:
t1_vtx = wp.vec3i(tri_vertex_indices[t1, 0], tri_vertex_indices[t1, 1], tri_vertex_indices[t1, 2])
t1_edge = Trimesh2DSpaceTopology._find_edge_index_in_tri(edge_vtx, t1_vtx)
tri_edge_indices[t1, t1_edge] = e
class Trimesh2DDiscontinuousSpaceTopology(
DiscontinuousSpaceTopologyMixin,
SpaceTopology,
):
def __init__(self, mesh: Trimesh2D, shape: ShapeFunction):
super().__init__(mesh, shape.NODES_PER_ELEMENT)
class Trimesh2DBasisSpace(ShapeBasisSpace):
def __init__(self, topology: Trimesh2DSpaceTopology, shape: ShapeFunction):
super().__init__(topology, shape)
self._mesh: Trimesh2D = topology.geometry
class Trimesh2DPiecewiseConstantBasis(Trimesh2DBasisSpace):
def __init__(self, mesh: Trimesh2D):
shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=2)
topology = Trimesh2DDiscontinuousSpaceTopology(mesh, shape)
super().__init__(shape=shape, topology=topology)
class Trace(TraceBasisSpace):
@wp.func
def _node_coords_in_element(
side_arg: Trimesh2D.SideArg,
basis_arg: Trimesh2DBasisSpace.BasisArg,
element_index: ElementIndex,
node_index_in_element: int,
):
return Coords(0.5, 0.0, 0.0)
def make_node_coords_in_element(self):
return self._node_coords_in_element
def trace(self):
return Trimesh2DPiecewiseConstantBasis.Trace(self)
class Trimesh2DPolynomialSpaceTopology(Trimesh2DSpaceTopology):
def __init__(self, mesh: Trimesh2D, shape: Triangle2DPolynomialShapeFunctions):
super().__init__(mesh, shape)
self.element_node_index = self._make_element_node_index()
def node_count(self) -> int:
INTERIOR_NODES_PER_SIDE = max(0, self._shape.ORDER - 1)
INTERIOR_NODES_PER_CELL = max(0, self._shape.ORDER - 2) * max(0, self._shape.ORDER - 1) // 2
return (
self._mesh.vertex_count()
+ self._mesh.side_count() * INTERIOR_NODES_PER_SIDE
+ self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
)
def _make_element_node_index(self):
INTERIOR_NODES_PER_SIDE = wp.constant(max(0, self._shape.ORDER - 1))
INTERIOR_NODES_PER_CELL = wp.constant(max(0, self._shape.ORDER - 2) * max(0, self._shape.ORDER - 1) // 2)
@cache.dynamic_func(suffix=self.name)
def element_node_index(
geo_arg: Trimesh2D.CellArg,
topo_arg: Trimesh2DTopologyArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
if node_type == Triangle2DPolynomialShapeFunctions.VERTEX:
return geo_arg.tri_vertex_indices[element_index][type_index]
global_offset = topo_arg.vertex_count
if node_type == Triangle2DPolynomialShapeFunctions.EDGE:
edge = type_index // INTERIOR_NODES_PER_SIDE
edge_node = type_index - INTERIOR_NODES_PER_SIDE * edge
global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
if (
topo_arg.edge_vertex_indices[global_edge_index][0]
!= geo_arg.tri_vertex_indices[element_index][edge]
):
edge_node = INTERIOR_NODES_PER_SIDE - 1 - edge_node
return global_offset + INTERIOR_NODES_PER_SIDE * global_edge_index + edge_node
global_offset += INTERIOR_NODES_PER_SIDE * topo_arg.edge_count
return global_offset + INTERIOR_NODES_PER_CELL * element_index + type_index
return element_node_index
class Trimesh2DPolynomialBasisSpace(Trimesh2DBasisSpace):
def __init__(
self,
mesh: Trimesh2D,
degree: int,
):
shape = Triangle2DPolynomialShapeFunctions(degree)
topology = forward_base_topology(Trimesh2DPolynomialSpaceTopology, mesh, shape)
super().__init__(topology, shape)
class Trimesh2DDGPolynomialBasisSpace(Trimesh2DBasisSpace):
def __init__(
self,
mesh: Trimesh2D,
degree: int,
):
shape = Triangle2DPolynomialShapeFunctions(degree)
topology = Trimesh2DDiscontinuousSpaceTopology(mesh, shape)
super().__init__(topology, shape)
class Trimesh2DNonConformingPolynomialBasisSpace(Trimesh2DBasisSpace):
def __init__(
self,
mesh: Trimesh2D,
degree: int,
):
shape = Triangle2DNonConformingPolynomialShapeFunctions(degree)
topology = Trimesh2DDiscontinuousSpaceTopology(mesh, shape)
super().__init__(topology, shape)