GarmentCode / NvidiaWarp-GarmentCode /warp /fem /space /quadmesh_2d_function_space.py
qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
import warp as wp
from warp.fem.types import ElementIndex, Coords
from warp.fem.polynomial import Polynomial, is_closed
from warp.fem.geometry import Quadmesh2D
from warp.fem import cache
from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
from .basis_space import ShapeBasisSpace, TraceBasisSpace
from .shape import ShapeFunction, ConstantShapeFunction
from .shape import (
SquareBipolynomialShapeFunctions,
SquareSerendipityShapeFunctions,
SquareNonConformingPolynomialShapeFunctions,
)
@wp.struct
class Quadmesh2DTopologyArg:
edge_vertex_indices: wp.array(dtype=wp.vec2i)
quad_edge_indices: wp.array2d(dtype=int)
vertex_count: int
edge_count: int
class Quadmesh2DSpaceTopology(SpaceTopology):
TopologyArg = Quadmesh2DTopologyArg
def __init__(self, mesh: Quadmesh2D, shape: ShapeFunction):
super().__init__(mesh, shape.NODES_PER_ELEMENT)
self._mesh = mesh
self._shape = shape
self._compute_quad_edge_indices()
@cache.cached_arg_value
def topo_arg_value(self, device):
arg = Quadmesh2DTopologyArg()
arg.quad_edge_indices = self._quad_edge_indices.to(device)
arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
arg.vertex_count = self._mesh.vertex_count()
arg.edge_count = self._mesh.side_count()
return arg
def _compute_quad_edge_indices(self):
self._quad_edge_indices = wp.empty(
dtype=int, device=self._mesh.quad_vertex_indices.device, shape=(self._mesh.cell_count(), 4)
)
wp.launch(
kernel=Quadmesh2DSpaceTopology._compute_quad_edge_indices_kernel,
dim=self._mesh.edge_quad_indices.shape,
device=self._mesh.quad_vertex_indices.device,
inputs=[
self._mesh.edge_quad_indices,
self._mesh.edge_vertex_indices,
self._mesh.quad_vertex_indices,
self._quad_edge_indices,
],
)
@wp.func
def _find_edge_index_in_quad(
edge_vtx: wp.vec2i,
quad_vtx: wp.vec4i,
):
for k in range(3):
if (edge_vtx[0] == quad_vtx[k] and edge_vtx[1] == quad_vtx[k + 1]) or (
edge_vtx[1] == quad_vtx[k] and edge_vtx[0] == quad_vtx[k + 1]
):
return k
return 3
@wp.kernel
def _compute_quad_edge_indices_kernel(
edge_quad_indices: wp.array(dtype=wp.vec2i),
edge_vertex_indices: wp.array(dtype=wp.vec2i),
quad_vertex_indices: wp.array2d(dtype=int),
quad_edge_indices: wp.array2d(dtype=int),
):
e = wp.tid()
edge_vtx = edge_vertex_indices[e]
edge_quads = edge_quad_indices[e]
q0 = edge_quads[0]
q0_vtx = wp.vec4i(
quad_vertex_indices[q0, 0],
quad_vertex_indices[q0, 1],
quad_vertex_indices[q0, 2],
quad_vertex_indices[q0, 3],
)
q0_edge = Quadmesh2DSpaceTopology._find_edge_index_in_quad(edge_vtx, q0_vtx)
quad_edge_indices[q0, q0_edge] = e
q1 = edge_quads[1]
if q1 != q0:
t1_vtx = wp.vec4i(
quad_vertex_indices[q1, 0],
quad_vertex_indices[q1, 1],
quad_vertex_indices[q1, 2],
quad_vertex_indices[q1, 3],
)
t1_edge = Quadmesh2DSpaceTopology._find_edge_index_in_quad(edge_vtx, t1_vtx)
quad_edge_indices[q1, t1_edge] = e
class Quadmesh2DDiscontinuousSpaceTopology(
DiscontinuousSpaceTopologyMixin,
SpaceTopology,
):
def __init__(self, mesh: Quadmesh2D, shape: ShapeFunction):
super().__init__(mesh, shape.NODES_PER_ELEMENT)
class Quadmesh2DBasisSpace(ShapeBasisSpace):
def __init__(self, topology: Quadmesh2DSpaceTopology, shape: ShapeFunction):
super().__init__(topology, shape)
self._mesh: Quadmesh2D = topology.geometry
class Quadmesh2DPiecewiseConstantBasis(Quadmesh2DBasisSpace):
def __init__(self, mesh: Quadmesh2D):
shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=2)
topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape)
super().__init__(shape=shape, topology=topology)
class Trace(TraceBasisSpace):
@wp.func
def _node_coords_in_element(
side_arg: Quadmesh2D.SideArg,
basis_arg: Quadmesh2DBasisSpace.BasisArg,
element_index: ElementIndex,
node_index_in_element: int,
):
return Coords(0.5, 0.0, 0.0)
def make_node_coords_in_element(self):
return self._node_coords_in_element
def trace(self):
return Quadmesh2DPiecewiseConstantBasis.Trace(self)
class Quadmesh2DBipolynomialSpaceTopology(Quadmesh2DSpaceTopology):
def __init__(self, mesh: Quadmesh2D, shape: SquareBipolynomialShapeFunctions):
super().__init__(mesh, shape)
self.element_node_index = self._make_element_node_index()
def node_count(self) -> int:
ORDER = self._shape.ORDER
INTERIOR_NODES_PER_SIDE = max(0, ORDER - 1)
INTERIOR_NODES_PER_CELL = INTERIOR_NODES_PER_SIDE**2
return (
self._mesh.vertex_count()
+ self._mesh.side_count() * INTERIOR_NODES_PER_SIDE
+ self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
)
def _make_element_node_index(self):
ORDER = self._shape.ORDER
INTERIOR_NODES_PER_SIDE = wp.constant(max(0, ORDER - 1))
INTERIOR_NODES_PER_CELL = wp.constant(INTERIOR_NODES_PER_SIDE**2)
@cache.dynamic_func(suffix=self.name)
def element_node_index(
geo_arg: Quadmesh2D.CellArg,
topo_arg: Quadmesh2DTopologyArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
node_i = node_index_in_elt // (ORDER + 1)
node_j = node_index_in_elt - (ORDER + 1) * node_i
# Vertices
if node_i == 0:
if node_j == 0:
return geo_arg.quad_vertex_indices[element_index, 0]
elif node_j == ORDER:
return geo_arg.quad_vertex_indices[element_index, 3]
# 3-0 edge
side_index = topo_arg.quad_edge_indices[element_index, 3]
local_vs = geo_arg.quad_vertex_indices[element_index, 3]
global_vs = topo_arg.edge_vertex_indices[side_index][0]
index_in_side = wp.select(local_vs == global_vs, ORDER - node_j, node_j) - 1
return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
elif node_i == ORDER:
if node_j == 0:
return geo_arg.quad_vertex_indices[element_index, 1]
elif node_j == ORDER:
return geo_arg.quad_vertex_indices[element_index, 2]
# 1-2 edge
side_index = topo_arg.quad_edge_indices[element_index, 1]
local_vs = geo_arg.quad_vertex_indices[element_index, 1]
global_vs = topo_arg.edge_vertex_indices[side_index][0]
index_in_side = wp.select(local_vs == global_vs, ORDER - node_j, node_j) - 1
return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
if node_j == 0:
# 0-1 edge
side_index = topo_arg.quad_edge_indices[element_index, 0]
local_vs = geo_arg.quad_vertex_indices[element_index, 0]
global_vs = topo_arg.edge_vertex_indices[side_index][0]
index_in_side = wp.select(local_vs == global_vs, node_i, ORDER - node_i) - 1
return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
elif node_j == ORDER:
# 2-3 edge
side_index = topo_arg.quad_edge_indices[element_index, 2]
local_vs = geo_arg.quad_vertex_indices[element_index, 2]
global_vs = topo_arg.edge_vertex_indices[side_index][0]
index_in_side = wp.select(local_vs == global_vs, node_i, ORDER - node_i) - 1
return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
return (
topo_arg.vertex_count
+ topo_arg.edge_count * INTERIOR_NODES_PER_SIDE
+ element_index * INTERIOR_NODES_PER_CELL
+ (node_i - 1) * INTERIOR_NODES_PER_SIDE
+ node_j
- 1
)
return element_node_index
class Quadmesh2DBipolynomialBasisSpace(Quadmesh2DBasisSpace):
def __init__(
self,
mesh: Quadmesh2D,
degree: int,
family: Polynomial,
):
if family is None:
family = Polynomial.LOBATTO_GAUSS_LEGENDRE
if not is_closed(family):
raise ValueError("A closed polynomial family is required to define a continuous function space")
shape = SquareBipolynomialShapeFunctions(degree, family=family)
topology = forward_base_topology(Quadmesh2DBipolynomialSpaceTopology, mesh, shape)
super().__init__(topology, shape)
class Quadmesh2DDGBipolynomialBasisSpace(Quadmesh2DBasisSpace):
def __init__(
self,
mesh: Quadmesh2D,
degree: int,
family: Polynomial,
):
if family is None:
family = Polynomial.LOBATTO_GAUSS_LEGENDRE
shape = SquareBipolynomialShapeFunctions(degree, family=family)
topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape)
super().__init__(topology, shape)
class Quadmesh2DSerendipitySpaceTopology(Quadmesh2DSpaceTopology):
def __init__(self, grid: Quadmesh2D, shape: SquareSerendipityShapeFunctions):
super().__init__(grid, shape)
self.element_node_index = self._make_element_node_index()
def node_count(self) -> int:
return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self.geometry.side_count()
def _make_element_node_index(self):
ORDER = self._shape.ORDER
SHAPE_TO_QUAD_IDX = wp.constant(wp.vec4i([0, 3, 1, 2]))
@cache.dynamic_func(suffix=self.name)
def element_node_index(
cell_arg: Quadmesh2D.CellArg,
topo_arg: Quadmesh2DSpaceTopology.TopologyArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
if node_type == SquareSerendipityShapeFunctions.VERTEX:
return cell_arg.quad_vertex_indices[element_index, SHAPE_TO_QUAD_IDX[type_index]]
side_offset, index_in_side = SquareSerendipityShapeFunctions.side_offset_and_index(type_index)
if node_type == SquareSerendipityShapeFunctions.EDGE_X:
if side_offset == 0:
side_start = 0
else:
side_start = 2
index_in_side = ORDER - 2 - index_in_side
else:
if side_offset == 0:
side_start = 3
index_in_side = ORDER - 2 - index_in_side
else:
side_start = 1
side_index = topo_arg.quad_edge_indices[element_index, side_start]
local_vs = cell_arg.quad_vertex_indices[element_index, side_start]
global_vs = topo_arg.edge_vertex_indices[side_index][0]
if local_vs != global_vs:
# Flip indexing direction
index_in_side = ORDER - 2 - index_in_side
return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
return element_node_index
class Quadmesh2DSerendipityBasisSpace(Quadmesh2DBasisSpace):
def __init__(
self,
mesh: Quadmesh2D,
degree: int,
family: Polynomial,
):
if family is None:
family = Polynomial.LOBATTO_GAUSS_LEGENDRE
shape = SquareSerendipityShapeFunctions(degree, family=family)
topology = forward_base_topology(Quadmesh2DSerendipitySpaceTopology, mesh, shape=shape)
super().__init__(topology=topology, shape=shape)
class Quadmesh2DDGSerendipityBasisSpace(Quadmesh2DBasisSpace):
def __init__(
self,
mesh: Quadmesh2D,
degree: int,
family: Polynomial,
):
if family is None:
family = Polynomial.LOBATTO_GAUSS_LEGENDRE
shape = SquareSerendipityShapeFunctions(degree, family=family)
topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape=shape)
super().__init__(topology=topology, shape=shape)
class Quadmesh2DPolynomialBasisSpace(Quadmesh2DBasisSpace):
def __init__(
self,
mesh: Quadmesh2D,
degree: int,
):
shape = SquareNonConformingPolynomialShapeFunctions(degree)
topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape)
super().__init__(topology, shape)