qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
from typing import Any
import warp as wp
from warp.fem.types import ElementIndex, Coords, Sample, make_free_sample, OUTSIDE
from warp.fem.cache import cached_arg_value
from .geometry import Geometry
from .element import Square, Cube
@wp.struct
class Grid3DCellArg:
res: wp.vec3i
cell_size: wp.vec3
origin: wp.vec3
_mat32 = wp.mat(shape=(3, 2), dtype=float)
class Grid3D(Geometry):
"""Three-dimensional regular grid geometry"""
dimension = 3
Permutation = wp.types.matrix(shape=(3, 3), dtype=int)
LOC_TO_WORLD = wp.constant(Permutation(0, 1, 2, 1, 2, 0, 2, 0, 1))
WORLD_TO_LOC = wp.constant(Permutation(0, 1, 2, 2, 0, 1, 1, 2, 0))
def __init__(self, res: wp.vec3i, bounds_lo: wp.vec3 = wp.vec3(0.0), bounds_hi: wp.vec3 = wp.vec3(1.0)):
"""Constructs a dense 3D grid
Args:
res: Resolution of the grid along each dimension
bounds_lo: Position of the lower bound of the axis-aligned grid
bounds_up: Position of the upper bound of the axis-aligned grid
"""
self.bounds_lo = bounds_lo
self.bounds_hi = bounds_hi
self._res = res
@property
def extents(self) -> wp.vec3:
# Avoid using native sub due to higher over of calling builtins from Python
return wp.vec3(
self.bounds_hi[0] - self.bounds_lo[0],
self.bounds_hi[1] - self.bounds_lo[1],
self.bounds_hi[2] - self.bounds_lo[2],
)
@property
def cell_size(self) -> wp.vec3:
ex = self.extents
return wp.vec3(
ex[0] / self.res[0],
ex[1] / self.res[1],
ex[2] / self.res[2],
)
def cell_count(self):
return self.res[0] * self.res[1] * self.res[2]
def vertex_count(self):
return (self.res[0] + 1) * (self.res[1] + 1) * (self.res[2] + 1)
def side_count(self):
return (
(self.res[0] + 1) * (self.res[1]) * (self.res[2])
+ (self.res[0]) * (self.res[1] + 1) * (self.res[2])
+ (self.res[0]) * (self.res[1]) * (self.res[2] + 1)
)
def edge_count(self):
return (
(self.res[0] + 1) * (self.res[1] + 1) * (self.res[2])
+ (self.res[0]) * (self.res[1] + 1) * (self.res[2] + 1)
+ (self.res[0] + 1) * (self.res[1]) * (self.res[2] + 1)
)
def boundary_side_count(self):
return 2 * (self.res[1]) * (self.res[2]) + (self.res[0]) * 2 * (self.res[2]) + (self.res[0]) * (self.res[1]) * 2
def reference_cell(self) -> Cube:
return Cube()
def reference_side(self) -> Square:
return Square()
@property
def res(self):
return self._res
@property
def origin(self):
return self.bounds_lo
@property
def strides(self):
return wp.vec3i(self.res[1] * self.res[2], self.res[2], 1)
# Utility device functions
CellArg = Grid3DCellArg
Cell = wp.vec3i
@wp.func
def _to_3d_index(strides: wp.vec2i, index: int):
x = index // strides[0]
y = (index - strides[0] * x) // strides[1]
z = index - strides[0] * x - strides[1] * y
return wp.vec3i(x, y, z)
@wp.func
def _from_3d_index(strides: wp.vec2i, index: wp.vec3i):
return strides[0] * index[0] + strides[1] * index[1] + index[2]
@wp.func
def cell_index(res: wp.vec3i, cell: Cell):
strides = wp.vec2i(res[1] * res[2], res[2])
return Grid3D._from_3d_index(strides, cell)
@wp.func
def get_cell(res: wp.vec3i, cell_index: ElementIndex):
strides = wp.vec2i(res[1] * res[2], res[2])
return Grid3D._to_3d_index(strides, cell_index)
@wp.struct
class Side:
axis: int # normal
origin: wp.vec3i # index of vertex at corner (0,0,0)
@wp.struct
class SideArg:
cell_count: int
axis_offsets: wp.vec3i
cell_arg: Grid3DCellArg
SideIndexArg = SideArg
@wp.func
def _world_to_local(axis: int, vec: Any):
return type(vec)(
vec[Grid3D.LOC_TO_WORLD[axis, 0]],
vec[Grid3D.LOC_TO_WORLD[axis, 1]],
vec[Grid3D.LOC_TO_WORLD[axis, 2]],
)
@wp.func
def _local_to_world(axis: int, vec: Any):
return type(vec)(
vec[Grid3D.WORLD_TO_LOC[axis, 0]],
vec[Grid3D.WORLD_TO_LOC[axis, 1]],
vec[Grid3D.WORLD_TO_LOC[axis, 2]],
)
@wp.func
def side_index(arg: SideArg, side: Side):
alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
if side.origin[0] == arg.cell_arg.res[alt_axis]:
# Upper-boundary side
longitude = side.origin[1]
latitude = side.origin[2]
latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[side.axis, 2]]
lat_long = latitude_res * longitude + latitude
return 3 * arg.cell_count + arg.axis_offsets[side.axis] + lat_long
cell_index = Grid3D.cell_index(arg.cell_arg.res, Grid3D._local_to_world(side.axis, side.origin))
return side.axis * arg.cell_count + cell_index
@wp.func
def get_side(arg: SideArg, side_index: ElementIndex):
if side_index < 3 * arg.cell_count:
axis = side_index // arg.cell_count
cell_index = side_index - axis * arg.cell_count
origin = Grid3D._world_to_local(axis, Grid3D.get_cell(arg.cell_arg.res, cell_index))
return Grid3D.Side(axis, origin)
axis_side_index = side_index - 3 * arg.cell_count
if axis_side_index < arg.axis_offsets[1]:
axis = 0
elif axis_side_index < arg.axis_offsets[2]:
axis = 1
else:
axis = 2
altitude = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 0]]
lat_long = axis_side_index - arg.axis_offsets[axis]
latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]]
longitude = lat_long // latitude_res
latitude = lat_long - longitude * latitude_res
origin_loc = wp.vec3i(altitude, longitude, latitude)
return Grid3D.Side(axis, origin_loc)
# Geometry device interface
@cached_arg_value
def cell_arg_value(self, device) -> CellArg:
args = self.CellArg()
args.res = self.res
args.origin = self.bounds_lo
args.cell_size = self.cell_size
return args
@wp.func
def cell_position(args: CellArg, s: Sample):
cell = Grid3D.get_cell(args.res, s.element_index)
return (
wp.vec3(
(float(cell[0]) + s.element_coords[0]) * args.cell_size[0],
(float(cell[1]) + s.element_coords[1]) * args.cell_size[1],
(float(cell[2]) + s.element_coords[2]) * args.cell_size[2],
)
+ args.origin
)
@wp.func
def cell_deformation_gradient(args: CellArg, s: Sample):
return wp.diag(args.cell_size)
@wp.func
def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
return wp.diag(wp.cw_div(wp.vec3(1.0), args.cell_size))
@wp.func
def cell_lookup(args: CellArg, pos: wp.vec3):
loc_pos = wp.cw_div(pos - args.origin, args.cell_size)
x = wp.clamp(loc_pos[0], 0.0, float(args.res[0]))
y = wp.clamp(loc_pos[1], 0.0, float(args.res[1]))
z = wp.clamp(loc_pos[2], 0.0, float(args.res[2]))
x_cell = wp.min(wp.floor(x), float(args.res[0]) - 1.0)
y_cell = wp.min(wp.floor(y), float(args.res[1]) - 1.0)
z_cell = wp.min(wp.floor(z), float(args.res[2]) - 1.0)
coords = Coords(x - x_cell, y - y_cell, z - z_cell)
cell_index = Grid3D.cell_index(args.res, Grid3D.Cell(int(x_cell), int(y_cell), int(z_cell)))
return make_free_sample(cell_index, coords)
@wp.func
def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
return Grid3D.cell_lookup(args, pos)
@wp.func
def cell_measure(args: CellArg, s: Sample):
return args.cell_size[0] * args.cell_size[1] * args.cell_size[2]
@wp.func
def cell_normal(args: CellArg, s: Sample):
return wp.vec3(0.0)
@wp.func
def cell_transform_reference_gradient(args: CellArg, cell_index: ElementIndex, coords: Coords, ref_grad: wp.vec3):
return wp.cw_div(ref_grad, args.cell_size)
@cached_arg_value
def side_arg_value(self, device) -> SideArg:
args = self.SideArg()
axis_dims = wp.vec3i(
self.res[1] * self.res[2],
self.res[2] * self.res[0],
self.res[0] * self.res[1],
)
args.axis_offsets = wp.vec3i(
0,
axis_dims[0],
axis_dims[0] + axis_dims[1],
)
args.cell_count = self.cell_count()
args.cell_arg = self.cell_arg_value(device)
return args
def side_index_arg_value(self, device) -> SideIndexArg:
return self.side_arg_value(device)
@wp.func
def boundary_side_index(args: SideArg, boundary_side_index: int):
"""Boundary side to side index"""
axis_side_index = boundary_side_index // 2
border = boundary_side_index - 2 * axis_side_index
if axis_side_index < args.axis_offsets[1]:
axis = 0
elif axis_side_index < args.axis_offsets[2]:
axis = 1
else:
axis = 2
lat_long = axis_side_index - args.axis_offsets[axis]
latitude_res = args.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]]
longitude = lat_long // latitude_res
latitude = lat_long - longitude * latitude_res
altitude = border * args.cell_arg.res[axis]
side = Grid3D.Side(axis, wp.vec3i(altitude, longitude, latitude))
return Grid3D.side_index(args, side)
@wp.func
def side_position(args: SideArg, s: Sample):
side = Grid3D.get_side(args, s.element_index)
coord0 = wp.select(side.origin[0] == 0, s.element_coords[0], 1.0 - s.element_coords[0])
local_pos = wp.vec3(
float(side.origin[0]),
float(side.origin[1]) + coord0,
float(side.origin[2]) + s.element_coords[1],
)
pos = args.cell_arg.origin + wp.cw_mul(Grid3D._local_to_world(side.axis, local_pos), args.cell_arg.cell_size)
return pos
@wp.func
def side_deformation_gradient(args: SideArg, s: Sample):
side = Grid3D.get_side(args, s.element_index)
sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
return _mat32(
wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, sign, 0.0)), args.cell_arg.cell_size),
wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, 0.0, 1.0)), args.cell_arg.cell_size),
)
@wp.func
def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
return Grid3D.cell_inverse_deformation_gradient(args.cell_arg, s)
@wp.func
def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
return Grid3D.cell_inverse_deformation_gradient(args.cell_arg, s)
@wp.func
def side_measure(args: SideArg, s: Sample):
side = Grid3D.get_side(args, s.element_index)
long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1]
lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2]
return args.cell_arg.cell_size[long_axis] * args.cell_arg.cell_size[lat_axis]
@wp.func
def side_measure_ratio(args: SideArg, s: Sample):
side = Grid3D.get_side(args, s.element_index)
alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
return 1.0 / args.cell_arg.cell_size[alt_axis]
@wp.func
def side_normal(args: SideArg, s: Sample):
side = Grid3D.get_side(args, s.element_index)
sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
local_n = wp.vec3(sign, 0.0, 0.0)
return Grid3D._local_to_world(side.axis, local_n)
@wp.func
def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
side = Grid3D.get_side(arg, side_index)
inner_alt = wp.select(side.origin[0] == 0, side.origin[0] - 1, 0)
inner_origin = wp.vec3i(inner_alt, side.origin[1], side.origin[2])
cell = Grid3D._local_to_world(side.axis, inner_origin)
return Grid3D.cell_index(arg.cell_arg.res, cell)
@wp.func
def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
side = Grid3D.get_side(arg, side_index)
alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
outer_alt = wp.select(
side.origin[0] == arg.cell_arg.res[alt_axis], side.origin[0], arg.cell_arg.res[alt_axis] - 1
)
outer_origin = wp.vec3i(outer_alt, side.origin[1], side.origin[2])
cell = Grid3D._local_to_world(side.axis, outer_origin)
return Grid3D.cell_index(arg.cell_arg.res, cell)
@wp.func
def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
side = Grid3D.get_side(args, side_index)
inner_alt = wp.select(side.origin[0] == 0, 1.0, 0.0)
side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])
return Grid3D._local_to_world(side.axis, wp.vec3(inner_alt, side_coord0, side_coords[1]))
@wp.func
def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
side = Grid3D.get_side(args, side_index)
alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
outer_alt = wp.select(side.origin[0] == args.cell_arg.res[alt_axis], 0.0, 1.0)
side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])
return Grid3D._local_to_world(side.axis, wp.vec3(outer_alt, side_coord0, side_coords[1]))
@wp.func
def side_from_cell_coords(
args: SideArg,
side_index: ElementIndex,
element_index: ElementIndex,
element_coords: Coords,
):
side = Grid3D.get_side(args, side_index)
cell = Grid3D.get_cell(args.cell_arg.res, element_index)
if float(side.origin[0] - cell[side.axis]) == element_coords[side.axis]:
long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1]
lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2]
long_coord = element_coords[long_axis]
long_coord = wp.select(side.origin[0] == 0, long_coord, 1.0 - long_coord)
return Coords(long_coord, element_coords[lat_axis], 0.0)
return Coords(OUTSIDE)
@wp.func
def side_to_cell_arg(side_arg: SideArg):
return side_arg.cell_arg