0xZohar commited on
Commit
8fc4fa7
·
verified ·
1 Parent(s): 25cd1c4

Add code/cube3d/vq_vae_encode_decode.py

Browse files
Files changed (1) hide show
  1. code/cube3d/vq_vae_encode_decode.py +164 -0
code/cube3d/vq_vae_encode_decode.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import json
4
+ import os
5
+
6
+ import numpy as np
7
+ import torch
8
+ import trimesh
9
+
10
+ from cube3d.inference.utils import load_config, load_model_weights, parse_structured, select_device
11
+ from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
12
+
13
+ MESH_SCALE = 0.96
14
+
15
+
16
+ def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
17
+ """Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
18
+ #import ipdb; ipdb.set_trace()
19
+ vertices = vertices
20
+ bbmin = vertices.min(0)
21
+ bbmax = vertices.max(0)
22
+ center = (bbmin + bbmax) * 0.5
23
+ scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
24
+ vertices = (vertices - center) * scale
25
+ return vertices
26
+
27
+
28
+ def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
29
+ """
30
+ Load a mesh and scale it to a unit cube, and clean the mesh.
31
+ Parameters:
32
+ file_obj: str | IO
33
+ file_type: str
34
+ Returns:
35
+ mesh: trimesh.Trimesh
36
+ """
37
+ mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
38
+ mesh.remove_infinite_values()
39
+ mesh.update_faces(mesh.nondegenerate_faces())
40
+ mesh.update_faces(mesh.unique_faces())
41
+ mesh.remove_unreferenced_vertices()
42
+ if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
43
+ raise ValueError("Mesh has no vertices or faces after cleaning")
44
+ mesh.vertices = rescale(mesh.vertices)
45
+ return mesh
46
+
47
+
48
+ def load_and_process_mesh(file_path: str, n_samples: int = 8192):
49
+ """
50
+ Loads a 3D mesh from the specified file path, samples points from its surface,
51
+ and processes the sampled points into a point cloud with normals.
52
+ Args:
53
+ file_path (str): The file path to the 3D mesh file.
54
+ n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
55
+ Returns:
56
+ torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
57
+ Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
58
+ """
59
+
60
+ mesh = load_scaled_mesh(file_path)
61
+ positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
62
+ normals = mesh.face_normals[face_indices]
63
+ point_cloud = np.concatenate(
64
+ [positions, normals], axis=1
65
+ ) # Shape: (num_samples, 6)
66
+ point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
67
+ return point_cloud, mesh
68
+
69
+
70
+ @torch.inference_mode()
71
+ def run_shape_decode(
72
+ shape_model: OneDAutoEncoder,
73
+ output_ids: torch.Tensor,
74
+ resolution_base: float = 8.0,
75
+ chunk_size: int = 100_000,
76
+ ):
77
+ """
78
+ Decodes the shape from the given output IDs and extracts the geometry.
79
+ Args:
80
+ shape_model (OneDAutoEncoder): The shape model.
81
+ output_ids (torch.Tensor): The tensor containing the output IDs.
82
+ resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
83
+ chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
84
+ Returns:
85
+ tuple: A tuple containing the vertices and faces of the mesh.
86
+ """
87
+ shape_ids = (
88
+ output_ids[:, : shape_model.cfg.num_encoder_latents, ...]
89
+ .clamp_(0, shape_model.cfg.num_codes - 1)
90
+ .view(-1, shape_model.cfg.num_encoder_latents)
91
+ )
92
+ latents = shape_model.decode_indices(shape_ids)
93
+ mesh_v_f, _ = shape_model.extract_geometry(
94
+ latents,
95
+ resolution_base=resolution_base,
96
+ chunk_size=chunk_size,
97
+ use_warp=True,
98
+ )
99
+ return mesh_v_f
100
+
101
+
102
+ if __name__ == "__main__":
103
+ parser = argparse.ArgumentParser(
104
+ description="cube shape encode and decode example script"
105
+ )
106
+ parser.add_argument(
107
+ "--mesh-path",
108
+ type=str,
109
+ required=True,
110
+ help="Path to the input mesh file.",
111
+ )
112
+ parser.add_argument(
113
+ "--config-path",
114
+ type=str,
115
+ default="cube3d/configs/open_model.yaml",
116
+ help="Path to the configuration YAML file.",
117
+ )
118
+ parser.add_argument(
119
+ "--shape-ckpt-path",
120
+ type=str,
121
+ required=True,
122
+ help="Path to the shape encoder/decoder checkpoint file.",
123
+ )
124
+ parser.add_argument(
125
+ "--recovered-mesh-path",
126
+ type=str,
127
+ default="recovered_mesh.obj",
128
+ help="Path to save the recovered mesh file.",
129
+ )
130
+ args = parser.parse_args()
131
+ device = select_device()
132
+ logging.info(f"Using device: {device}")
133
+
134
+ cfg = load_config(args.config_path)
135
+
136
+ #import ipdb; ipdb.set_trace()
137
+ shape_model = OneDAutoEncoder(
138
+ parse_structured(OneDAutoEncoder.Config, cfg.shape_model)
139
+ )
140
+ load_model_weights(
141
+ shape_model,
142
+ args.shape_ckpt_path,
143
+ )
144
+ shape_model = shape_model.eval().to(device)
145
+ point_cloud, rescale_mesh = load_and_process_mesh(args.mesh_path)
146
+ rescale_mesh.export(args.recovered_mesh_path.replace('recovered', 'rescaled'))
147
+
148
+ output = shape_model.encode(point_cloud.to(device))
149
+ indices = output[3]["indices"]
150
+ # print("Got the following shape indices:")
151
+ # print(indices)
152
+ # print("Indices shape: ", indices.shape)
153
+
154
+ indices_list = indices.detach().cpu().view(-1).tolist()
155
+ print(json.dumps({
156
+ "mesh": os.path.basename(args.mesh_path),
157
+ "latent_ids": indices_list
158
+ }), flush=True)
159
+
160
+ mesh_v_f = run_shape_decode(shape_model, indices)
161
+ vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
162
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
163
+ mesh.export(args.recovered_mesh_path)
164
+