File size: 14,500 Bytes
66c9c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
from typing import Any
import warp as wp

from warp.fem.types import ElementIndex, Coords, Sample, make_free_sample, OUTSIDE
from warp.fem.cache import cached_arg_value

from .geometry import Geometry
from .element import Square, Cube


@wp.struct
class Grid3DCellArg:
    res: wp.vec3i
    cell_size: wp.vec3
    origin: wp.vec3


_mat32 = wp.mat(shape=(3, 2), dtype=float)


class Grid3D(Geometry):
    """Three-dimensional regular grid geometry"""

    dimension = 3

    Permutation = wp.types.matrix(shape=(3, 3), dtype=int)
    LOC_TO_WORLD = wp.constant(Permutation(0, 1, 2, 1, 2, 0, 2, 0, 1))
    WORLD_TO_LOC = wp.constant(Permutation(0, 1, 2, 2, 0, 1, 1, 2, 0))

    def __init__(self, res: wp.vec3i, bounds_lo: wp.vec3 = wp.vec3(0.0), bounds_hi: wp.vec3 = wp.vec3(1.0)):
        """Constructs a dense 3D grid

        Args:
            res: Resolution of the grid along each dimension
            bounds_lo: Position of the lower bound of the axis-aligned grid
            bounds_up: Position of the upper bound of the axis-aligned grid
        """

        self.bounds_lo = bounds_lo
        self.bounds_hi = bounds_hi

        self._res = res

    @property
    def extents(self) -> wp.vec3:
        # Avoid using native sub due to higher over of calling builtins from Python
        return wp.vec3(
            self.bounds_hi[0] - self.bounds_lo[0],
            self.bounds_hi[1] - self.bounds_lo[1],
            self.bounds_hi[2] - self.bounds_lo[2],
        )

    @property
    def cell_size(self) -> wp.vec3:
        ex = self.extents
        return wp.vec3(
            ex[0] / self.res[0],
            ex[1] / self.res[1],
            ex[2] / self.res[2],
        )

    def cell_count(self):
        return self.res[0] * self.res[1] * self.res[2]

    def vertex_count(self):
        return (self.res[0] + 1) * (self.res[1] + 1) * (self.res[2] + 1)

    def side_count(self):
        return (
            (self.res[0] + 1) * (self.res[1]) * (self.res[2])
            + (self.res[0]) * (self.res[1] + 1) * (self.res[2])
            + (self.res[0]) * (self.res[1]) * (self.res[2] + 1)
        )

    def edge_count(self):
        return (
            (self.res[0] + 1) * (self.res[1] + 1) * (self.res[2])
            + (self.res[0]) * (self.res[1] + 1) * (self.res[2] + 1)
            + (self.res[0] + 1) * (self.res[1]) * (self.res[2] + 1)
        )

    def boundary_side_count(self):
        return 2 * (self.res[1]) * (self.res[2]) + (self.res[0]) * 2 * (self.res[2]) + (self.res[0]) * (self.res[1]) * 2

    def reference_cell(self) -> Cube:
        return Cube()

    def reference_side(self) -> Square:
        return Square()

    @property
    def res(self):
        return self._res

    @property
    def origin(self):
        return self.bounds_lo

    @property
    def strides(self):
        return wp.vec3i(self.res[1] * self.res[2], self.res[2], 1)

    # Utility device functions

    CellArg = Grid3DCellArg
    Cell = wp.vec3i

    @wp.func
    def _to_3d_index(strides: wp.vec2i, index: int):
        x = index // strides[0]
        y = (index - strides[0] * x) // strides[1]
        z = index - strides[0] * x - strides[1] * y
        return wp.vec3i(x, y, z)

    @wp.func
    def _from_3d_index(strides: wp.vec2i, index: wp.vec3i):
        return strides[0] * index[0] + strides[1] * index[1] + index[2]

    @wp.func
    def cell_index(res: wp.vec3i, cell: Cell):
        strides = wp.vec2i(res[1] * res[2], res[2])
        return Grid3D._from_3d_index(strides, cell)

    @wp.func
    def get_cell(res: wp.vec3i, cell_index: ElementIndex):
        strides = wp.vec2i(res[1] * res[2], res[2])
        return Grid3D._to_3d_index(strides, cell_index)

    @wp.struct
    class Side:
        axis: int  # normal
        origin: wp.vec3i  # index of vertex at corner (0,0,0)

    @wp.struct
    class SideArg:
        cell_count: int
        axis_offsets: wp.vec3i
        cell_arg: Grid3DCellArg

    SideIndexArg = SideArg

    @wp.func
    def _world_to_local(axis: int, vec: Any):
        return type(vec)(
            vec[Grid3D.LOC_TO_WORLD[axis, 0]],
            vec[Grid3D.LOC_TO_WORLD[axis, 1]],
            vec[Grid3D.LOC_TO_WORLD[axis, 2]],
        )

    @wp.func
    def _local_to_world(axis: int, vec: Any):
        return type(vec)(
            vec[Grid3D.WORLD_TO_LOC[axis, 0]],
            vec[Grid3D.WORLD_TO_LOC[axis, 1]],
            vec[Grid3D.WORLD_TO_LOC[axis, 2]],
        )

    @wp.func
    def side_index(arg: SideArg, side: Side):
        alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
        if side.origin[0] == arg.cell_arg.res[alt_axis]:
            # Upper-boundary side
            longitude = side.origin[1]
            latitude = side.origin[2]

            latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[side.axis, 2]]
            lat_long = latitude_res * longitude + latitude

            return 3 * arg.cell_count + arg.axis_offsets[side.axis] + lat_long

        cell_index = Grid3D.cell_index(arg.cell_arg.res, Grid3D._local_to_world(side.axis, side.origin))
        return side.axis * arg.cell_count + cell_index

    @wp.func
    def get_side(arg: SideArg, side_index: ElementIndex):
        if side_index < 3 * arg.cell_count:
            axis = side_index // arg.cell_count
            cell_index = side_index - axis * arg.cell_count
            origin = Grid3D._world_to_local(axis, Grid3D.get_cell(arg.cell_arg.res, cell_index))
            return Grid3D.Side(axis, origin)

        axis_side_index = side_index - 3 * arg.cell_count
        if axis_side_index < arg.axis_offsets[1]:
            axis = 0
        elif axis_side_index < arg.axis_offsets[2]:
            axis = 1
        else:
            axis = 2

        altitude = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 0]]

        lat_long = axis_side_index - arg.axis_offsets[axis]
        latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]]

        longitude = lat_long // latitude_res
        latitude = lat_long - longitude * latitude_res

        origin_loc = wp.vec3i(altitude, longitude, latitude)

        return Grid3D.Side(axis, origin_loc)

    # Geometry device interface

    @cached_arg_value
    def cell_arg_value(self, device) -> CellArg:
        args = self.CellArg()
        args.res = self.res
        args.origin = self.bounds_lo
        args.cell_size = self.cell_size
        return args

    @wp.func
    def cell_position(args: CellArg, s: Sample):
        cell = Grid3D.get_cell(args.res, s.element_index)
        return (
            wp.vec3(
                (float(cell[0]) + s.element_coords[0]) * args.cell_size[0],
                (float(cell[1]) + s.element_coords[1]) * args.cell_size[1],
                (float(cell[2]) + s.element_coords[2]) * args.cell_size[2],
            )
            + args.origin
        )

    @wp.func
    def cell_deformation_gradient(args: CellArg, s: Sample):
        return wp.diag(args.cell_size)

    @wp.func
    def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
        return wp.diag(wp.cw_div(wp.vec3(1.0), args.cell_size))

    @wp.func
    def cell_lookup(args: CellArg, pos: wp.vec3):
        loc_pos = wp.cw_div(pos - args.origin, args.cell_size)
        x = wp.clamp(loc_pos[0], 0.0, float(args.res[0]))
        y = wp.clamp(loc_pos[1], 0.0, float(args.res[1]))
        z = wp.clamp(loc_pos[2], 0.0, float(args.res[2]))

        x_cell = wp.min(wp.floor(x), float(args.res[0]) - 1.0)
        y_cell = wp.min(wp.floor(y), float(args.res[1]) - 1.0)
        z_cell = wp.min(wp.floor(z), float(args.res[2]) - 1.0)

        coords = Coords(x - x_cell, y - y_cell, z - z_cell)
        cell_index = Grid3D.cell_index(args.res, Grid3D.Cell(int(x_cell), int(y_cell), int(z_cell)))

        return make_free_sample(cell_index, coords)

    @wp.func
    def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
        return Grid3D.cell_lookup(args, pos)

    @wp.func
    def cell_measure(args: CellArg, s: Sample):
        return args.cell_size[0] * args.cell_size[1] * args.cell_size[2]

    @wp.func
    def cell_normal(args: CellArg, s: Sample):
        return wp.vec3(0.0)

    @wp.func
    def cell_transform_reference_gradient(args: CellArg, cell_index: ElementIndex, coords: Coords, ref_grad: wp.vec3):
        return wp.cw_div(ref_grad, args.cell_size)

    @cached_arg_value
    def side_arg_value(self, device) -> SideArg:
        args = self.SideArg()

        axis_dims = wp.vec3i(
            self.res[1] * self.res[2],
            self.res[2] * self.res[0],
            self.res[0] * self.res[1],
        )
        args.axis_offsets = wp.vec3i(
            0,
            axis_dims[0],
            axis_dims[0] + axis_dims[1],
        )
        args.cell_count = self.cell_count()
        args.cell_arg = self.cell_arg_value(device)
        return args

    def side_index_arg_value(self, device) -> SideIndexArg:
        return self.side_arg_value(device)

    @wp.func
    def boundary_side_index(args: SideArg, boundary_side_index: int):
        """Boundary side to side index"""

        axis_side_index = boundary_side_index // 2
        border = boundary_side_index - 2 * axis_side_index

        if axis_side_index < args.axis_offsets[1]:
            axis = 0
        elif axis_side_index < args.axis_offsets[2]:
            axis = 1
        else:
            axis = 2

        lat_long = axis_side_index - args.axis_offsets[axis]
        latitude_res = args.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]]

        longitude = lat_long // latitude_res
        latitude = lat_long - longitude * latitude_res

        altitude = border * args.cell_arg.res[axis]

        side = Grid3D.Side(axis, wp.vec3i(altitude, longitude, latitude))
        return Grid3D.side_index(args, side)

    @wp.func
    def side_position(args: SideArg, s: Sample):
        side = Grid3D.get_side(args, s.element_index)

        coord0 = wp.select(side.origin[0] == 0, s.element_coords[0], 1.0 - s.element_coords[0])

        local_pos = wp.vec3(
            float(side.origin[0]),
            float(side.origin[1]) + coord0,
            float(side.origin[2]) + s.element_coords[1],
        )

        pos = args.cell_arg.origin + wp.cw_mul(Grid3D._local_to_world(side.axis, local_pos), args.cell_arg.cell_size)

        return pos

    @wp.func
    def side_deformation_gradient(args: SideArg, s: Sample):
        side = Grid3D.get_side(args, s.element_index)

        sign = wp.select(side.origin[0] == 0, 1.0, -1.0)

        return _mat32(
            wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, sign, 0.0)), args.cell_arg.cell_size),
            wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, 0.0, 1.0)), args.cell_arg.cell_size),
        )

    @wp.func
    def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
        return Grid3D.cell_inverse_deformation_gradient(args.cell_arg, s)

    @wp.func
    def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
        return Grid3D.cell_inverse_deformation_gradient(args.cell_arg, s)

    @wp.func
    def side_measure(args: SideArg, s: Sample):
        side = Grid3D.get_side(args, s.element_index)
        long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1]
        lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2]
        return args.cell_arg.cell_size[long_axis] * args.cell_arg.cell_size[lat_axis]

    @wp.func
    def side_measure_ratio(args: SideArg, s: Sample):
        side = Grid3D.get_side(args, s.element_index)
        alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
        return 1.0 / args.cell_arg.cell_size[alt_axis]

    @wp.func
    def side_normal(args: SideArg, s: Sample):
        side = Grid3D.get_side(args, s.element_index)

        sign = wp.select(side.origin[0] == 0, 1.0, -1.0)

        local_n = wp.vec3(sign, 0.0, 0.0)
        return Grid3D._local_to_world(side.axis, local_n)

    @wp.func
    def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
        side = Grid3D.get_side(arg, side_index)

        inner_alt = wp.select(side.origin[0] == 0, side.origin[0] - 1, 0)

        inner_origin = wp.vec3i(inner_alt, side.origin[1], side.origin[2])

        cell = Grid3D._local_to_world(side.axis, inner_origin)
        return Grid3D.cell_index(arg.cell_arg.res, cell)

    @wp.func
    def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
        side = Grid3D.get_side(arg, side_index)

        alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]

        outer_alt = wp.select(
            side.origin[0] == arg.cell_arg.res[alt_axis], side.origin[0], arg.cell_arg.res[alt_axis] - 1
        )

        outer_origin = wp.vec3i(outer_alt, side.origin[1], side.origin[2])

        cell = Grid3D._local_to_world(side.axis, outer_origin)
        return Grid3D.cell_index(arg.cell_arg.res, cell)

    @wp.func
    def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
        side = Grid3D.get_side(args, side_index)

        inner_alt = wp.select(side.origin[0] == 0, 1.0, 0.0)

        side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])

        return Grid3D._local_to_world(side.axis, wp.vec3(inner_alt, side_coord0, side_coords[1]))

    @wp.func
    def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
        side = Grid3D.get_side(args, side_index)

        alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
        outer_alt = wp.select(side.origin[0] == args.cell_arg.res[alt_axis], 0.0, 1.0)

        side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])

        return Grid3D._local_to_world(side.axis, wp.vec3(outer_alt, side_coord0, side_coords[1]))

    @wp.func
    def side_from_cell_coords(
        args: SideArg,
        side_index: ElementIndex,
        element_index: ElementIndex,
        element_coords: Coords,
    ):
        side = Grid3D.get_side(args, side_index)
        cell = Grid3D.get_cell(args.cell_arg.res, element_index)

        if float(side.origin[0] - cell[side.axis]) == element_coords[side.axis]:
            long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1]
            lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2]
            long_coord = element_coords[long_axis]
            long_coord = wp.select(side.origin[0] == 0, long_coord, 1.0 - long_coord)
            return Coords(long_coord, element_coords[lat_axis], 0.0)

        return Coords(OUTSIDE)

    @wp.func
    def side_to_cell_arg(side_arg: SideArg):
        return side_arg.cell_arg