0xZohar commited on
Commit
681b5d5
·
verified ·
1 Parent(s): 8fc4fa7

Add code/cube3d/train.py

Browse files
Files changed (1) hide show
  1. code/cube3d/train.py +250 -0
code/cube3d/train.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import numpy as np
4
+ from accelerate import Accelerator
5
+ import torch
6
+ import trimesh
7
+
8
+ torch.autograd.set_detect_anomaly(True)
9
+
10
+ from cube3d.training.trainer import Trainer
11
+ from cube3d.training.bert_infer import Infer
12
+ from cube3d.training.engine import Engine, EngineFast
13
+ from cube3d.training.utils import normalize_bbox, select_device
14
+ from cube3d.training.dataset import CubeDataset, LegosDataset, LegosTestDataset
15
+
16
+ MESH_SCALE = 0.96
17
+
18
+ try:
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ TENSORBOARD_FOUND = True
21
+ except ImportError:
22
+ TENSORBOARD_FOUND = False
23
+
24
+
25
+ def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
26
+ """Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
27
+ vertices = vertices
28
+ bbmin = vertices.min(0)
29
+ bbmax = vertices.max(0)
30
+ center = (bbmin + bbmax) * 0.5
31
+ scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
32
+ vertices = (vertices - center) * scale
33
+ return vertices
34
+
35
+
36
+ def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
37
+ """
38
+ Load a mesh and scale it to a unit cube, and clean the mesh.
39
+ Parameters:
40
+ file_obj: str | IO
41
+ file_type: str
42
+ Returns:
43
+ mesh: trimesh.Trimesh
44
+ """
45
+ mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
46
+ mesh.remove_infinite_values()
47
+ mesh.update_faces(mesh.nondegenerate_faces())
48
+ mesh.update_faces(mesh.unique_faces())
49
+ mesh.remove_unreferenced_vertices()
50
+ if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
51
+ raise ValueError("Mesh has no vertices or faces after cleaning")
52
+ mesh.vertices = rescale(mesh.vertices)
53
+ return mesh
54
+
55
+
56
+ def load_and_process_mesh(file_path: str, n_samples: int = 8192):
57
+ """
58
+ Loads a 3D mesh from the specified file path, samples points from its surface,
59
+ and processes the sampled points into a point cloud with normals.
60
+ Args:
61
+ file_path (str): The file path to the 3D mesh file.
62
+ n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
63
+ Returns:
64
+ torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
65
+ Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
66
+ """
67
+
68
+ mesh = load_scaled_mesh(file_path)
69
+ positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
70
+ normals = mesh.face_normals[face_indices]
71
+ point_cloud = np.concatenate(
72
+ [positions, normals], axis=1
73
+ ) # Shape: (num_samples, 6)
74
+ point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
75
+ return point_cloud
76
+
77
+ if __name__ == "__main__":
78
+
79
+ parser = argparse.ArgumentParser(description="cube shape generation script")
80
+ parser.add_argument(
81
+ "--config-path",
82
+ type=str,
83
+ default="cube3d/configs/open_model_v0.5.yaml",
84
+ help="Path to the configuration YAML file.",
85
+ )
86
+ parser.add_argument(
87
+ "--mesh-path",
88
+ type=str,
89
+ required=True,
90
+ help="Path to the input mesh file.",
91
+ )
92
+ parser.add_argument(
93
+ "--data-dir",
94
+ type=str,
95
+ required=True,
96
+ help="Path to the input dataset file.",
97
+ )
98
+ parser.add_argument(
99
+ "--gpt-ckpt-path",
100
+ type=str,
101
+ required=True,
102
+ help="Path to the main GPT checkpoint file.",
103
+ )
104
+ parser.add_argument(
105
+ "--save-gpt-ckpt-path",
106
+ type=str,
107
+ required=True,
108
+ help="Path to the save main GPT checkpoint file.",
109
+ )
110
+ parser.add_argument(
111
+ "--shape-ckpt-path",
112
+ type=str,
113
+ required=True,
114
+ help="Path to the shape encoder/decoder checkpoint file.",
115
+ )
116
+ parser.add_argument(
117
+ "--expname",
118
+ type=str,
119
+ required=True,
120
+ help="Path to the tensorboard file.",
121
+ )
122
+ parser.add_argument(
123
+ "--fast-training",
124
+ help="Use optimized training with cuda graphs",
125
+ default=False,
126
+ action="store_true",
127
+ )
128
+ parser.add_argument(
129
+ "--prompt",
130
+ type=str,
131
+ required=True,
132
+ help="Text prompt for generating a 3D mesh",
133
+ )
134
+ parser.add_argument(
135
+ "--top-p",
136
+ type=float,
137
+ default=None,
138
+ help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.",
139
+ )
140
+ parser.add_argument(
141
+ "--bounding-box-xyz",
142
+ nargs=3,
143
+ type=float,
144
+ help="Three float values for x, y, z bounding box",
145
+ default=None,
146
+ required=False,
147
+ )
148
+ parser.add_argument(
149
+ "--render-gif",
150
+ help="Render a turntable gif of the mesh",
151
+ default=False,
152
+ action="store_true",
153
+ )
154
+ parser.add_argument(
155
+ "--disable-postprocessing",
156
+ help="Disable postprocessing on the mesh. This will result in a mesh with more faces.",
157
+ default=False,
158
+ action="store_true",
159
+ )
160
+ parser.add_argument(
161
+ "--resolution-base",
162
+ type=float,
163
+ default=8.0,
164
+ help="Resolution base for the shape decoder.",
165
+ )
166
+ args = parser.parse_args()
167
+ # Create Tensorboard writer
168
+ tb_writer = None
169
+ if TENSORBOARD_FOUND:
170
+ tb_writer = SummaryWriter(log_dir=os.path.join('runs', args.expname))
171
+ else:
172
+ print("Tensorboard not available: not logging progress")
173
+
174
+ device = select_device()
175
+ print(f"Using device: {device}")
176
+
177
+ mode = 'test'
178
+
179
+ accelerator = Accelerator()
180
+ # Initialize engine based on fast_training flag
181
+ if args.fast_training:
182
+ print(
183
+ "Using cuda graphs, this will take some time to warmup and capture the graph."
184
+ )
185
+ engine = EngineFast(
186
+ args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, args.save_gpt_ckpt_path, device=accelerator.device, mode=mode #device
187
+ )
188
+ print("Compiled the graph.")
189
+ else:
190
+ engine = Engine(
191
+ args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
192
+ )
193
+
194
+ if args.bounding_box_xyz is not None:
195
+ args.bounding_box_xyz = normalize_bbox(tuple(args.bounding_box_xyz))
196
+
197
+
198
+ point_cloud = load_and_process_mesh(args.mesh_path)
199
+ output = engine.shape_model.encode(point_cloud.to(device)) #
200
+
201
+ indices = output[3]["indices"]
202
+ print("Got the following shape indices:")
203
+ print(indices)
204
+ print("Indices shape: ", indices.shape)
205
+
206
+ train_config = Trainer.get_default_config()
207
+ train_config.learning_rate = 5e-4 # many possible options, see the file
208
+ train_config.max_iters = 40000
209
+ train_config.batch_size = 1 if mode=='test' else 28
210
+ train_config.save_interval = 1000
211
+
212
+ train_dataset = LegosDataset(args)
213
+ test_dataset = LegosTestDataset(args)
214
+
215
+ dataset = test_dataset if mode=='test' else train_dataset
216
+
217
+ if mode!='test':
218
+ trainer = Trainer(
219
+ config=train_config,
220
+ engine=engine,
221
+ accelerator=accelerator,
222
+ tb=tb_writer,
223
+ prompt=args.prompt,
224
+ train_dataset=dataset,
225
+ indices=indices,
226
+ resolution_base=args.resolution_base,
227
+ disable_postprocessing=args.disable_postprocessing,
228
+ top_p=args.top_p,
229
+ bounding_box_xyz=args.bounding_box_xyz,
230
+ save_gpt_ckpt_path=args.save_gpt_ckpt_path,
231
+ mode = mode
232
+ )
233
+ trainer.run()
234
+ else:
235
+ infer = Infer(
236
+ config=train_config,
237
+ engine=engine,
238
+ accelerator=accelerator,
239
+ tb=tb_writer,
240
+ prompt=args.prompt,
241
+ train_dataset=dataset,
242
+ indices=indices,
243
+ resolution_base=args.resolution_base,
244
+ disable_postprocessing=args.disable_postprocessing,
245
+ top_p=args.top_p,
246
+ bounding_box_xyz=args.bounding_box_xyz,
247
+ save_gpt_ckpt_path=args.save_gpt_ckpt_path,
248
+ mode = mode
249
+ )
250
+ infer.run()