from typing import Optional import warp as wp from warp.fem.cache import ( TemporaryStore, borrow_temporary, borrow_temporary_like, cached_arg_value, ) from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample from .element import LinearEdge, Square from .geometry import Geometry # from .closest_point import project_on_tet_at_origin @wp.struct class Quadmesh2DCellArg: quad_vertex_indices: wp.array2d(dtype=int) positions: wp.array(dtype=wp.vec2) # for neighbor cell lookup vertex_quad_offsets: wp.array(dtype=int) vertex_quad_indices: wp.array(dtype=int) @wp.struct class Quadmesh2DSideArg: cell_arg: Quadmesh2DCellArg edge_vertex_indices: wp.array(dtype=wp.vec2i) edge_quad_indices: wp.array(dtype=wp.vec2i) class Quadmesh2D(Geometry): """Two-dimensional quadrilateral mesh geometry""" dimension = 2 def __init__( self, quad_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None ): """ Constructs a two-dimensional quadrilateral mesh. Args: quad_vertex_indices: warp array of shape (num_tris, 4) containing vertex indices for each quad, in counter-clockwise order positions: warp array of shape (num_vertices, 2) containing 2d position for each vertex temporary_store: shared pool from which to allocate temporary arrays """ self.quad_vertex_indices = quad_vertex_indices self.positions = positions self._edge_vertex_indices: wp.array = None self._edge_quad_indices: wp.array = None self._vertex_quad_offsets: wp.array = None self._vertex_quad_indices: wp.array = None self._build_topology(temporary_store) def cell_count(self): return self.quad_vertex_indices.shape[0] def vertex_count(self): return self.positions.shape[0] def side_count(self): return self._edge_vertex_indices.shape[0] def boundary_side_count(self): return self._boundary_edge_indices.shape[0] def reference_cell(self) -> Square: return Square() def reference_side(self) -> LinearEdge: return LinearEdge() @property def edge_quad_indices(self) -> wp.array: return self._edge_quad_indices @property def edge_vertex_indices(self) -> wp.array: return self._edge_vertex_indices CellArg = Quadmesh2DCellArg SideArg = Quadmesh2DSideArg @wp.struct class SideIndexArg: boundary_edge_indices: wp.array(dtype=int) # Geometry device interface @cached_arg_value def cell_arg_value(self, device) -> CellArg: args = self.CellArg() args.quad_vertex_indices = self.quad_vertex_indices.to(device) args.positions = self.positions.to(device) args.vertex_quad_offsets = self._vertex_quad_offsets.to(device) args.vertex_quad_indices = self._vertex_quad_indices.to(device) return args @wp.func def cell_position(args: CellArg, s: Sample): quad_idx = args.quad_vertex_indices[s.element_index] w_p = s.element_coords w_m = Coords(1.0) - s.element_coords # 0 : m m # 1 : p m # 2 : p p # 3 : m p return ( w_m[0] * w_m[1] * args.positions[quad_idx[0]] + w_p[0] * w_m[1] * args.positions[quad_idx[1]] + w_p[0] * w_p[1] * args.positions[quad_idx[2]] + w_m[0] * w_p[1] * args.positions[quad_idx[3]] ) @wp.func def cell_deformation_gradient(cell_arg: CellArg, s: Sample): """Deformation gradient at `coords`""" quad_idx = cell_arg.quad_vertex_indices[s.element_index] w_p = s.element_coords w_m = Coords(1.0) - s.element_coords return ( wp.outer(cell_arg.positions[quad_idx[0]], wp.vec2(-w_m[1], -w_m[0])) + wp.outer(cell_arg.positions[quad_idx[1]], wp.vec2(w_m[1], -w_p[0])) + wp.outer(cell_arg.positions[quad_idx[2]], wp.vec2(w_p[1], w_p[0])) + wp.outer(cell_arg.positions[quad_idx[3]], wp.vec2(-w_p[1], w_m[0])) ) @wp.func def cell_inverse_deformation_gradient(cell_arg: CellArg, s: Sample): return wp.inverse(Quadmesh2D.cell_deformation_gradient(cell_arg, s)) @wp.func def cell_measure(args: CellArg, s: Sample): return wp.abs(wp.determinant(Quadmesh2D.cell_deformation_gradient(args, s))) @wp.func def cell_normal(args: CellArg, s: Sample): return wp.vec2(0.0) @cached_arg_value def side_index_arg_value(self, device) -> SideIndexArg: args = self.SideIndexArg() args.boundary_edge_indices = self._boundary_edge_indices.to(device) return args @wp.func def boundary_side_index(args: SideIndexArg, boundary_side_index: int): """Boundary side to side index""" return args.boundary_edge_indices[boundary_side_index] @cached_arg_value def side_arg_value(self, device) -> CellArg: args = self.SideArg() args.cell_arg = self.cell_arg_value(device) args.edge_vertex_indices = self._edge_vertex_indices.to(device) args.edge_quad_indices = self._edge_quad_indices.to(device) return args @wp.func def side_position(args: SideArg, s: Sample): edge_idx = args.edge_vertex_indices[s.element_index] return (1.0 - s.element_coords[0]) * args.cell_arg.positions[edge_idx[0]] + s.element_coords[ 0 ] * args.cell_arg.positions[edge_idx[1]] @wp.func def side_deformation_gradient(args: SideArg, s: Sample): edge_idx = args.edge_vertex_indices[s.element_index] v0 = args.cell_arg.positions[edge_idx[0]] v1 = args.cell_arg.positions[edge_idx[1]] return v1 - v0 @wp.func def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample): cell_index = Quadmesh2D.side_inner_cell_index(args, s.element_index) cell_coords = Quadmesh2D.side_inner_cell_coords(args, s.element_index, s.element_coords) return Quadmesh2D.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords)) @wp.func def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample): cell_index = Quadmesh2D.side_outer_cell_index(args, s.element_index) cell_coords = Quadmesh2D.side_outer_cell_coords(args, s.element_index, s.element_coords) return Quadmesh2D.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords)) @wp.func def side_measure(args: SideArg, s: Sample): edge_idx = args.edge_vertex_indices[s.element_index] v0 = args.cell_arg.positions[edge_idx[0]] v1 = args.cell_arg.positions[edge_idx[1]] return wp.length(v1 - v0) @wp.func def side_measure_ratio(args: SideArg, s: Sample): inner = Quadmesh2D.side_inner_cell_index(args, s.element_index) outer = Quadmesh2D.side_outer_cell_index(args, s.element_index) inner_coords = Quadmesh2D.side_inner_cell_coords(args, s.element_index, s.element_coords) outer_coords = Quadmesh2D.side_outer_cell_coords(args, s.element_index, s.element_coords) return Quadmesh2D.side_measure(args, s) / wp.min( Quadmesh2D.cell_measure(args.cell_arg, make_free_sample(inner, inner_coords)), Quadmesh2D.cell_measure(args.cell_arg, make_free_sample(outer, outer_coords)), ) @wp.func def side_normal(args: SideArg, s: Sample): edge_idx = args.edge_vertex_indices[s.element_index] v0 = args.cell_arg.positions[edge_idx[0]] v1 = args.cell_arg.positions[edge_idx[1]] e = v1 - v0 return wp.normalize(wp.vec2(-e[1], e[0])) @wp.func def side_inner_cell_index(arg: SideArg, side_index: ElementIndex): return arg.edge_quad_indices[side_index][0] @wp.func def side_outer_cell_index(arg: SideArg, side_index: ElementIndex): return arg.edge_quad_indices[side_index][1] @wp.func def edge_to_quad_coords(args: SideArg, side_index: ElementIndex, quad_index: ElementIndex, side_coords: Coords): edge_vidx = args.edge_vertex_indices[side_index] quad_vidx = args.cell_arg.quad_vertex_indices[quad_index] vs = edge_vidx[0] ve = edge_vidx[1] s = side_coords[0] if vs == quad_vidx[0]: return wp.select(ve == quad_vidx[1], Coords(0.0, s, 0.0), Coords(s, 0.0, 0.0)) elif vs == quad_vidx[1]: return wp.select(ve == quad_vidx[2], Coords(1.0 - s, 0.0, 0.0), Coords(1.0, s, 0.0)) elif vs == quad_vidx[2]: return wp.select(ve == quad_vidx[3], Coords(1.0, 1.0 - s, 0.0), Coords(1.0 - s, 1.0, 0.0)) return wp.select(ve == quad_vidx[0], Coords(s, 1.0, 0.0), Coords(0.0, 1.0 - s, 0.0)) @wp.func def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords): inner_cell_index = Quadmesh2D.side_inner_cell_index(args, side_index) return Quadmesh2D.edge_to_quad_coords(args, side_index, inner_cell_index, side_coords) @wp.func def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords): outer_cell_index = Quadmesh2D.side_outer_cell_index(args, side_index) return Quadmesh2D.edge_to_quad_coords(args, side_index, outer_cell_index, side_coords) @wp.func def side_from_cell_coords( args: SideArg, side_index: ElementIndex, quad_index: ElementIndex, quad_coords: Coords, ): edge_vidx = args.edge_vertex_indices[side_index] quad_vidx = args.cell_arg.quad_vertex_indices[quad_index] vs = edge_vidx[0] ve = edge_vidx[1] cx = quad_coords[0] cy = quad_coords[1] if vs == quad_vidx[0]: oc = wp.select(ve == quad_vidx[1], cx, cy) ec = wp.select(ve == quad_vidx[1], cy, cx) elif vs == quad_vidx[1]: oc = wp.select(ve == quad_vidx[2], cy, 1.0 - cx) ec = wp.select(ve == quad_vidx[2], 1.0 - cx, cy) elif vs == quad_vidx[2]: oc = wp.select(ve == quad_vidx[3], 1.0 - cx, 1.0 - cy) ec = wp.select(ve == quad_vidx[3], 1.0 - cy, 1.0 - cx) else: oc = wp.select(ve == quad_vidx[0], 1.0 - cy, cx) ec = wp.select(ve == quad_vidx[0], cx, 1.0 - cy) return wp.select(oc == 0.0, Coords(OUTSIDE), Coords(ec, 0.0, 0.0)) @wp.func def side_to_cell_arg(side_arg: SideArg): return side_arg.cell_arg def _build_topology(self, temporary_store: TemporaryStore): from warp.fem.utils import compress_node_indices, masked_indices from warp.utils import array_scan device = self.quad_vertex_indices.device vertex_quad_offsets, vertex_quad_indices, _, __ = compress_node_indices( self.vertex_count(), self.quad_vertex_indices, temporary_store=temporary_store ) self._vertex_quad_offsets = vertex_quad_offsets.detach() self._vertex_quad_indices = vertex_quad_indices.detach() vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count()) vertex_start_edge_count.array.zero_() vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store) vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(4 * self.cell_count())) vertex_edge_quads = borrow_temporary( temporary_store, dtype=int, device=device, shape=(4 * self.cell_count(), 2) ) # Count face edges starting at each vertex wp.launch( kernel=Quadmesh2D._count_starting_edges_kernel, device=device, dim=self.cell_count(), inputs=[self.quad_vertex_indices, vertex_start_edge_count.array], ) array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False) # Count number of unique edges (deduplicate across faces) vertex_unique_edge_count = vertex_start_edge_count wp.launch( kernel=Quadmesh2D._count_unique_starting_edges_kernel, device=device, dim=self.vertex_count(), inputs=[ self._vertex_quad_offsets, self._vertex_quad_indices, self.quad_vertex_indices, vertex_start_edge_offsets.array, vertex_unique_edge_count.array, vertex_edge_ends.array, vertex_edge_quads.array, ], ) vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store) array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False) # Get back edge count to host if device.is_cuda: edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True) # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok wp.copy( dest=edge_count.array, src=vertex_unique_edge_offsets.array, src_offset=self.vertex_count() - 1, count=1 ) wp.synchronize_stream(wp.get_stream(device)) edge_count = int(edge_count.array.numpy()[0]) else: edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1]) self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device) self._edge_quad_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device) boundary_mask = borrow_temporary(temporary_store=temporary_store, shape=(edge_count,), dtype=int, device=device) # Compress edge data wp.launch( kernel=Quadmesh2D._compress_edges_kernel, device=device, dim=self.vertex_count(), inputs=[ vertex_start_edge_offsets.array, vertex_unique_edge_offsets.array, vertex_unique_edge_count.array, vertex_edge_ends.array, vertex_edge_quads.array, self._edge_vertex_indices, self._edge_quad_indices, boundary_mask.array, ], ) vertex_start_edge_offsets.release() vertex_unique_edge_offsets.release() vertex_unique_edge_count.release() vertex_edge_ends.release() vertex_edge_quads.release() # Flip normals if necessary wp.launch( kernel=Quadmesh2D._flip_edge_normals, device=device, dim=self.side_count(), inputs=[self._edge_vertex_indices, self._edge_quad_indices, self.quad_vertex_indices, self.positions], ) boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store) self._boundary_edge_indices = boundary_edge_indices.detach() boundary_mask.release() @wp.kernel def _count_starting_edges_kernel( quad_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int) ): t = wp.tid() for k in range(4): v0 = quad_vertex_indices[t, k] v1 = quad_vertex_indices[t, (k + 1) % 4] if v0 < v1: wp.atomic_add(vertex_start_edge_count, v0, 1) else: wp.atomic_add(vertex_start_edge_count, v1, 1) @wp.func def _find( needle: int, values: wp.array(dtype=int), beg: int, end: int, ): for i in range(beg, end): if values[i] == needle: return i return -1 @wp.kernel def _count_unique_starting_edges_kernel( vertex_quad_offsets: wp.array(dtype=int), vertex_quad_indices: wp.array(dtype=int), quad_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_offsets: wp.array(dtype=int), vertex_start_edge_count: wp.array(dtype=int), edge_ends: wp.array(dtype=int), edge_quads: wp.array2d(dtype=int), ): v = wp.tid() edge_beg = vertex_start_edge_offsets[v] quad_beg = vertex_quad_offsets[v] quad_end = vertex_quad_offsets[v + 1] edge_cur = edge_beg for quad in range(quad_beg, quad_end): q = vertex_quad_indices[quad] for k in range(4): v0 = quad_vertex_indices[q, k] v1 = quad_vertex_indices[q, (k + 1) % 4] if v == wp.min(v0, v1): other_v = wp.max(v0, v1) # Check if other_v has been seen seen_idx = Quadmesh2D._find(other_v, edge_ends, edge_beg, edge_cur) if seen_idx == -1: edge_ends[edge_cur] = other_v edge_quads[edge_cur, 0] = q edge_quads[edge_cur, 1] = q edge_cur += 1 else: edge_quads[seen_idx, 1] = q vertex_start_edge_count[v] = edge_cur - edge_beg @wp.kernel def _compress_edges_kernel( vertex_start_edge_offsets: wp.array(dtype=int), vertex_unique_edge_offsets: wp.array(dtype=int), vertex_unique_edge_count: wp.array(dtype=int), uncompressed_edge_ends: wp.array(dtype=int), uncompressed_edge_quads: wp.array2d(dtype=int), edge_vertex_indices: wp.array(dtype=wp.vec2i), edge_quad_indices: wp.array(dtype=wp.vec2i), boundary_mask: wp.array(dtype=int), ): v = wp.tid() start_beg = vertex_start_edge_offsets[v] unique_beg = vertex_unique_edge_offsets[v] unique_count = vertex_unique_edge_count[v] for e in range(unique_count): src_index = start_beg + e edge_index = unique_beg + e edge_vertex_indices[edge_index] = wp.vec2i(v, uncompressed_edge_ends[src_index]) q0 = uncompressed_edge_quads[src_index, 0] q1 = uncompressed_edge_quads[src_index, 1] edge_quad_indices[edge_index] = wp.vec2i(q0, q1) if q0 == q1: boundary_mask[edge_index] = 1 else: boundary_mask[edge_index] = 0 @wp.kernel def _flip_edge_normals( edge_vertex_indices: wp.array(dtype=wp.vec2i), edge_quad_indices: wp.array(dtype=wp.vec2i), quad_vertex_indices: wp.array2d(dtype=int), positions: wp.array(dtype=wp.vec2), ): e = wp.tid() tri = edge_quad_indices[e][0] quad_vidx = quad_vertex_indices[tri] edge_vidx = edge_vertex_indices[e] quad_centroid = ( positions[quad_vidx[0]] + positions[quad_vidx[1]] + positions[quad_vidx[2]] + positions[quad_vidx[3]] ) / 4.0 v0 = positions[edge_vidx[0]] v1 = positions[edge_vidx[1]] edge_center = 0.5 * (v1 + v0) edge_vec = v1 - v0 edge_normal = wp.vec2(-edge_vec[1], edge_vec[0]) # if edge normal points toward first triangle centroid, flip indices if wp.dot(quad_centroid - edge_center, edge_normal) > 0.0: edge_vertex_indices[e] = wp.vec2i(edge_vidx[1], edge_vidx[0])