qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
from typing import Optional
import warp as wp
from warp.fem.types import ElementIndex, Coords, make_free_sample
from warp.fem.geometry import Geometry
from warp.fem.quadrature import Quadrature
from warp.fem import cache
from .topology import SpaceTopology, DiscontinuousSpaceTopology
from .shape import ShapeFunction
class BasisSpace:
"""Interface class for defining a scalar-valued basis over a geometry.
A basis space makes it easy to define multiple function spaces sharing the same basis (and thus nodes) but with different valuation functions;
however, it is not a required ingredient of a function space.
See also: :func:`make_polynomial_basis_space`, :func:`make_collocated_function_space`
"""
@wp.struct
class BasisArg:
"""Argument structure to be passed to device functions"""
pass
def __init__(self, topology: SpaceTopology):
self._topology = topology
self.NODES_PER_ELEMENT = self._topology.NODES_PER_ELEMENT
@property
def topology(self) -> SpaceTopology:
"""Underlying topology of the basis space"""
return self._topology
@property
def geometry(self) -> Geometry:
"""Underlying geometry of the basis space"""
return self._topology.geometry
def basis_arg_value(self, device) -> "BasisArg":
"""Value for the argument structure to be passed to device functions"""
return BasisSpace.BasisArg()
# Helpers for generating node positions
def node_positions(self, out: Optional[wp.array] = None) -> wp.array:
"""Returns a temporary array containing the world position for each node"""
NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
pos_type = cache.cached_vec_type(length=self.geometry.dimension, dtype=float)
node_coords_in_element = self.make_node_coords_in_element()
@cache.dynamic_kernel(suffix=self.name, kernel_options={"max_unroll": 4, "enable_backward": False})
def fill_node_positions(
geo_cell_arg: self.geometry.CellArg,
basis_arg: self.BasisArg,
topo_arg: self.topology.TopologyArg,
node_positions: wp.array(dtype=pos_type),
):
element_index = wp.tid()
for n in range(NODES_PER_ELEMENT):
node_index = self.topology.element_node_index(geo_cell_arg, topo_arg, element_index, n)
coords = node_coords_in_element(geo_cell_arg, basis_arg, element_index, n)
sample = make_free_sample(element_index, coords)
pos = self.geometry.cell_position(geo_cell_arg, sample)
node_positions[node_index] = pos
shape = (self.topology.node_count(),)
if out is None:
node_positions = wp.empty(
shape=shape,
dtype=pos_type,
)
else:
if out.shape != shape or not wp.types.types_equal(pos_type, out.dtype):
raise ValueError(
f"Out node positions array must have shape {shape} and data type {wp.types.type_repr(pos_type)}"
)
node_positions = out
wp.launch(
dim=self.geometry.cell_count(),
kernel=fill_node_positions,
inputs=[
self.geometry.cell_arg_value(device=node_positions.device),
self.basis_arg_value(device=node_positions.device),
self.topology.topo_arg_value(device=node_positions.device),
node_positions,
],
)
return node_positions
def make_node_coords_in_element(self):
raise NotImplementedError()
def make_node_quadrature_weight(self):
raise NotImplementedError()
def make_element_inner_weight(self):
raise NotImplementedError()
def make_element_outer_weight(self):
return self.make_element_inner_weight()
def make_element_inner_weight_gradient(self):
raise NotImplementedError()
def make_element_outer_weight_gradient(self):
return self.make_element_inner_weight_gradient()
def make_trace_node_quadrature_weight(self):
raise NotImplementedError()
def trace(self) -> "TraceBasisSpace":
return TraceBasisSpace(self)
class ShapeBasisSpace(BasisSpace):
"""Base class for defining shape-function-based basis spaces."""
def __init__(self, topology: SpaceTopology, shape: ShapeFunction):
super().__init__(topology)
self._shape = shape
self.ORDER = self._shape.ORDER
if hasattr(shape, "element_node_triangulation"):
self.node_triangulation = self._node_triangulation
if hasattr(shape, "element_node_tets"):
self.node_tets = self._node_tets
if hasattr(shape, "element_node_hexes"):
self.node_hexes = self._node_hexes
@property
def shape(self) -> ShapeFunction:
"""Shape functions used for defining individual element basis"""
return self._shape
@property
def name(self):
return f"{self.topology.name}_{self._shape.name}"
def make_node_coords_in_element(self):
shape_node_coords_in_element = self._shape.make_node_coords_in_element()
@cache.dynamic_func(suffix=self.name)
def node_coords_in_element(
elt_arg: self.geometry.CellArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
return shape_node_coords_in_element(node_index_in_elt)
return node_coords_in_element
def make_node_quadrature_weight(self):
shape_node_quadrature_weight = self._shape.make_node_quadrature_weight()
@cache.dynamic_func(suffix=self.name)
def node_quadrature_weight(
elt_arg: self.geometry.CellArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
return shape_node_quadrature_weight(node_index_in_elt)
return node_quadrature_weight
def make_element_inner_weight(self):
shape_element_inner_weight = self._shape.make_element_inner_weight()
@cache.dynamic_func(suffix=self.name)
def element_inner_weight(
elt_arg: self.geometry.CellArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
coords: Coords,
node_index_in_elt: int,
):
return shape_element_inner_weight(coords, node_index_in_elt)
return element_inner_weight
def make_element_inner_weight_gradient(self):
shape_element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
@cache.dynamic_func(suffix=self.name)
def element_inner_weight_gradient(
elt_arg: self.geometry.CellArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
coords: Coords,
node_index_in_elt: int,
):
return shape_element_inner_weight_gradient(coords, node_index_in_elt)
return element_inner_weight_gradient
def make_trace_node_quadrature_weight(self, trace_basis):
shape_trace_node_quadrature_weight = self._shape.make_trace_node_quadrature_weight()
@cache.dynamic_func(suffix=self.name)
def trace_node_quadrature_weight(
geo_side_arg: trace_basis.geometry.SideArg,
basis_arg: trace_basis.BasisArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
neighbour_elem, index_in_neighbour = trace_basis.topology.neighbor_cell_index(
geo_side_arg, element_index, node_index_in_elt
)
return shape_trace_node_quadrature_weight(index_in_neighbour)
return trace_node_quadrature_weight
def _node_triangulation(self):
element_node_indices = self._topology.element_node_indices().numpy()
element_triangles = self._shape.element_node_triangulation()
tri_indices = element_node_indices[:, element_triangles].reshape(-1, 3)
return tri_indices
def _node_tets(self):
element_node_indices = self._topology.element_node_indices().numpy()
element_tets = self._shape.element_node_tets()
tet_indices = element_node_indices[:, element_tets].reshape(-1, 4)
return tet_indices
def _node_hexes(self):
element_node_indices = self._topology.element_node_indices().numpy()
element_hexes = self._shape.element_node_hexes()
hex_indices = element_node_indices[:, element_hexes].reshape(-1, 8)
return hex_indices
class TraceBasisSpace(BasisSpace):
"""Auto-generated trace space evaluating the cell-defined basis on the geometry sides"""
def __init__(self, basis: BasisSpace):
super().__init__(basis.topology.trace())
self.ORDER = basis.ORDER
self._basis = basis
self.BasisArg = self._basis.BasisArg
self.basis_arg_value = self._basis.basis_arg_value
@property
def name(self):
return f"{self._basis.name}_Trace"
def make_node_coords_in_element(self):
node_coords_in_cell = self._basis.make_node_coords_in_element()
@cache.dynamic_func(suffix=self._basis.name)
def trace_node_coords_in_element(
geo_side_arg: self.geometry.SideArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
neighbour_elem, index_in_neighbour = self.topology.neighbor_cell_index(
geo_side_arg, element_index, node_index_in_elt
)
geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
neighbour_coords = node_coords_in_cell(
geo_cell_arg,
basis_arg,
neighbour_elem,
index_in_neighbour,
)
return self.geometry.side_from_cell_coords(geo_side_arg, element_index, neighbour_elem, neighbour_coords)
return trace_node_coords_in_element
def make_node_quadrature_weight(self):
return self._basis.make_trace_node_quadrature_weight(self)
def make_element_inner_weight(self):
cell_inner_weight = self._basis.make_element_inner_weight()
@cache.dynamic_func(suffix=self._basis.name)
def trace_element_inner_weight(
geo_side_arg: self.geometry.SideArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
coords: Coords,
node_index_in_elt: int,
):
cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
if index_in_cell < 0:
return 0.0
cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
return cell_inner_weight(
geo_cell_arg,
basis_arg,
cell_index,
cell_coords,
index_in_cell,
)
return trace_element_inner_weight
def make_element_outer_weight(self):
cell_outer_weight = self._basis.make_element_outer_weight()
@cache.dynamic_func(suffix=self._basis.name)
def trace_element_outer_weight(
geo_side_arg: self.geometry.SideArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
coords: Coords,
node_index_in_elt: int,
):
cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
if index_in_cell < 0:
return 0.0
cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
return cell_outer_weight(
geo_cell_arg,
basis_arg,
cell_index,
cell_coords,
index_in_cell,
)
return trace_element_outer_weight
def make_element_inner_weight_gradient(self):
cell_inner_weight_gradient = self._basis.make_element_inner_weight_gradient()
grad_vec_type = wp.vec(length=self.geometry.dimension, dtype=float)
@cache.dynamic_func(suffix=self._basis.name)
def trace_element_inner_weight_gradient(
geo_side_arg: self.geometry.SideArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
coords: Coords,
node_index_in_elt: int,
):
cell_index, index_in_cell = self.topology.inner_cell_index(geo_side_arg, element_index, node_index_in_elt)
if index_in_cell < 0:
return grad_vec_type(0.0)
cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
return cell_inner_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell)
return trace_element_inner_weight_gradient
def make_element_outer_weight_gradient(self):
cell_outer_weight_gradient = self._basis.make_element_outer_weight_gradient()
grad_vec_type = wp.vec(length=self.geometry.dimension, dtype=float)
@cache.dynamic_func(suffix=self._basis.name)
def trace_element_outer_weight_gradient(
geo_side_arg: self.geometry.SideArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
coords: Coords,
node_index_in_elt: int,
):
cell_index, index_in_cell = self.topology.outer_cell_index(geo_side_arg, element_index, node_index_in_elt)
if index_in_cell < 0:
return grad_vec_type(0.0)
cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
return cell_outer_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell)
return trace_element_outer_weight_gradient
def __eq__(self, other: "TraceBasisSpace") -> bool:
return self._topo == other._topo
class PointBasisSpace(BasisSpace):
"""An unstructured :class:`BasisSpace` that is non-zero at a finite set of points only.
The node locations and nodal quadrature weights are defined by a :class:`Quadrature` formula.
"""
def __init__(self, quadrature: Quadrature):
self._quadrature = quadrature
if quadrature.points_per_element() is None:
raise NotImplementedError("Varying number of points per element is not supported yet")
topology = DiscontinuousSpaceTopology(
geometry=quadrature.domain.geometry, nodes_per_element=quadrature.points_per_element()
)
super().__init__(topology)
self.BasisArg = quadrature.Arg
self.basis_arg_value = quadrature.arg_value
self.ORDER = 0
self.make_element_outer_weight = self.make_element_inner_weight
self.make_element_outer_weight_gradient = self.make_element_outer_weight_gradient
@property
def name(self):
return f"{self._quadrature.name}_Point"
def make_node_coords_in_element(self):
@cache.dynamic_func(suffix=self.name)
def node_coords_in_element(
elt_arg: self._quadrature.domain.ElementArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
return self._quadrature.point_coords(elt_arg, basis_arg, element_index, node_index_in_elt)
return node_coords_in_element
def make_node_quadrature_weight(self):
@cache.dynamic_func(suffix=self.name)
def node_quadrature_weight(
elt_arg: self._quadrature.domain.ElementArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
return self._quadrature.point_weight(elt_arg, basis_arg, element_index, node_index_in_elt)
return node_quadrature_weight
def make_element_inner_weight(self):
@cache.dynamic_func(suffix=self.name)
def element_inner_weight(
elt_arg: self._quadrature.domain.ElementArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
coords: Coords,
node_index_in_elt: int,
):
qp_coord = self._quadrature.point_coords(elt_arg, basis_arg, element_index, node_index_in_elt)
return wp.select(wp.length_sq(coords - qp_coord) < 0.001, 0.0, 1.0)
return element_inner_weight
def make_element_inner_weight_gradient(self):
gradient_vec = cache.cached_vec_type(length=self.geometry.dimension, dtype=float)
@cache.dynamic_func(suffix=self.name)
def element_inner_weight_gradient(
elt_arg: self._quadrature.domain.ElementArg,
basis_arg: self.BasisArg,
element_index: ElementIndex,
coords: Coords,
node_index_in_elt: int,
):
return gradient_vec(0.0)
return element_inner_weight_gradient
def make_trace_node_quadrature_weight(self, trace_basis):
@cache.dynamic_func(suffix=self.name)
def trace_node_quadrature_weight(
elt_arg: trace_basis.geometry.SideArg,
basis_arg: trace_basis.BasisArg,
element_index: ElementIndex,
node_index_in_elt: int,
):
return 0.0
return trace_node_quadrature_weight