qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
from typing import Optional
import warp as wp
from warp.fem.cache import (
TemporaryStore,
borrow_temporary,
borrow_temporary_like,
cached_arg_value,
)
from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
from .element import Cube, Square
from .geometry import Geometry
@wp.struct
class HexmeshCellArg:
hex_vertex_indices: wp.array2d(dtype=int)
positions: wp.array(dtype=wp.vec3)
# for neighbor cell lookup
vertex_hex_offsets: wp.array(dtype=int)
vertex_hex_indices: wp.array(dtype=int)
@wp.struct
class HexmeshSideArg:
cell_arg: HexmeshCellArg
face_vertex_indices: wp.array(dtype=wp.vec4i)
face_hex_indices: wp.array(dtype=wp.vec2i)
face_hex_face_orientation: wp.array(dtype=wp.vec4i)
_mat32 = wp.mat(shape=(3, 2), dtype=float)
FACE_VERTEX_INDICES = wp.constant(
wp.mat(shape=(6, 4), dtype=int)(
[
[0, 4, 7, 3], # x = 0
[1, 2, 6, 5], # x = 1
[0, 1, 5, 4], # y = 0
[3, 7, 6, 2], # y = 1
[0, 3, 2, 1], # z = 0
[4, 5, 6, 7], # z = 1
]
)
)
EDGE_VERTEX_INDICES = wp.constant(
wp.mat(shape=(12, 2), dtype=int)(
[
[0, 1],
[1, 2],
[3, 2],
[0, 3],
[4, 5],
[5, 6],
[7, 6],
[4, 7],
[0, 4],
[1, 5],
[2, 6],
[3, 7],
]
)
)
# orthogal transform for face coordinates given first vertex + winding
# (two rows per entry)
FACE_ORIENTATION = [
[1, 0], # FV: 0, det: +
[0, 1],
[0, 1], # FV: 0, det: -
[1, 0],
[0, -1], # FV: 1, det: +
[1, 0],
[-1, 0], # FV: 1, det: -
[0, 1],
[-1, 0], # FV: 2, det: +
[0, -1],
[0, -1], # FV: 2, det: -
[-1, 0],
[0, 1], # FV: 3, det: +
[-1, 0],
[1, 0], # FV: 3, det: -
[0, -1],
]
FACE_TRANSLATION = [
[0, 0],
[1, 0],
[1, 1],
[0, 1],
]
# local face coordinate system
_FACE_COORD_INDICES = wp.constant(
wp.mat(shape=(6, 4), dtype=int)(
[
[2, 1, 0, 0], # 0: z y -x
[1, 2, 0, 1], # 1: y z x-1
[0, 2, 1, 0], # 2: x z -y
[2, 0, 1, 1], # 3: z x y-1
[1, 0, 2, 0], # 4: y x -z
[0, 1, 2, 1], # 5: x y z-1
]
)
)
_FACE_ORIENTATION_F = wp.constant(wp.mat(shape=(16, 2), dtype=float)(FACE_ORIENTATION))
_FACE_TRANSLATION_F = wp.constant(wp.mat(shape=(4, 2), dtype=float)(FACE_TRANSLATION))
class Hexmesh(Geometry):
"""Hexahedral mesh geometry"""
dimension = 3
def __init__(
self, hex_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
):
"""
Constructs a tetrahedral mesh.
Args:
hex_vertex_indices: warp array of shape (num_hexes, 8) containing vertex indices for each hex
following standard ordering (bottom face vertices in counter-clockwise order, then similarly for upper face)
positions: warp array of shape (num_vertices, 3) containing 3d position for each vertex
temporary_store: shared pool from which to allocate temporary arrays
"""
self.hex_vertex_indices = hex_vertex_indices
self.positions = positions
self._face_vertex_indices: wp.array = None
self._face_hex_indices: wp.array = None
self._face_hex_face_orientation: wp.array = None
self._vertex_hex_offsets: wp.array = None
self._vertex_hex_indices: wp.array = None
self._hex_edge_indices: wp.array = None
self._edge_count = 0
self._build_topology(temporary_store)
def cell_count(self):
return self.hex_vertex_indices.shape[0]
def vertex_count(self):
return self.positions.shape[0]
def side_count(self):
return self._face_vertex_indices.shape[0]
def edge_count(self):
if self._hex_edge_indices is None:
self._compute_hex_edges()
return self._edge_count
def boundary_side_count(self):
return self._boundary_face_indices.shape[0]
def reference_cell(self) -> Cube:
return Cube()
def reference_side(self) -> Square:
return Square()
@property
def hex_edge_indices(self) -> wp.array:
if self._hex_edge_indices is None:
self._compute_hex_edges()
return self._hex_edge_indices
@property
def face_hex_indices(self) -> wp.array:
return self._face_hex_indices
@property
def face_vertex_indices(self) -> wp.array:
return self._face_vertex_indices
CellArg = HexmeshCellArg
SideArg = HexmeshSideArg
@wp.struct
class SideIndexArg:
boundary_face_indices: wp.array(dtype=int)
# Geometry device interface
@cached_arg_value
def cell_arg_value(self, device) -> CellArg:
args = self.CellArg()
args.hex_vertex_indices = self.hex_vertex_indices.to(device)
args.positions = self.positions.to(device)
args.vertex_hex_offsets = self._vertex_hex_offsets.to(device)
args.vertex_hex_indices = self._vertex_hex_indices.to(device)
return args
@wp.func
def cell_position(args: CellArg, s: Sample):
hex_idx = args.hex_vertex_indices[s.element_index]
w_p = s.element_coords
w_m = Coords(1.0) - s.element_coords
# 0 : m m m
# 1 : p m m
# 2 : p p m
# 3 : m p m
# 4 : m m p
# 5 : p m p
# 6 : p p p
# 7 : m p p
return (
w_m[0] * w_m[1] * w_m[2] * args.positions[hex_idx[0]]
+ w_p[0] * w_m[1] * w_m[2] * args.positions[hex_idx[1]]
+ w_p[0] * w_p[1] * w_m[2] * args.positions[hex_idx[2]]
+ w_m[0] * w_p[1] * w_m[2] * args.positions[hex_idx[3]]
+ w_m[0] * w_m[1] * w_p[2] * args.positions[hex_idx[4]]
+ w_p[0] * w_m[1] * w_p[2] * args.positions[hex_idx[5]]
+ w_p[0] * w_p[1] * w_p[2] * args.positions[hex_idx[6]]
+ w_m[0] * w_p[1] * w_p[2] * args.positions[hex_idx[7]]
)
@wp.func
def cell_deformation_gradient(cell_arg: CellArg, s: Sample):
"""Deformation gradient at `coords`"""
"""Transposed deformation gradient at `coords`"""
hex_idx = cell_arg.hex_vertex_indices[s.element_index]
w_p = s.element_coords
w_m = Coords(1.0) - s.element_coords
return (
wp.outer(cell_arg.positions[hex_idx[0]], wp.vec3(-w_m[1] * w_m[2], -w_m[0] * w_m[2], -w_m[0] * w_m[1]))
+ wp.outer(cell_arg.positions[hex_idx[1]], wp.vec3(w_m[1] * w_m[2], -w_p[0] * w_m[2], -w_p[0] * w_m[1]))
+ wp.outer(cell_arg.positions[hex_idx[2]], wp.vec3(w_p[1] * w_m[2], w_p[0] * w_m[2], -w_p[0] * w_p[1]))
+ wp.outer(cell_arg.positions[hex_idx[3]], wp.vec3(-w_p[1] * w_m[2], w_m[0] * w_m[2], -w_m[0] * w_p[1]))
+ wp.outer(cell_arg.positions[hex_idx[4]], wp.vec3(-w_m[1] * w_p[2], -w_m[0] * w_p[2], w_m[0] * w_m[1]))
+ wp.outer(cell_arg.positions[hex_idx[5]], wp.vec3(w_m[1] * w_p[2], -w_p[0] * w_p[2], w_p[0] * w_m[1]))
+ wp.outer(cell_arg.positions[hex_idx[6]], wp.vec3(w_p[1] * w_p[2], w_p[0] * w_p[2], w_p[0] * w_p[1]))
+ wp.outer(cell_arg.positions[hex_idx[7]], wp.vec3(-w_p[1] * w_p[2], w_m[0] * w_p[2], w_m[0] * w_p[1]))
)
@wp.func
def cell_inverse_deformation_gradient(cell_arg: CellArg, s: Sample):
return wp.inverse(Hexmesh.cell_deformation_gradient(cell_arg, s))
@wp.func
def cell_measure(args: CellArg, s: Sample):
return wp.abs(wp.determinant(Hexmesh.cell_deformation_gradient(args, s)))
@wp.func
def cell_normal(args: CellArg, s: Sample):
return wp.vec3(0.0)
@cached_arg_value
def side_index_arg_value(self, device) -> SideIndexArg:
args = self.SideIndexArg()
args.boundary_face_indices = self._boundary_face_indices.to(device)
return args
@wp.func
def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
"""Boundary side to side index"""
return args.boundary_face_indices[boundary_side_index]
@cached_arg_value
def side_arg_value(self, device) -> CellArg:
args = self.SideArg()
args.cell_arg = self.cell_arg_value(device)
args.face_vertex_indices = self._face_vertex_indices.to(device)
args.face_hex_indices = self._face_hex_indices.to(device)
args.face_hex_face_orientation = self._face_hex_face_orientation.to(device)
return args
@wp.func
def side_position(args: SideArg, s: Sample):
face_idx = args.face_vertex_indices[s.element_index]
w_p = s.element_coords
w_m = Coords(1.0) - s.element_coords
return (
w_m[0] * w_m[1] * args.cell_arg.positions[face_idx[0]]
+ w_p[0] * w_m[1] * args.cell_arg.positions[face_idx[1]]
+ w_p[0] * w_p[1] * args.cell_arg.positions[face_idx[2]]
+ w_m[0] * w_p[1] * args.cell_arg.positions[face_idx[3]]
)
@wp.func
def _side_deformation_vecs(args: SideArg, side_index: ElementIndex, coords: Coords):
face_idx = args.face_vertex_indices[side_index]
p0 = args.cell_arg.positions[face_idx[0]]
p1 = args.cell_arg.positions[face_idx[1]]
p2 = args.cell_arg.positions[face_idx[2]]
p3 = args.cell_arg.positions[face_idx[3]]
w_p = coords
w_m = Coords(1.0) - coords
v1 = w_m[1] * (p1 - p0) + w_p[1] * (p2 - p3)
v2 = w_p[0] * (p2 - p1) + w_m[0] * (p3 - p0)
return v1, v2
@wp.func
def side_deformation_gradient(args: SideArg, s:Sample):
"""Transposed side deformation gradient at `coords`"""
v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
return _mat32(v1, v2)
@wp.func
def side_inner_inverse_deformation_gradient(args: SideArg, s:Sample):
cell_index = Hexmesh.side_inner_cell_index(args, s.element_index)
cell_coords = Hexmesh.side_inner_cell_coords(args, s.element_index, s.element_coords)
return Hexmesh.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
@wp.func
def side_outer_inverse_deformation_gradient(args: SideArg, s:Sample):
cell_index = Hexmesh.side_outer_cell_index(args, s.element_index)
cell_coords = Hexmesh.side_outer_cell_coords(args, s.element_index, s.element_coords)
return Hexmesh.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
@wp.func
def side_measure(args: SideArg, s: Sample):
v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
return wp.length(wp.cross(v1, v2))
@wp.func
def side_measure_ratio(args: SideArg, s: Sample):
inner = Hexmesh.side_inner_cell_index(args, s.element_index)
outer = Hexmesh.side_outer_cell_index(args, s.element_index)
inner_coords = Hexmesh.side_inner_cell_coords(args, s.element_index, s.element_coords)
outer_coords = Hexmesh.side_outer_cell_coords(args, s.element_index, s.element_coords)
return Hexmesh.side_measure(args, s) / wp.min(
Hexmesh.cell_measure(args.cell_arg, make_free_sample(inner, inner_coords)),
Hexmesh.cell_measure(args.cell_arg, make_free_sample(outer, outer_coords)),
)
@wp.func
def side_normal(args: SideArg, s: Sample):
v1, v2 = Hexmesh._side_deformation_vecs(args, s.element_index, s.element_coords)
return wp.normalize(wp.cross(v1, v2))
@wp.func
def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
return arg.face_hex_indices[side_index][0]
@wp.func
def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
return arg.face_hex_indices[side_index][1]
@wp.func
def _hex_local_face_coords(hex_coords: Coords, face_index: int):
# Coordinatex in local face coordinates system
# Sign of last coordinate (out of face)
face_coords = wp.vec2(
hex_coords[_FACE_COORD_INDICES[face_index, 0]], hex_coords[_FACE_COORD_INDICES[face_index, 1]]
)
normal_coord = hex_coords[_FACE_COORD_INDICES[face_index, 2]]
normal_coord = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, normal_coord - 1.0, -normal_coord)
return face_coords, normal_coord
@wp.func
def _local_face_hex_coords(face_coords: wp.vec2, face_index: int):
# Coordinates in hex from local face coordinates system
hex_coords = Coords()
hex_coords[_FACE_COORD_INDICES[face_index, 0]] = face_coords[0]
hex_coords[_FACE_COORD_INDICES[face_index, 1]] = face_coords[1]
hex_coords[_FACE_COORD_INDICES[face_index, 2]] = wp.select(_FACE_COORD_INDICES[face_index, 3] == 0, 1.0, 0.0)
return hex_coords
@wp.func
def _local_from_oriented_face_coords(ori: int, oriented_coords: Coords):
fv = ori // 2
return (oriented_coords[0] - _FACE_TRANSLATION_F[fv, 0]) * _FACE_ORIENTATION_F[2 * ori] + (
oriented_coords[1] - _FACE_TRANSLATION_F[fv, 1]
) * _FACE_ORIENTATION_F[2 * ori + 1]
@wp.func
def _local_to_oriented_face_coords(ori: int, coords: wp.vec2):
fv = ori // 2
return Coords(
wp.dot(_FACE_ORIENTATION_F[2 * ori], coords) + _FACE_TRANSLATION_F[fv, 0],
wp.dot(_FACE_ORIENTATION_F[2 * ori + 1], coords) + _FACE_TRANSLATION_F[fv, 1],
0.0,
)
@wp.func
def face_to_hex_coords(local_face_index: int, face_orientation: int, side_coords: Coords):
local_coords = Hexmesh._local_from_oriented_face_coords(face_orientation, side_coords)
return Hexmesh._local_face_hex_coords(local_coords, local_face_index)
@wp.func
def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
local_face_index = args.face_hex_face_orientation[side_index][0]
face_orientation = args.face_hex_face_orientation[side_index][1]
return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
@wp.func
def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
local_face_index = args.face_hex_face_orientation[side_index][2]
face_orientation = args.face_hex_face_orientation[side_index][3]
return Hexmesh.face_to_hex_coords(local_face_index, face_orientation, side_coords)
@wp.func
def side_from_cell_coords(args: SideArg, side_index: ElementIndex, hex_index: ElementIndex, hex_coords: Coords):
if Hexmesh.side_inner_cell_index(args, side_index) == hex_index:
local_face_index = args.face_hex_face_orientation[side_index][0]
face_orientation = args.face_hex_face_orientation[side_index][1]
else:
local_face_index = args.face_hex_face_orientation[side_index][2]
face_orientation = args.face_hex_face_orientation[side_index][3]
face_coords, normal_coord = Hexmesh._hex_local_face_coords(hex_coords, local_face_index)
return wp.select(
normal_coord == 0.0, Coords(OUTSIDE), Hexmesh._local_to_oriented_face_coords(face_orientation, face_coords)
)
@wp.func
def side_to_cell_arg(side_arg: SideArg):
return side_arg.cell_arg
def _build_topology(self, temporary_store: TemporaryStore):
from warp.fem.utils import compress_node_indices, masked_indices
from warp.utils import array_scan
device = self.hex_vertex_indices.device
vertex_hex_offsets, vertex_hex_indices, _, __ = compress_node_indices(
self.vertex_count(), self.hex_vertex_indices, temporary_store=temporary_store
)
self._vertex_hex_offsets = vertex_hex_offsets.detach()
self._vertex_hex_indices = vertex_hex_indices.detach()
vertex_start_face_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
vertex_start_face_count.array.zero_()
vertex_start_face_offsets = borrow_temporary_like(vertex_start_face_count, temporary_store=temporary_store)
vertex_face_other_vs = borrow_temporary(
temporary_store, dtype=wp.vec3i, device=device, shape=(8 * self.cell_count())
)
vertex_face_hexes = borrow_temporary(
temporary_store, dtype=int, device=device, shape=(8 * self.cell_count(), 2)
)
# Count face edges starting at each vertex
wp.launch(
kernel=Hexmesh._count_starting_faces_kernel,
device=device,
dim=self.cell_count(),
inputs=[self.hex_vertex_indices, vertex_start_face_count.array],
)
array_scan(in_array=vertex_start_face_count.array, out_array=vertex_start_face_offsets.array, inclusive=False)
# Count number of unique edges (deduplicate across faces)
vertex_unique_face_count = vertex_start_face_count
wp.launch(
kernel=Hexmesh._count_unique_starting_faces_kernel,
device=device,
dim=self.vertex_count(),
inputs=[
self._vertex_hex_offsets,
self._vertex_hex_indices,
self.hex_vertex_indices,
vertex_start_face_offsets.array,
vertex_unique_face_count.array,
vertex_face_other_vs.array,
vertex_face_hexes.array,
],
)
vertex_unique_face_offsets = borrow_temporary_like(vertex_start_face_offsets, temporary_store=temporary_store)
array_scan(in_array=vertex_start_face_count.array, out_array=vertex_unique_face_offsets.array, inclusive=False)
# Get back edge count to host
if device.is_cuda:
face_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
# Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
wp.copy(
dest=face_count.array, src=vertex_unique_face_offsets.array, src_offset=self.vertex_count() - 1, count=1
)
wp.synchronize_stream(wp.get_stream(device))
face_count = int(face_count.array.numpy()[0])
else:
face_count = int(vertex_unique_face_offsets.array.numpy()[self.vertex_count() - 1])
self._face_vertex_indices = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
self._face_hex_indices = wp.empty(shape=(face_count,), dtype=wp.vec2i, device=device)
self._face_hex_face_orientation = wp.empty(shape=(face_count,), dtype=wp.vec4i, device=device)
boundary_mask = borrow_temporary(temporary_store, shape=(face_count,), dtype=int, device=device)
# Compress edge data
wp.launch(
kernel=Hexmesh._compress_faces_kernel,
device=device,
dim=self.vertex_count(),
inputs=[
vertex_start_face_offsets.array,
vertex_unique_face_offsets.array,
vertex_unique_face_count.array,
vertex_face_other_vs.array,
vertex_face_hexes.array,
self._face_vertex_indices,
self._face_hex_indices,
boundary_mask.array,
],
)
vertex_start_face_offsets.release()
vertex_unique_face_offsets.release()
vertex_unique_face_count.release()
vertex_face_other_vs.release()
vertex_face_hexes.release()
# Flip normals if necessary
wp.launch(
kernel=Hexmesh._flip_face_normals,
device=device,
dim=self.side_count(),
inputs=[self._face_vertex_indices, self._face_hex_indices, self.hex_vertex_indices, self.positions],
)
# Compute and store face orientation
wp.launch(
kernel=Hexmesh._compute_face_orientation,
device=device,
dim=self.side_count(),
inputs=[
self._face_vertex_indices,
self._face_hex_indices,
self.hex_vertex_indices,
self._face_hex_face_orientation,
],
)
boundary_face_indices, _ = masked_indices(boundary_mask.array)
self._boundary_face_indices = boundary_face_indices.detach()
def _compute_hex_edges(self, temporary_store: Optional[TemporaryStore] = None):
from warp.utils import array_scan
device = self.hex_vertex_indices.device
vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
vertex_start_edge_count.array.zero_()
vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(12 * self.cell_count()))
# Count face edges starting at each vertex
wp.launch(
kernel=Hexmesh._count_starting_edges_kernel,
device=device,
dim=self.cell_count(),
inputs=[self.hex_vertex_indices, vertex_start_edge_count.array],
)
array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
# Count number of unique edges (deduplicate across faces)
vertex_unique_edge_count = vertex_start_edge_count
wp.launch(
kernel=Hexmesh._count_unique_starting_edges_kernel,
device=device,
dim=self.vertex_count(),
inputs=[
self._vertex_hex_offsets,
self._vertex_hex_indices,
self.hex_vertex_indices,
vertex_start_edge_offsets.array,
vertex_unique_edge_count.array,
vertex_edge_ends.array,
],
)
vertex_unique_edge_offsets = borrow_temporary_like(
vertex_start_edge_offsets.array, temporary_store=temporary_store
)
array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
# Get back edge count to host
if device.is_cuda:
edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
# Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
wp.copy(
dest=edge_count.array,
src=vertex_unique_edge_offsets.array,
src_offset=self.vertex_count() - 1,
count=1,
)
wp.synchronize_stream(wp.get_stream(device))
self._edge_count = int(edge_count.array.numpy()[0])
else:
self._edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
self._hex_edge_indices = wp.empty(
dtype=int, device=self.hex_vertex_indices.device, shape=(self.cell_count(), 12)
)
# Compress edge data
wp.launch(
kernel=Hexmesh._compress_edges_kernel,
device=device,
dim=self.vertex_count(),
inputs=[
self._vertex_hex_offsets,
self._vertex_hex_indices,
self.hex_vertex_indices,
vertex_start_edge_offsets.array,
vertex_unique_edge_offsets.array,
vertex_unique_edge_count.array,
vertex_edge_ends.array,
self._hex_edge_indices,
],
)
vertex_start_edge_offsets.release()
vertex_unique_edge_offsets.release()
vertex_unique_edge_count.release()
vertex_edge_ends.release()
@wp.kernel
def _count_starting_faces_kernel(
hex_vertex_indices: wp.array2d(dtype=int), vertex_start_face_count: wp.array(dtype=int)
):
t = wp.tid()
for k in range(6):
vi = wp.vec4i(
hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 0]],
hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 1]],
hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 2]],
hex_vertex_indices[t, FACE_VERTEX_INDICES[k, 3]],
)
vm = wp.min(vi)
for i in range(4):
if vm == vi[i]:
wp.atomic_add(vertex_start_face_count, vm, 1)
@wp.func
def _face_sort(vidx: wp.vec4i, min_k: int):
v1 = vidx[(min_k + 1) % 4]
v2 = vidx[(min_k + 2) % 4]
v3 = vidx[(min_k + 3) % 4]
if v1 < v3:
return wp.vec3i(v1, v2, v3)
return wp.vec3i(v3, v2, v1)
@wp.func
def _find_face(
needle: wp.vec3i,
values: wp.array(dtype=wp.vec3i),
beg: int,
end: int,
):
for i in range(beg, end):
if values[i] == needle:
return i
return -1
@wp.kernel
def _count_unique_starting_faces_kernel(
vertex_hex_offsets: wp.array(dtype=int),
vertex_hex_indices: wp.array(dtype=int),
hex_vertex_indices: wp.array2d(dtype=int),
vertex_start_face_offsets: wp.array(dtype=int),
vertex_start_face_count: wp.array(dtype=int),
face_other_vs: wp.array(dtype=wp.vec3i),
face_hexes: wp.array2d(dtype=int),
):
v = wp.tid()
face_beg = vertex_start_face_offsets[v]
hex_beg = vertex_hex_offsets[v]
hex_end = vertex_hex_offsets[v + 1]
face_cur = face_beg
for hexa in range(hex_beg, hex_end):
hx = vertex_hex_indices[hexa]
for k in range(6):
vi = wp.vec4i(
hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 0]],
hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 1]],
hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 2]],
hex_vertex_indices[hx, FACE_VERTEX_INDICES[k, 3]],
)
min_i = int(wp.argmin(vi))
if v == vi[min_i]:
other_v = Hexmesh._face_sort(vi, min_i)
# Check if other_v has been seen
seen_idx = Hexmesh._find_face(other_v, face_other_vs, face_beg, face_cur)
if seen_idx == -1:
face_other_vs[face_cur] = other_v
face_hexes[face_cur, 0] = hx
face_hexes[face_cur, 1] = hx
face_cur += 1
else:
face_hexes[seen_idx, 1] = hx
vertex_start_face_count[v] = face_cur - face_beg
@wp.kernel
def _compress_faces_kernel(
vertex_start_face_offsets: wp.array(dtype=int),
vertex_unique_face_offsets: wp.array(dtype=int),
vertex_unique_face_count: wp.array(dtype=int),
uncompressed_face_other_vs: wp.array(dtype=wp.vec3i),
uncompressed_face_hexes: wp.array2d(dtype=int),
face_vertex_indices: wp.array(dtype=wp.vec4i),
face_hex_indices: wp.array(dtype=wp.vec2i),
boundary_mask: wp.array(dtype=int),
):
v = wp.tid()
start_beg = vertex_start_face_offsets[v]
unique_beg = vertex_unique_face_offsets[v]
unique_count = vertex_unique_face_count[v]
for f in range(unique_count):
src_index = start_beg + f
face_index = unique_beg + f
face_vertex_indices[face_index] = wp.vec4i(
v,
uncompressed_face_other_vs[src_index][0],
uncompressed_face_other_vs[src_index][1],
uncompressed_face_other_vs[src_index][2],
)
hx0 = uncompressed_face_hexes[src_index, 0]
hx1 = uncompressed_face_hexes[src_index, 1]
face_hex_indices[face_index] = wp.vec2i(hx0, hx1)
if hx0 == hx1:
boundary_mask[face_index] = 1
else:
boundary_mask[face_index] = 0
@wp.kernel
def _flip_face_normals(
face_vertex_indices: wp.array(dtype=wp.vec4i),
face_hex_indices: wp.array(dtype=wp.vec2i),
hex_vertex_indices: wp.array2d(dtype=int),
positions: wp.array(dtype=wp.vec3),
):
f = wp.tid()
hexa = face_hex_indices[f][0]
hex_vidx = hex_vertex_indices[hexa]
face_vidx = face_vertex_indices[f]
hex_centroid = (
positions[hex_vidx[0]]
+ positions[hex_vidx[1]]
+ positions[hex_vidx[2]]
+ positions[hex_vidx[3]]
+ positions[hex_vidx[4]]
+ positions[hex_vidx[5]]
+ positions[hex_vidx[6]]
+ positions[hex_vidx[7]]
) / 8.0
v0 = positions[face_vidx[0]]
v1 = positions[face_vidx[1]]
v2 = positions[face_vidx[2]]
v3 = positions[face_vidx[3]]
face_center = (v1 + v0 + v2 + v3) / 4.0
face_normal = wp.cross(v2 - v0, v3 - v1)
# if face normal points toward first tet centroid, flip indices
if wp.dot(hex_centroid - face_center, face_normal) > 0.0:
face_vertex_indices[f] = wp.vec4i(face_vidx[0], face_vidx[3], face_vidx[2], face_vidx[1])
@wp.func
def _find_face_orientation(face_vidx: wp.vec4i, hex_index: int, hex_vertex_indices: wp.array2d(dtype=int)):
hex_vidx = hex_vertex_indices[hex_index]
# Find local index in hex corresponding to face
face_min_i = int(wp.argmin(face_vidx))
face_other_v = Hexmesh._face_sort(face_vidx, face_min_i)
for k in range(6):
hex_face_vi = wp.vec4i(
hex_vidx[FACE_VERTEX_INDICES[k, 0]],
hex_vidx[FACE_VERTEX_INDICES[k, 1]],
hex_vidx[FACE_VERTEX_INDICES[k, 2]],
hex_vidx[FACE_VERTEX_INDICES[k, 3]],
)
hex_min_i = int(wp.argmin(hex_face_vi))
hex_other_v = Hexmesh._face_sort(hex_face_vi, hex_min_i)
if hex_other_v == face_other_v:
local_face_index = k
break
# Find starting vertex index
for k in range(4):
if face_vidx[k] == hex_face_vi[0]:
face_orientation = 2 * k
if face_vidx[(k + 1) % 4] != hex_face_vi[1]:
face_orientation += 1
return local_face_index, face_orientation
@wp.kernel
def _compute_face_orientation(
face_vertex_indices: wp.array(dtype=wp.vec4i),
face_hex_indices: wp.array(dtype=wp.vec2i),
hex_vertex_indices: wp.array2d(dtype=int),
face_hex_face_ori: wp.array(dtype=wp.vec4i),
):
f = wp.tid()
face_vidx = face_vertex_indices[f]
hx0 = face_hex_indices[f][0]
local_face_0, ori_0 = Hexmesh._find_face_orientation(face_vidx, hx0, hex_vertex_indices)
hx1 = face_hex_indices[f][1]
if hx0 == hx1:
face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_0, ori_0)
else:
local_face_1, ori_1 = Hexmesh._find_face_orientation(face_vidx, hx1, hex_vertex_indices)
face_hex_face_ori[f] = wp.vec4i(local_face_0, ori_0, local_face_1, ori_1)
@wp.kernel
def _count_starting_edges_kernel(
hex_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
):
t = wp.tid()
for k in range(12):
v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
if v0 < v1:
wp.atomic_add(vertex_start_edge_count, v0, 1)
else:
wp.atomic_add(vertex_start_edge_count, v1, 1)
@wp.func
def _find_edge(
needle: int,
values: wp.array(dtype=int),
beg: int,
end: int,
):
for i in range(beg, end):
if values[i] == needle:
return i
return -1
@wp.kernel
def _count_unique_starting_edges_kernel(
vertex_hex_offsets: wp.array(dtype=int),
vertex_hex_indices: wp.array(dtype=int),
hex_vertex_indices: wp.array2d(dtype=int),
vertex_start_edge_offsets: wp.array(dtype=int),
vertex_start_edge_count: wp.array(dtype=int),
edge_ends: wp.array(dtype=int),
):
v = wp.tid()
edge_beg = vertex_start_edge_offsets[v]
hex_beg = vertex_hex_offsets[v]
hex_end = vertex_hex_offsets[v + 1]
edge_cur = edge_beg
for tet in range(hex_beg, hex_end):
t = vertex_hex_indices[tet]
for k in range(12):
v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
if v == wp.min(v0, v1):
other_v = wp.max(v0, v1)
if Hexmesh._find_edge(other_v, edge_ends, edge_beg, edge_cur) == -1:
edge_ends[edge_cur] = other_v
edge_cur += 1
vertex_start_edge_count[v] = edge_cur - edge_beg
@wp.kernel
def _compress_edges_kernel(
vertex_hex_offsets: wp.array(dtype=int),
vertex_hex_indices: wp.array(dtype=int),
hex_vertex_indices: wp.array2d(dtype=int),
vertex_start_edge_offsets: wp.array(dtype=int),
vertex_unique_edge_offsets: wp.array(dtype=int),
vertex_unique_edge_count: wp.array(dtype=int),
uncompressed_edge_ends: wp.array(dtype=int),
hex_edge_indices: wp.array2d(dtype=int),
):
v = wp.tid()
uncompressed_beg = vertex_start_edge_offsets[v]
unique_beg = vertex_unique_edge_offsets[v]
unique_count = vertex_unique_edge_count[v]
hex_beg = vertex_hex_offsets[v]
hex_end = vertex_hex_offsets[v + 1]
for tet in range(hex_beg, hex_end):
t = vertex_hex_indices[tet]
for k in range(12):
v0 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 0]]
v1 = hex_vertex_indices[t, EDGE_VERTEX_INDICES[k, 1]]
if v == wp.min(v0, v1):
other_v = wp.max(v0, v1)
edge_id = (
Hexmesh._find_edge(
other_v, uncompressed_edge_ends, uncompressed_beg, uncompressed_beg + unique_count
)
- uncompressed_beg
+ unique_beg
)
hex_edge_indices[t][k] = edge_id