0xZohar commited on
Commit
bae26d1
·
verified ·
1 Parent(s): c1fc5b2

Upload code/cube3d/generate.py

Browse files
Files changed (1) hide show
  1. code/cube3d/generate.py +245 -0
code/cube3d/generate.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ import trimesh
6
+
7
+ from cube3d.inference.engine import Engine, EngineFast
8
+ from cube3d.inference.utils import normalize_bbox, select_device
9
+ from cube3d.mesh_utils.postprocessing import (
10
+ PYMESHLAB_AVAILABLE,
11
+ create_pymeshset,
12
+ postprocess_mesh,
13
+ save_mesh,
14
+ )
15
+ from cube3d.renderer import renderer
16
+ from cube3d.training.dataset import LegosTestDataset, LegosDataset
17
+ from torch.utils.data.dataloader import DataLoader
18
+ from cube3d.training.utils import normalize_bboxs
19
+
20
+ def generate_mesh(
21
+ engine,
22
+ prompt,
23
+ output_dir,
24
+ output_name,
25
+ resolution_base=8.0,
26
+ disable_postprocess=False,
27
+ top_p=None,
28
+ bounding_box_xyz=None,
29
+ ):
30
+ #import ipdb; ipdb.set_trace()
31
+ mesh_v_f = engine.t2s(
32
+ [prompt],
33
+ use_kv_cache=True,
34
+ resolution_base=resolution_base,
35
+ top_p=top_p,
36
+ bounding_box_xyz=bounding_box_xyz,
37
+ )
38
+ vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
39
+ obj_path = os.path.join(output_dir, f"{output_name}.obj")
40
+ if PYMESHLAB_AVAILABLE:
41
+ ms = create_pymeshset(vertices, faces)
42
+ if not disable_postprocess:
43
+ target_face_num = max(10000, int(faces.shape[0] * 0.1))
44
+ print(f"Postprocessing mesh to {target_face_num} faces")
45
+ postprocess_mesh(ms, target_face_num, obj_path)
46
+
47
+ save_mesh(ms, obj_path)
48
+ else:
49
+ print(
50
+ "WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing."
51
+ )
52
+ mesh = trimesh.Trimesh(vertices, faces)
53
+ mesh.export(obj_path)
54
+
55
+ return obj_path
56
+
57
+ def generate_ldr(
58
+ engine,
59
+ prompt,
60
+ inputs_ids,
61
+ output_dir,
62
+ output_name,
63
+ resolution_base=8.0,
64
+ disable_postprocess=False,
65
+ top_p=None,
66
+ bounding_box_xyz=None,
67
+ idx=None
68
+ ):
69
+ #import ipdb; ipdb.set_trace()
70
+ ldr = engine.t2l(
71
+ #[prompt],
72
+ prompt,
73
+ inputs_ids=inputs_ids,
74
+ use_kv_cache=True,
75
+ resolution_base=resolution_base,
76
+ top_p=top_p,
77
+ bounding_box_xyz=bounding_box_xyz,
78
+ idx=idx
79
+ )
80
+ # vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
81
+ # obj_path = os.path.join(output_dir, f"{output_name}.obj")
82
+ # if PYMESHLAB_AVAILABLE:
83
+ # ms = create_pymeshset(vertices, faces)
84
+ # if not disable_postprocess:
85
+ # target_face_num = max(10000, int(faces.shape[0] * 0.1))
86
+ # print(f"Postprocessing mesh to {target_face_num} faces")
87
+ # postprocess_mesh(ms, target_face_num, obj_path)
88
+
89
+ # save_mesh(ms, obj_path)
90
+ # else:
91
+ # print(
92
+ # "WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing."
93
+ # )
94
+ # mesh = trimesh.Trimesh(vertices, faces)
95
+ # mesh.export(obj_path)
96
+
97
+ return ldr
98
+
99
+
100
+ if __name__ == "__main__":
101
+ parser = argparse.ArgumentParser(description="cube shape generation script")
102
+ parser.add_argument(
103
+ "--config-path",
104
+ type=str,
105
+ default="cube3d/configs/open_model_v0.5.yaml",
106
+ help="Path to the configuration YAML file.",
107
+ )
108
+ parser.add_argument(
109
+ "--data-dir",
110
+ type=str,
111
+ required=True,
112
+ help="Path to the input dataset file.",
113
+ )
114
+ parser.add_argument(
115
+ "--output-dir",
116
+ type=str,
117
+ default="outputs/",
118
+ help="Path to the output directory to store .obj and .gif files",
119
+ )
120
+ parser.add_argument(
121
+ "--gpt-ckpt-path",
122
+ type=str,
123
+ required=True,
124
+ help="Path to the main GPT checkpoint file.",
125
+ )
126
+ parser.add_argument(
127
+ "--shape-ckpt-path",
128
+ type=str,
129
+ required=True,
130
+ help="Path to the shape encoder/decoder checkpoint file.",
131
+ )
132
+ parser.add_argument(
133
+ "--save-gpt-ckpt-path",
134
+ type=str,
135
+ required=True,
136
+ help="Path to the save adaption GPT checkpoint file.",
137
+ )
138
+ parser.add_argument(
139
+ "--fast-inference",
140
+ help="Use optimized inference",
141
+ default=False,
142
+ action="store_true",
143
+ )
144
+ parser.add_argument(
145
+ "--prompt",
146
+ type=str,
147
+ required=True,
148
+ help="Text prompt for generating a 3D mesh",
149
+ )
150
+ parser.add_argument(
151
+ "--top-p",
152
+ type=float,
153
+ default=None,
154
+ help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.",
155
+ )
156
+ parser.add_argument(
157
+ "--bounding-box-xyz",
158
+ nargs=3,
159
+ type=float,
160
+ help="Three float values for x, y, z bounding box",
161
+ default=None,
162
+ required=False,
163
+ )
164
+ parser.add_argument(
165
+ "--render-gif",
166
+ help="Render a turntable gif of the mesh",
167
+ default=False,
168
+ action="store_true",
169
+ )
170
+ parser.add_argument(
171
+ "--disable-postprocessing",
172
+ help="Disable postprocessing on the mesh. This will result in a mesh with more faces.",
173
+ default=False,
174
+ action="store_true",
175
+ )
176
+ parser.add_argument(
177
+ "--resolution-base",
178
+ type=float,
179
+ default=8.0,
180
+ help="Resolution base for the shape decoder.",
181
+ )
182
+ args = parser.parse_args()
183
+ os.makedirs(args.output_dir, exist_ok=True)
184
+ device = select_device()
185
+ print(f"Using device: {device}")
186
+ # Initialize engine based on fast_inference flag
187
+ if args.fast_inference:
188
+ print(
189
+ "Using cuda graphs, this will take some time to warmup and capture the graph."
190
+ )
191
+ engine = EngineFast(
192
+ args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
193
+ )
194
+ print("Compiled the graph.")
195
+ else:
196
+ engine = Engine(
197
+ args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, args.save_gpt_ckpt_path, device=device
198
+ )
199
+
200
+ if args.bounding_box_xyz is not None:
201
+ args.bounding_box_xyz = normalize_bbox(tuple(args.bounding_box_xyz))
202
+
203
+ # Generate meshes based on input source
204
+ # obj_path = generate_mesh(
205
+ # engine,
206
+ # args.prompt,
207
+ # args.output_dir,
208
+ # "output",
209
+ # args.resolution_base,
210
+ # args.disable_postprocessing,
211
+ # args.top_p,
212
+ # args.bounding_box_xyz,
213
+ # )
214
+
215
+ test_dataset = LegosDataset(args)
216
+ batch_size = 1
217
+ x_num = 213
218
+ y_num = 217
219
+ z_num = 529
220
+
221
+ # setup the dataloader
222
+ data_loader = DataLoader(
223
+ test_dataset,
224
+ shuffle=False,
225
+ batch_size=batch_size,
226
+ )
227
+ data_iter = iter(data_loader)
228
+ for idx in range(len(test_dataset)):
229
+ batch = next(data_iter)
230
+ prompt, targets, box = batch['prompt'], batch['target'].to(device), batch['bbox']
231
+ ldr = generate_ldr(
232
+ engine,
233
+ prompt,
234
+ targets,
235
+ args.output_dir,
236
+ "output",
237
+ args.resolution_base,
238
+ args.disable_postprocessing,
239
+ args.top_p,
240
+ #args.bounding_box_xyz,
241
+ normalize_bboxs(box.float(), [x_num-1, y_num-1, z_num-1]),
242
+ idx
243
+ )
244
+ # if idx>4:
245
+ # break