0xZohar commited on
Commit
43f661a
·
verified ·
1 Parent(s): 0a921ba

Add code/cube3d/model/autoencoder/grid.py

Browse files
code/cube3d/model/autoencoder/grid.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import warp as wp
6
+
7
+
8
+ def generate_dense_grid_points(
9
+ bbox_min: np.ndarray,
10
+ bbox_max: np.ndarray,
11
+ resolution_base: float,
12
+ indexing: Literal["xy", "ij"] = "ij",
13
+ ) -> tuple[np.ndarray, list[int], np.ndarray]:
14
+ """
15
+ Generate a dense grid of points within a bounding box.
16
+
17
+ Parameters:
18
+ bbox_min (np.ndarray): The minimum coordinates of the bounding box (3D).
19
+ bbox_max (np.ndarray): The maximum coordinates of the bounding box (3D).
20
+ resolution_base (float): The base resolution for the grid. The number of cells along each axis will be 2^resolution_base.
21
+ indexing (Literal["xy", "ij"], optional): The indexing convention for the grid. "xy" for Cartesian indexing, "ij" for matrix indexing. Default is "ij".
22
+ Returns:
23
+ tuple: A tuple containing:
24
+ - xyz (np.ndarray): A 2D array of shape (N, 3) where N is the total number of grid points. Each row represents the (x, y, z) coordinates of a grid point.
25
+ - grid_size (list): A list of three integers representing the number of grid points along each axis.
26
+ - length (np.ndarray): The length of the bounding box along each axis.
27
+ """
28
+ length = bbox_max - bbox_min
29
+ num_cells = np.exp2(resolution_base)
30
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
31
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
32
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
33
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
34
+ xyz = np.stack((xs, ys, zs), axis=-1)
35
+ xyz = xyz.reshape(-1, 3)
36
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
37
+
38
+ return xyz, grid_size, length
39
+
40
+
41
+ def marching_cubes_with_warp(
42
+ grid_logits: torch.Tensor,
43
+ level: float,
44
+ device: Union[str, torch.device] = "cuda",
45
+ max_verts: int = 3_000_000,
46
+ max_tris: int = 3_000_000,
47
+ ) -> tuple[np.ndarray, np.ndarray]:
48
+ """
49
+ Perform the marching cubes algorithm on a 3D grid with warp support.
50
+ Args:
51
+ grid_logits (torch.Tensor): A 3D tensor containing the grid logits.
52
+ level (float): The threshold level for the isosurface.
53
+ device (Union[str, torch.device], optional): The device to perform the computation on. Defaults to "cuda".
54
+ max_verts (int, optional): The maximum number of vertices. Defaults to 3,000,000.
55
+ max_tris (int, optional): The maximum number of triangles. Defaults to 3,000,000.
56
+ Returns:
57
+ Tuple[np.ndarray, np.ndarray]: A tuple containing the vertices and faces of the isosurface.
58
+ """
59
+ if isinstance(device, torch.device):
60
+ device = str(device)
61
+
62
+ assert grid_logits.ndim == 3
63
+ if "cuda" in device:
64
+ assert wp.is_cuda_available()
65
+ else:
66
+ raise ValueError(
67
+ f"Device {device} is not supported for marching_cubes_with_warp"
68
+ )
69
+
70
+ dim = grid_logits.shape[0]
71
+ field = wp.from_torch(grid_logits)
72
+
73
+ iso = wp.MarchingCubes(
74
+ nx=dim,
75
+ ny=dim,
76
+ nz=dim,
77
+ max_verts=int(max_verts),
78
+ max_tris=int(max_tris),
79
+ device=device,
80
+ )
81
+ iso.surface(field=field, threshold=level)
82
+ vertices = iso.verts.numpy()
83
+ faces = iso.indices.numpy().reshape(-1, 3)
84
+ return vertices, faces