Spaces:
Sleeping
Sleeping
| from typing import Any | |
| import warp as wp | |
| from warp.fem.types import ElementIndex, Coords, Sample, make_free_sample, OUTSIDE | |
| from warp.fem.cache import cached_arg_value | |
| from .geometry import Geometry | |
| from .element import Square, Cube | |
| class Grid3DCellArg: | |
| res: wp.vec3i | |
| cell_size: wp.vec3 | |
| origin: wp.vec3 | |
| _mat32 = wp.mat(shape=(3, 2), dtype=float) | |
| class Grid3D(Geometry): | |
| """Three-dimensional regular grid geometry""" | |
| dimension = 3 | |
| Permutation = wp.types.matrix(shape=(3, 3), dtype=int) | |
| LOC_TO_WORLD = wp.constant(Permutation(0, 1, 2, 1, 2, 0, 2, 0, 1)) | |
| WORLD_TO_LOC = wp.constant(Permutation(0, 1, 2, 2, 0, 1, 1, 2, 0)) | |
| def __init__(self, res: wp.vec3i, bounds_lo: wp.vec3 = wp.vec3(0.0), bounds_hi: wp.vec3 = wp.vec3(1.0)): | |
| """Constructs a dense 3D grid | |
| Args: | |
| res: Resolution of the grid along each dimension | |
| bounds_lo: Position of the lower bound of the axis-aligned grid | |
| bounds_up: Position of the upper bound of the axis-aligned grid | |
| """ | |
| self.bounds_lo = bounds_lo | |
| self.bounds_hi = bounds_hi | |
| self._res = res | |
| def extents(self) -> wp.vec3: | |
| # Avoid using native sub due to higher over of calling builtins from Python | |
| return wp.vec3( | |
| self.bounds_hi[0] - self.bounds_lo[0], | |
| self.bounds_hi[1] - self.bounds_lo[1], | |
| self.bounds_hi[2] - self.bounds_lo[2], | |
| ) | |
| def cell_size(self) -> wp.vec3: | |
| ex = self.extents | |
| return wp.vec3( | |
| ex[0] / self.res[0], | |
| ex[1] / self.res[1], | |
| ex[2] / self.res[2], | |
| ) | |
| def cell_count(self): | |
| return self.res[0] * self.res[1] * self.res[2] | |
| def vertex_count(self): | |
| return (self.res[0] + 1) * (self.res[1] + 1) * (self.res[2] + 1) | |
| def side_count(self): | |
| return ( | |
| (self.res[0] + 1) * (self.res[1]) * (self.res[2]) | |
| + (self.res[0]) * (self.res[1] + 1) * (self.res[2]) | |
| + (self.res[0]) * (self.res[1]) * (self.res[2] + 1) | |
| ) | |
| def edge_count(self): | |
| return ( | |
| (self.res[0] + 1) * (self.res[1] + 1) * (self.res[2]) | |
| + (self.res[0]) * (self.res[1] + 1) * (self.res[2] + 1) | |
| + (self.res[0] + 1) * (self.res[1]) * (self.res[2] + 1) | |
| ) | |
| def boundary_side_count(self): | |
| return 2 * (self.res[1]) * (self.res[2]) + (self.res[0]) * 2 * (self.res[2]) + (self.res[0]) * (self.res[1]) * 2 | |
| def reference_cell(self) -> Cube: | |
| return Cube() | |
| def reference_side(self) -> Square: | |
| return Square() | |
| def res(self): | |
| return self._res | |
| def origin(self): | |
| return self.bounds_lo | |
| def strides(self): | |
| return wp.vec3i(self.res[1] * self.res[2], self.res[2], 1) | |
| # Utility device functions | |
| CellArg = Grid3DCellArg | |
| Cell = wp.vec3i | |
| def _to_3d_index(strides: wp.vec2i, index: int): | |
| x = index // strides[0] | |
| y = (index - strides[0] * x) // strides[1] | |
| z = index - strides[0] * x - strides[1] * y | |
| return wp.vec3i(x, y, z) | |
| def _from_3d_index(strides: wp.vec2i, index: wp.vec3i): | |
| return strides[0] * index[0] + strides[1] * index[1] + index[2] | |
| def cell_index(res: wp.vec3i, cell: Cell): | |
| strides = wp.vec2i(res[1] * res[2], res[2]) | |
| return Grid3D._from_3d_index(strides, cell) | |
| def get_cell(res: wp.vec3i, cell_index: ElementIndex): | |
| strides = wp.vec2i(res[1] * res[2], res[2]) | |
| return Grid3D._to_3d_index(strides, cell_index) | |
| class Side: | |
| axis: int # normal | |
| origin: wp.vec3i # index of vertex at corner (0,0,0) | |
| class SideArg: | |
| cell_count: int | |
| axis_offsets: wp.vec3i | |
| cell_arg: Grid3DCellArg | |
| SideIndexArg = SideArg | |
| def _world_to_local(axis: int, vec: Any): | |
| return type(vec)( | |
| vec[Grid3D.LOC_TO_WORLD[axis, 0]], | |
| vec[Grid3D.LOC_TO_WORLD[axis, 1]], | |
| vec[Grid3D.LOC_TO_WORLD[axis, 2]], | |
| ) | |
| def _local_to_world(axis: int, vec: Any): | |
| return type(vec)( | |
| vec[Grid3D.WORLD_TO_LOC[axis, 0]], | |
| vec[Grid3D.WORLD_TO_LOC[axis, 1]], | |
| vec[Grid3D.WORLD_TO_LOC[axis, 2]], | |
| ) | |
| def side_index(arg: SideArg, side: Side): | |
| alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0] | |
| if side.origin[0] == arg.cell_arg.res[alt_axis]: | |
| # Upper-boundary side | |
| longitude = side.origin[1] | |
| latitude = side.origin[2] | |
| latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[side.axis, 2]] | |
| lat_long = latitude_res * longitude + latitude | |
| return 3 * arg.cell_count + arg.axis_offsets[side.axis] + lat_long | |
| cell_index = Grid3D.cell_index(arg.cell_arg.res, Grid3D._local_to_world(side.axis, side.origin)) | |
| return side.axis * arg.cell_count + cell_index | |
| def get_side(arg: SideArg, side_index: ElementIndex): | |
| if side_index < 3 * arg.cell_count: | |
| axis = side_index // arg.cell_count | |
| cell_index = side_index - axis * arg.cell_count | |
| origin = Grid3D._world_to_local(axis, Grid3D.get_cell(arg.cell_arg.res, cell_index)) | |
| return Grid3D.Side(axis, origin) | |
| axis_side_index = side_index - 3 * arg.cell_count | |
| if axis_side_index < arg.axis_offsets[1]: | |
| axis = 0 | |
| elif axis_side_index < arg.axis_offsets[2]: | |
| axis = 1 | |
| else: | |
| axis = 2 | |
| altitude = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 0]] | |
| lat_long = axis_side_index - arg.axis_offsets[axis] | |
| latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]] | |
| longitude = lat_long // latitude_res | |
| latitude = lat_long - longitude * latitude_res | |
| origin_loc = wp.vec3i(altitude, longitude, latitude) | |
| return Grid3D.Side(axis, origin_loc) | |
| # Geometry device interface | |
| def cell_arg_value(self, device) -> CellArg: | |
| args = self.CellArg() | |
| args.res = self.res | |
| args.origin = self.bounds_lo | |
| args.cell_size = self.cell_size | |
| return args | |
| def cell_position(args: CellArg, s: Sample): | |
| cell = Grid3D.get_cell(args.res, s.element_index) | |
| return ( | |
| wp.vec3( | |
| (float(cell[0]) + s.element_coords[0]) * args.cell_size[0], | |
| (float(cell[1]) + s.element_coords[1]) * args.cell_size[1], | |
| (float(cell[2]) + s.element_coords[2]) * args.cell_size[2], | |
| ) | |
| + args.origin | |
| ) | |
| def cell_deformation_gradient(args: CellArg, s: Sample): | |
| return wp.diag(args.cell_size) | |
| def cell_inverse_deformation_gradient(args: CellArg, s: Sample): | |
| return wp.diag(wp.cw_div(wp.vec3(1.0), args.cell_size)) | |
| def cell_lookup(args: CellArg, pos: wp.vec3): | |
| loc_pos = wp.cw_div(pos - args.origin, args.cell_size) | |
| x = wp.clamp(loc_pos[0], 0.0, float(args.res[0])) | |
| y = wp.clamp(loc_pos[1], 0.0, float(args.res[1])) | |
| z = wp.clamp(loc_pos[2], 0.0, float(args.res[2])) | |
| x_cell = wp.min(wp.floor(x), float(args.res[0]) - 1.0) | |
| y_cell = wp.min(wp.floor(y), float(args.res[1]) - 1.0) | |
| z_cell = wp.min(wp.floor(z), float(args.res[2]) - 1.0) | |
| coords = Coords(x - x_cell, y - y_cell, z - z_cell) | |
| cell_index = Grid3D.cell_index(args.res, Grid3D.Cell(int(x_cell), int(y_cell), int(z_cell))) | |
| return make_free_sample(cell_index, coords) | |
| def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample): | |
| return Grid3D.cell_lookup(args, pos) | |
| def cell_measure(args: CellArg, s: Sample): | |
| return args.cell_size[0] * args.cell_size[1] * args.cell_size[2] | |
| def cell_normal(args: CellArg, s: Sample): | |
| return wp.vec3(0.0) | |
| def cell_transform_reference_gradient(args: CellArg, cell_index: ElementIndex, coords: Coords, ref_grad: wp.vec3): | |
| return wp.cw_div(ref_grad, args.cell_size) | |
| def side_arg_value(self, device) -> SideArg: | |
| args = self.SideArg() | |
| axis_dims = wp.vec3i( | |
| self.res[1] * self.res[2], | |
| self.res[2] * self.res[0], | |
| self.res[0] * self.res[1], | |
| ) | |
| args.axis_offsets = wp.vec3i( | |
| 0, | |
| axis_dims[0], | |
| axis_dims[0] + axis_dims[1], | |
| ) | |
| args.cell_count = self.cell_count() | |
| args.cell_arg = self.cell_arg_value(device) | |
| return args | |
| def side_index_arg_value(self, device) -> SideIndexArg: | |
| return self.side_arg_value(device) | |
| def boundary_side_index(args: SideArg, boundary_side_index: int): | |
| """Boundary side to side index""" | |
| axis_side_index = boundary_side_index // 2 | |
| border = boundary_side_index - 2 * axis_side_index | |
| if axis_side_index < args.axis_offsets[1]: | |
| axis = 0 | |
| elif axis_side_index < args.axis_offsets[2]: | |
| axis = 1 | |
| else: | |
| axis = 2 | |
| lat_long = axis_side_index - args.axis_offsets[axis] | |
| latitude_res = args.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]] | |
| longitude = lat_long // latitude_res | |
| latitude = lat_long - longitude * latitude_res | |
| altitude = border * args.cell_arg.res[axis] | |
| side = Grid3D.Side(axis, wp.vec3i(altitude, longitude, latitude)) | |
| return Grid3D.side_index(args, side) | |
| def side_position(args: SideArg, s: Sample): | |
| side = Grid3D.get_side(args, s.element_index) | |
| coord0 = wp.select(side.origin[0] == 0, s.element_coords[0], 1.0 - s.element_coords[0]) | |
| local_pos = wp.vec3( | |
| float(side.origin[0]), | |
| float(side.origin[1]) + coord0, | |
| float(side.origin[2]) + s.element_coords[1], | |
| ) | |
| pos = args.cell_arg.origin + wp.cw_mul(Grid3D._local_to_world(side.axis, local_pos), args.cell_arg.cell_size) | |
| return pos | |
| def side_deformation_gradient(args: SideArg, s: Sample): | |
| side = Grid3D.get_side(args, s.element_index) | |
| sign = wp.select(side.origin[0] == 0, 1.0, -1.0) | |
| return _mat32( | |
| wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, sign, 0.0)), args.cell_arg.cell_size), | |
| wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, 0.0, 1.0)), args.cell_arg.cell_size), | |
| ) | |
| def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample): | |
| return Grid3D.cell_inverse_deformation_gradient(args.cell_arg, s) | |
| def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample): | |
| return Grid3D.cell_inverse_deformation_gradient(args.cell_arg, s) | |
| def side_measure(args: SideArg, s: Sample): | |
| side = Grid3D.get_side(args, s.element_index) | |
| long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1] | |
| lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2] | |
| return args.cell_arg.cell_size[long_axis] * args.cell_arg.cell_size[lat_axis] | |
| def side_measure_ratio(args: SideArg, s: Sample): | |
| side = Grid3D.get_side(args, s.element_index) | |
| alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0] | |
| return 1.0 / args.cell_arg.cell_size[alt_axis] | |
| def side_normal(args: SideArg, s: Sample): | |
| side = Grid3D.get_side(args, s.element_index) | |
| sign = wp.select(side.origin[0] == 0, 1.0, -1.0) | |
| local_n = wp.vec3(sign, 0.0, 0.0) | |
| return Grid3D._local_to_world(side.axis, local_n) | |
| def side_inner_cell_index(arg: SideArg, side_index: ElementIndex): | |
| side = Grid3D.get_side(arg, side_index) | |
| inner_alt = wp.select(side.origin[0] == 0, side.origin[0] - 1, 0) | |
| inner_origin = wp.vec3i(inner_alt, side.origin[1], side.origin[2]) | |
| cell = Grid3D._local_to_world(side.axis, inner_origin) | |
| return Grid3D.cell_index(arg.cell_arg.res, cell) | |
| def side_outer_cell_index(arg: SideArg, side_index: ElementIndex): | |
| side = Grid3D.get_side(arg, side_index) | |
| alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0] | |
| outer_alt = wp.select( | |
| side.origin[0] == arg.cell_arg.res[alt_axis], side.origin[0], arg.cell_arg.res[alt_axis] - 1 | |
| ) | |
| outer_origin = wp.vec3i(outer_alt, side.origin[1], side.origin[2]) | |
| cell = Grid3D._local_to_world(side.axis, outer_origin) | |
| return Grid3D.cell_index(arg.cell_arg.res, cell) | |
| def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords): | |
| side = Grid3D.get_side(args, side_index) | |
| inner_alt = wp.select(side.origin[0] == 0, 1.0, 0.0) | |
| side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0]) | |
| return Grid3D._local_to_world(side.axis, wp.vec3(inner_alt, side_coord0, side_coords[1])) | |
| def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords): | |
| side = Grid3D.get_side(args, side_index) | |
| alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0] | |
| outer_alt = wp.select(side.origin[0] == args.cell_arg.res[alt_axis], 0.0, 1.0) | |
| side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0]) | |
| return Grid3D._local_to_world(side.axis, wp.vec3(outer_alt, side_coord0, side_coords[1])) | |
| def side_from_cell_coords( | |
| args: SideArg, | |
| side_index: ElementIndex, | |
| element_index: ElementIndex, | |
| element_coords: Coords, | |
| ): | |
| side = Grid3D.get_side(args, side_index) | |
| cell = Grid3D.get_cell(args.cell_arg.res, element_index) | |
| if float(side.origin[0] - cell[side.axis]) == element_coords[side.axis]: | |
| long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1] | |
| lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2] | |
| long_coord = element_coords[long_axis] | |
| long_coord = wp.select(side.origin[0] == 0, long_coord, 1.0 - long_coord) | |
| return Coords(long_coord, element_coords[lat_axis], 0.0) | |
| return Coords(OUTSIDE) | |
| def side_to_cell_arg(side_arg: SideArg): | |
| return side_arg.cell_arg | |