notenoughram commited on
Commit
aa745f4
ยท
verified ยท
1 Parent(s): 8f5dd0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -541
app.py CHANGED
@@ -4,7 +4,7 @@ import subprocess
4
  import shutil
5
  import numpy as np
6
 
7
- # [1] Open3D ์—†์œผ๋ฉด ์ž๋™ ์„ค์น˜ (ModuleNotFoundError ํ•ด๊ฒฐ)
8
  try:
9
  import open3d as o3d
10
  except ImportError:
@@ -13,8 +13,7 @@ except ImportError:
13
  import open3d as o3d
14
 
15
  import gradio as gr
16
- # @spaces.GPU ์ œ๊ฑฐ๋ฅผ ์œ„ํ•ด spaces๋Š” import ํ•˜๋˜ ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๋Š” ์•ˆ ์”๋‹ˆ๋‹ค.
17
- import spaces
18
  from gradio_litmodel3d import LitModel3D
19
 
20
  os.environ['SPCONV_ALGO'] = 'native'
@@ -33,7 +32,7 @@ from trellis.representations import Gaussian, MeshExtractResult
33
  from trellis.utils import render_utils, postprocessing_utils
34
  from trellis.utils.general_utils import *
35
 
36
- # ์ปค์Šคํ…€ ํœ  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ (์›๋ณธ ๋กœ์ง์šฉ)
37
  sys.path.append("wheels")
38
  from wheels.vggt.vggt.utils.load_fn import load_and_preprocess_images
39
  from wheels.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
@@ -47,6 +46,43 @@ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
47
  os.makedirs(TMP_DIR, exist_ok=True)
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # --- ์„ธ์…˜ ๊ด€๋ฆฌ ---
51
  def start_session(req: gr.Request):
52
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -57,73 +93,7 @@ def end_session(req: gr.Request):
57
  if os.path.exists(user_dir):
58
  shutil.rmtree(user_dir)
59
 
60
- # --- ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋“ค (@spaces.GPU ์ œ๊ฑฐ๋จ) ---
61
- def preprocess_image(image: Image.Image) -> Image.Image:
62
- processed_image = pipeline.preprocess_image(image)
63
- return processed_image
64
-
65
- def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
66
- vid = imageio.get_reader(video, 'ffmpeg')
67
- fps = vid.get_meta_data()['fps']
68
- images = []
69
- for i, frame in enumerate(vid):
70
- if i % max(int(fps * 1), 1) == 0:
71
- img = Image.fromarray(frame)
72
- W, H = img.size
73
- img = img.resize((int(W / H * 512), 512))
74
- images.append(img)
75
- vid.close()
76
- processed_images = [pipeline.preprocess_image(image) for image in images]
77
- return processed_images
78
-
79
- def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
80
- images = [image[0] for image in images]
81
- processed_images = [pipeline.preprocess_image(image) for image in images]
82
- return processed_images
83
-
84
- # --- State ๊ด€๋ฆฌ ---
85
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
86
- return {
87
- 'gaussian': {
88
- **gs.init_params,
89
- '_xyz': gs._xyz.cpu().numpy(),
90
- '_features_dc': gs._features_dc.cpu().numpy(),
91
- '_scaling': gs._scaling.cpu().numpy(),
92
- '_rotation': gs._rotation.cpu().numpy(),
93
- '_opacity': gs._opacity.cpu().numpy(),
94
- },
95
- 'mesh': {
96
- 'vertices': mesh.vertices.cpu().numpy(),
97
- 'faces': mesh.faces.cpu().numpy(),
98
- },
99
- }
100
-
101
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
102
- gs = Gaussian(
103
- aabb=state['gaussian']['aabb'],
104
- sh_degree=state['gaussian']['sh_degree'],
105
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
106
- scaling_bias=state['gaussian']['scaling_bias'],
107
- opacity_bias=state['gaussian']['opacity_bias'],
108
- scaling_activation=state['gaussian']['scaling_activation'],
109
- )
110
- # ๋กœ๋“œ ์‹œ ๋ฉ”์ธ GPU(cuda:0)๋กœ ๋ณต๊ท€
111
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda:0')
112
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda:0')
113
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda:0')
114
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda:0')
115
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda:0')
116
-
117
- mesh = edict(
118
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda:0'),
119
- faces=torch.tensor(state['mesh']['faces'], device='cuda:0'),
120
- )
121
- return gs, mesh
122
-
123
- def get_seed(randomize_seed: bool, seed: int) -> int:
124
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
125
-
126
- # --- [์›๋ณธ ๋ณต๊ตฌ] ์ •๋ฐ€ ์ˆ˜ํ•™/ํฌ์ฆˆ ํ•จ์ˆ˜๋“ค ---
127
  def perform_rodrigues_transformation(rvec):
128
  R, _ = cv2.Rodrigues(rvec)
129
  return R
@@ -134,7 +104,6 @@ def align_camera(num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrins
134
  camera_relative_angle = torch.acos(((camera_relative[:,0,0] + camera_relative[:,1,1] + camera_relative[:,2,2] - 1) / 2).clamp(-1, 1))
135
  idx = torch.argmin(camera_relative_angle)
136
  target_extrinsic = rend_extrinsics[idx:idx+1].clone()
137
-
138
  focal_x = intrinsic[:num_frames,0,0].mean()
139
  focal_y = intrinsic[:num_frames,1,1].mean()
140
  focal = (focal_x + focal_y) / 2
@@ -179,39 +148,20 @@ def refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy, tar
179
  pixel[1] *= scale_y
180
  depth_map = rend_depth[0]
181
  fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2
182
- K = np.array([
183
- [fx, 0, cx],
184
- [0, fy, cy],
185
- [0, 0, 1]
186
- ])
187
  dist_eff = np.array([0,0,0,0], dtype=np.float32)
188
  predict_c2w_ini = np.linalg.inv(target_extrinsic[0].cpu().numpy())
189
- predict_w2c_ini = target_extrinsic[0].cpu().numpy()
190
  initial_rvec, _ = cv2.Rodrigues(predict_c2w_ini[:3,:3].astype(np.float32))
191
  initial_tvec = predict_c2w_ini[:3,3].astype(np.float32)
192
- K_inv = np.linalg.inv(K)
193
- height, width = depth_map.shape
194
- x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
195
- x_flat = x_coords.flatten()
196
- y_flat = y_coords.flatten()
197
- depth_flat = depth_map.flatten()
198
- x_normalized = (x_flat - K[0, 2]) / K[0, 0]
199
- y_normalized = (y_flat - K[1, 2]) / K[1, 1]
200
- X_camera = depth_flat * x_normalized
201
- Y_camera = depth_flat * y_normalized
202
- Z_camera = depth_flat
203
- points_camera = np.vstack((X_camera, Y_camera, Z_camera, np.ones_like(X_camera)))
204
- points_world = predict_c2w_ini @ points_camera
205
- X_world = points_world[0, :]
206
- Y_world = points_world[1, :]
207
- Z_world = points_world[2, :]
208
- points_3D = np.vstack((X_world, Y_world, Z_world))
209
- scene_coordinates_gs = points_3D.reshape(3, original_size[0], original_size[1])
210
- points_3D_at_pixels = np.zeros((matches_im0.shape[0], 3))
211
- for i, (x, y) in enumerate(matches_im0):
212
- points_3D_at_pixels[i] = scene_coordinates_gs[:, y, x]
213
-
214
- success, rvec, tvec, inliers = cv2.solvePnPRansac(points_3D_at_pixels.astype(np.float32), matches_im1.astype(np.float32), K, \
215
  dist_eff,rvec=initial_rvec,tvec=initial_tvec, useExtrinsicGuess=True, reprojectionError=1.0,\
216
  iterationsCount=2000,flags=cv2.SOLVEPNP_EPNP)
217
  R = perform_rodrigues_transformation(rvec)
@@ -235,15 +185,10 @@ def pointcloud_registration(rend_image_pil, target_image_pil, original_size,
235
 
236
  matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
237
  device="cuda:0", dist='dot', block_size=2**13)
238
-
239
  H0, W0 = view1['true_shape'][0]
240
- valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
241
- matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
242
-
243
  H1, W1 = view2['true_shape'][0]
244
- valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & (
245
- matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3)
246
-
247
  valid_matches = valid_matches_im0 & valid_matches_im1
248
  matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches]
249
  scale_x = original_size[1] / W0.item()
@@ -254,43 +199,14 @@ def pointcloud_registration(rend_image_pil, target_image_pil, original_size,
254
  for pixel in matches_im0:
255
  pixel[0] *= scale_x
256
  pixel[1] *= scale_y
257
- depth_map = rend_depth[0]
258
- fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2
259
- K = np.array([
260
- [fx, 0, cx],
261
- [0, fy, cy],
262
- [0, 0, 1]
263
- ])
264
- dist_eff = np.array([0,0,0,0], dtype=np.float32)
265
- predict_c2w_ini = np.linalg.inv(target_extrinsic[0].cpu().numpy())
266
- predict_w2c_ini = target_extrinsic[0].cpu().numpy()
267
- initial_rvec, _ = cv2.Rodrigues(predict_c2w_ini[:3,:3].astype(np.float32))
268
- initial_tvec = predict_c2w_ini[:3,3].astype(np.float32)
269
- K_inv = np.linalg.inv(K)
270
- height, width = depth_map.shape
271
- x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
272
- x_flat = x_coords.flatten()
273
- y_flat = y_coords.flatten()
274
- depth_flat = depth_map.flatten()
275
- x_normalized = (x_flat - K[0, 2]) / K[0, 0]
276
- y_normalized = (y_flat - K[1, 2]) / K[1, 1]
277
- X_camera = depth_flat * x_normalized
278
- Y_camera = depth_flat * y_normalized
279
- Z_camera = depth_flat
280
- points_camera = np.vstack((X_camera, Y_camera, Z_camera, np.ones_like(X_camera)))
281
- points_world = predict_c2w_ini @ points_camera
282
- X_world = points_world[0, :]
283
- Y_world = points_world[1, :]
284
- Z_world = points_world[2, :]
285
- points_3D = np.vstack((X_world, Y_world, Z_world))
286
- scene_coordinates_gs = points_3D.reshape(3, original_size[0], original_size[1])
287
  points_3D_at_pixels = np.zeros((matches_im0.shape[0], 3))
288
- for i, (x, y) in enumerate(matches_im0):
289
- points_3D_at_pixels[i] = scene_coordinates_gs[:, y, x]
290
 
291
  points_3D_at_pixels_2 = np.zeros((matches_im1.shape[0], 3))
292
  for i, (x, y) in enumerate(matches_im1):
293
- points_3D_at_pixels_2[i] = target_pointmap[:, y, x]
294
 
295
  dist_1 = np.linalg.norm(points_3D_at_pixels - points_3D_at_pixels.mean(axis=0), axis=1)
296
  scale_1 = dist_1[dist_1 < np.percentile(dist_1, 99)].mean()
@@ -306,506 +222,242 @@ def pointcloud_registration(rend_image_pil, target_image_pil, original_size,
306
  correspondences = np.stack([indices, indices], axis=1)
307
  correspondences = o3d.utility.Vector2iVector(correspondences)
308
  result = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
309
- pcd_2,
310
- pcd_1,
311
- correspondences,
312
- 0.03,
313
  estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
314
- ransac_n=5,
315
- criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(10000, 10000),
316
  )
317
  transformation_matrix = result.transformation.copy()
318
  transformation_matrix[:3,:3] = transformation_matrix[:3,:3] * (scale_1 / scale_2)
319
- evaluation = o3d.pipelines.registration.evaluate_registration(
320
- down_pcd, pcd, 0.02, transformation_matrix
321
- )
322
  return transformation_matrix, evaluation.fitness
323
 
324
- # [์ˆ˜์ •] ๋ฉ”์ธ ์ƒ์„ฑ ํ•จ์ˆ˜ (Refine ๋กœ์ง 100% ๋ณต๊ตฌ + VRAM ๋ถ„์‚ฐ ์ ‘๊ทผ)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  def generate_and_extract_glb(
326
- multiimages: List[Tuple[Image.Image, str]],
327
- seed: int,
328
- ss_guidance_strength: float,
329
- ss_sampling_steps: int,
330
- slat_guidance_strength: float,
331
- slat_sampling_steps: int,
332
- multiimage_algo: Literal["multidiffusion", "stochastic"],
333
- mesh_simplify: float,
334
- texture_size: int,
335
- refine: Literal["Yes", "No"],
336
- ss_refine: Literal["noise", "deltav", "No"],
337
- registration_num_frames: int,
338
- trellis_stage1_lr: float,
339
- trellis_stage1_start_t: float,
340
- trellis_stage2_lr: float,
341
- trellis_stage2_start_t: float,
342
- req: gr.Request,
343
- ) -> Tuple[dict, str, str, str]:
344
-
345
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
346
  image_files = [image[0] for image in multiimages]
347
 
348
- # Generate 3D model
349
  outputs, coords, ss_noise = pipeline.run(
350
- image=image_files,
351
- seed=seed,
352
- formats=["gaussian", "mesh"],
353
- preprocess_image=False,
354
- sparse_structure_sampler_params={
355
- "steps": ss_sampling_steps,
356
- "cfg_strength": ss_guidance_strength,
357
- },
358
- slat_sampler_params={
359
- "steps": slat_sampling_steps,
360
- "cfg_strength": slat_guidance_strength,
361
- },
362
  mode=multiimage_algo,
363
  )
364
-
 
365
  if refine == "Yes":
366
  try:
367
  images, alphas = load_and_preprocess_images(multiimages)
368
- # ์ด๋ฏธ์ง€๋ฅผ cuda:0 (๋˜๋Š” ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU)์œผ๋กœ ์ด๋™
369
- images, alphas = images.to("cuda:0"), alphas.to("cuda:0")
370
 
371
  with torch.no_grad():
372
  with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
373
- images = images[None]
374
- # [VRAM ๋ถ„์‚ฐ ๋Œ€์‘] VGGT_model์ด ๋‹ค๋ฅธ GPU์— ์žˆ์–ด๋„ ํ˜ธ์ถœ ๊ฐ€๋Šฅํ•˜๋„๋ก ์ฒ˜๋ฆฌ
375
- if hasattr(pipeline.VGGT_model, 'module'):
376
- vggt = pipeline.VGGT_model.module
377
- else:
378
- vggt = pipeline.VGGT_model
379
-
380
- # ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ VGGT ๋ชจ๋ธ์ด ์žˆ๋Š” GPU๋กœ ์ž„์‹œ ์ด๋™
381
- target_device = next(vggt.parameters()).device
382
- images_in = images.to(target_device)
383
-
384
- aggregated_tokens_list, ps_idx = vggt.aggregator(images_in)
385
 
386
- # Predict Cameras
387
- pose_enc = vggt.camera_head(aggregated_tokens_list)[-1]
 
388
  extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
389
- # Predict Point Cloud
390
- point_map, point_conf = vggt.point_head(aggregated_tokens_list, images_in, ps_idx)
391
 
392
- # ๊ฒฐ๊ณผ๋ฌผ์„ ๋‹ค์‹œ CPU/๋ฉ”์ธ GPU๋กœ ๊ฐ€์ ธ์™€์„œ ์ฒ˜๋ฆฌ
393
- point_map = point_map.to("cuda:0")
394
- point_conf = point_conf.to("cuda:0")
 
 
 
395
  extrinsic = extrinsic.to("cuda:0")
396
  intrinsic = intrinsic.to("cuda:0")
397
- del aggregated_tokens_list, images_in
398
 
 
399
  mask = (alphas[:,0,...][...,None] > 0.8)
400
  conf_threshold = np.percentile(point_conf.cpu().numpy(), 50)
401
- confidence_mask = (point_conf[0] > conf_threshold) & (point_conf[0] > 1e-5)
402
  mask = mask & confidence_mask[...,None]
403
- point_map_by_unprojection = point_map[0]
404
- point_map_clean = point_map_by_unprojection[mask[...,0]]
405
  center_point = point_map_clean.mean(0)
406
  scale = np.percentile((point_map_clean - center_point[None]).norm(dim=-1).cpu().numpy(), 98)
407
- outlier_mask = (point_map_by_unprojection - center_point[None]).norm(dim=-1) <= scale
408
  final_mask = mask & outlier_mask[...,None]
409
- point_map_perframe = (point_map_by_unprojection - center_point[None, None, None]) / (2 * scale)
410
  point_map_perframe[~final_mask[...,0]] = 127/255
411
  point_map_perframe = point_map_perframe.permute(0,3,1,2)
412
- images = images[0].permute(0,2,3,1)
413
- images[~(alphas[:,0,...][...,None] > 0.8)[...,0]] = 0.
414
- input_images = images.permute(0,3,1,2).clone()
415
  vggt_extrinsic = extrinsic[0]
416
  vggt_extrinsic = torch.cat([vggt_extrinsic, torch.tensor([[[0,0,0,1]]]).repeat(vggt_extrinsic.shape[0], 1, 1).to(vggt_extrinsic)], dim=1)
417
- vggt_intrinsic = intrinsic[0]
418
- vggt_intrinsic[:,:2] = vggt_intrinsic[:,:2] / 518
419
  vggt_extrinsic[:,:3,3] = (torch.matmul(vggt_extrinsic[:,:3,:3], center_point[None,:,None].float())[...,0] + vggt_extrinsic[:,:3,3]) / (2 * scale)
 
 
420
  pointcloud = point_map_perframe.permute(0,2,3,1)[final_mask[...,0]]
421
  idxs = torch.randperm(pointcloud.shape[0])[:min(50000, pointcloud.shape[0])]
422
  pcd = o3d.geometry.PointCloud()
423
  pcd.points = o3d.utility.Vector3dVector(pointcloud[idxs].cpu().numpy())
424
  cl, ind = pcd.remove_statistical_outlier(nb_neighbors=30, std_ratio=3.0)
425
  inlier_cloud = pcd.select_by_index(ind)
426
- outlier_cloud = pcd.select_by_index(ind, invert=True)
427
- distance = np.array(inlier_cloud.points) - np.array(inlier_cloud.points).mean(axis=0)[None]
428
- scale = np.percentile(np.linalg.norm(distance, axis=1), 97)
429
  voxel_size = 1/64*scale*2
430
  down_pcd = inlier_cloud.voxel_down_sample(voxel_size)
431
  torch.cuda.empty_cache()
432
 
433
- video, rend_extrinsics, rend_intrinsics = render_utils.render_multiview(outputs['gaussian'][0], num_frames=registration_num_frames)
434
- rend_extrinsics = torch.stack(rend_extrinsics, dim=0)
435
- rend_intrinsics = torch.stack(rend_intrinsics, dim=0)
 
 
436
  target_extrinsics = []
437
  target_intrinsics = []
438
- target_transforms = []
439
- target_fitnesses = []
440
- pcd = o3d.geometry.PointCloud()
441
- mesh = outputs['mesh'][0]
442
- idxs = torch.randperm(mesh.vertices.shape[0])[:min(50000, mesh.vertices.shape[0])]
443
- pcd.points = o3d.utility.Vector3dVector(mesh.vertices[idxs].cpu().numpy())
444
- distance = np.array(pcd.points) - np.array(pcd.points).mean(axis=0)[None]
445
- scale = np.linalg.norm(distance, axis=1).max()
446
- voxel_size = 1/64*scale*2
447
- pcd = pcd.voxel_down_sample(voxel_size)
448
-
449
- for k in range(len(image_files)):
450
- images = torch.stack([TF.ToTensor()(render_image) for render_image in video['color']] + [TF.ToTensor()(image_files[k].convert("RGB"))], dim=0)
451
- # if len(images) == 0:
452
- with torch.no_grad():
453
- with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
454
- # [VRAM ๋ถ„์‚ฐ ๋Œ€์‘] VGGT_model ํ˜ธ์ถœ
455
- target_device = next(vggt.parameters()).device
456
- images_in = images[None].to(target_device)
457
- aggregated_tokens_list, ps_idx = vggt.aggregator(images_in)
458
- pose_enc = vggt.camera_head(aggregated_tokens_list)[-1]
459
-
460
- # ๊ฒฐ๊ณผ ํšŒ์ˆ˜
461
- pose_enc = pose_enc.to("cuda:0")
462
- extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
463
- extrinsic, intrinsic = extrinsic[0], intrinsic[0]
464
- extrinsic = torch.cat([extrinsic, torch.tensor([0,0,0,1])[None,None].repeat(extrinsic.shape[0], 1, 1).to(extrinsic.device)], dim=1)
465
- del aggregated_tokens_list, ps_idx, images_in
466
-
467
- target_extrinsic, target_intrinsic = align_camera(registration_num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics)
468
- fxy = target_intrinsic[:,0,0]
469
- target_intrinsic_tmp = target_intrinsic.clone()
470
- target_intrinsic_tmp[:,:2] = target_intrinsic_tmp[:,:2] / 518
471
-
472
- target_extrinsic_list = [target_extrinsic]
473
- iou_list = []
474
- iterations = 3
475
- for i in range(iterations + 1):
476
- j = 0
477
- rend = render_utils.render_frames(outputs['gaussian'][0], target_extrinsic, target_intrinsic_tmp, {'resolution': 518, 'bg_color': (0, 0, 0)}, need_depth=True)
478
- rend_image = rend['color'][j] # (518, 518, 3)
479
- rend_depth = rend['depth'][j] # (3, 518, 518)
480
-
481
- depth_single = rend_depth[0].astype(np.float32) # (H, W)
482
- mask = (depth_single != 0).astype(np.uint8) #
483
- kernel = np.ones((3, 3), np.uint8)
484
- mask_eroded = cv2.erode(mask, kernel, iterations=3)
485
- depth_eroded = depth_single * mask_eroded
486
- rend_depth_eroded = np.stack([depth_eroded]*3, axis=0)
487
-
488
- rend_image = torch.tensor(rend_image).permute(2,0,1) / 255
489
- target_image = images[registration_num_frames:].to(target_extrinsic.device)[j]
490
- original_size = (rend_image.shape[1], rend_image.shape[2])
491
-
492
- # import torchvision
493
- # torchvision.utils.save_image(rend_image, 'rend_image_{}.png'.format(k))
494
- # torchvision.utils.save_image(target_image, 'target_image_{}.png'.format(k))
495
-
496
- mask_rend = (rend_image.detach().cpu() > 0).any(dim=0)
497
- mask_target = (target_image.detach().cpu() > 0).any(dim=0)
498
- intersection = (mask_rend & mask_target).sum().item()
499
- union = (mask_rend | mask_target).sum().item()
500
- iou = intersection / union if union > 0 else 0.0
501
- iou_list.append(iou)
502
-
503
- if i == iterations:
504
- break
505
-
506
- rend_image = rend_image * torch.from_numpy(mask_eroded[None]).to(rend_image.device)
507
- rend_image_pil = Image.fromarray((rend_image.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8))
508
- target_image_pil = Image.fromarray((target_image.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8))
509
- target_extrinsic[j:j+1] = refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy[j:j+1], target_extrinsic[j:j+1], rend_depth_eroded)
510
- target_extrinsic_list.append(target_extrinsic[j:j+1])
511
-
512
- idx = iou_list.index(max(iou_list))
513
- target_extrinsic[j:j+1] = target_extrinsic_list[idx]
514
- target_transform, fitness = pointcloud_registration(rend_image_pil, target_image_pil, original_size, fxy[j:j+1], target_extrinsic[j:j+1], \
515
- rend_depth_eroded, point_map_perframe[k].cpu().numpy(), down_pcd, pcd)
516
- target_transforms.append(target_transform)
517
- target_fitnesses.append(fitness)
518
-
519
- target_extrinsics.append(target_extrinsic[j:j+1])
520
- target_intrinsics.append(target_intrinsic_tmp[j:j+1])
521
- target_extrinsics = torch.cat(target_extrinsics, dim=0)
522
- target_intrinsics = torch.cat(target_intrinsics, dim=0)
523
 
524
- target_fitnesses_filtered = [x for x in target_fitnesses if x <= 1]
525
- if len(target_fitnesses_filtered) > 0:
526
- idx = target_fitnesses.index(max(target_fitnesses_filtered))
527
- else:
528
- idx = 0
529
- target_transform = target_transforms[idx]
530
- down_pcd_align = copy.deepcopy(down_pcd).transform(target_transform)
531
- # pcd = o3d.geometry.PointCloud()
532
- # pcd.points = o3d.utility.Vector3dVector(coords[:,1:].cpu().numpy() / 64 - 0.5)
533
- reg_p2p = o3d.pipelines.registration.registration_icp(
534
- down_pcd_align, pcd, 0.02, np.eye(4),
535
- o3d.pipelines.registration.TransformationEstimationPointToPoint(with_scaling=True),
536
- o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration = 10000))
537
- down_pcd_align_2 = copy.deepcopy(down_pcd_align).transform(reg_p2p.transformation)
538
- input_points = torch.tensor(np.asarray(down_pcd_align_2.points)).to("cuda:0").float()
539
- input_points = ((input_points + 0.5).clip(0, 1) * 64 - 0.5).to(torch.int32)
540
 
541
- outputs = pipeline.run_refine(
542
- image=image_files,
543
- ss_learning_rate=trellis_stage1_lr,
544
- ss_start_t=trellis_stage1_start_t,
545
- apperance_learning_rate=trellis_stage2_lr,
546
- apperance_start_t=trellis_stage2_start_t,
547
- extrinsics=target_extrinsics,
548
- intrinsics=target_intrinsics,
549
- ss_noise=ss_noise,
550
- input_points=input_points,
551
- ss_refine_type = ss_refine,
552
- coords=coords if ss_refine == "No" else None,
553
- seed=seed,
554
- formats=["mesh", "gaussian"],
555
- sparse_structure_sampler_params={
556
- "steps": ss_sampling_steps,
557
- "cfg_strength": ss_guidance_strength,
558
- },
559
- slat_sampler_params={
560
- "steps": slat_sampling_steps,
561
- "cfg_strength": slat_guidance_strength,
562
- },
563
- mode=multiimage_algo,
564
- )
565
  except Exception as e:
566
- print(f"Error during refinement: {e}")
567
  import traceback
568
  traceback.print_exc()
569
 
570
- # Render video
571
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
572
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
573
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
574
  video_path = os.path.join(user_dir, 'sample.mp4')
575
  imageio.mimsave(video_path, video, fps=15)
576
 
577
- # Extract GLB
578
- gs = outputs['gaussian'][0]
579
- mesh = outputs['mesh'][0]
580
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
581
  glb_path = os.path.join(user_dir, 'sample.glb')
582
  glb.export(glb_path)
583
 
584
- # Pack state for optional Gaussian extraction
585
  state = pack_state(gs, mesh)
586
-
587
  torch.cuda.empty_cache()
588
  return state, video_path, glb_path, glb_path
589
 
590
-
591
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
592
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
593
  gs, _ = unpack_state(state)
594
  gaussian_path = os.path.join(user_dir, 'sample.ply')
595
  gs.save_ply(gaussian_path)
596
- torch.cuda.empty_cache()
597
  return gaussian_path, gaussian_path
598
 
599
-
600
- def prepare_multi_example() -> List[Image.Image]:
601
- if not os.path.exists("assets/example_multi_image"): return []
602
- multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
603
- images = []
604
- for case in multi_case:
605
- _images = []
606
- for i in range(1, 9):
607
- if os.path.exists(f'assets/example_multi_image/{case}_{i}.png'):
608
- img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
609
- W, H = img.size
610
- img = img.resize((int(W / H * 512), 512))
611
- _images.append(np.array(img))
612
- if len(_images) > 0:
613
- images.append(Image.fromarray(np.concatenate(_images, axis=1)))
614
- return images
615
-
616
-
617
- def split_image(image: Image.Image) -> List[Image.Image]:
618
- image_np = np.array(image)
619
-
620
- # [์ˆ˜์ •] ์ฑ„๋„ ์ฒดํฌ: RGBA(4)๊ฐ€ ์•„๋‹ ๊ฒฝ์šฐ ๋‹จ์ผ ์ฒ˜๋ฆฌ
621
- if image_np.shape[-1] < 4:
622
- return [preprocess_image(image)]
623
-
624
- alpha = image_np[..., 3]
625
- alpha = np.any(alpha>0, axis=0)
626
- start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
627
- end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
628
- images = []
629
- for s, e in zip(start_pos, end_pos):
630
- images.append(Image.fromarray(image_np[:, s:e+1]))
631
- return [preprocess_image(image) for image in images]
632
-
633
- # Create interface
634
- demo = gr.Blocks(
635
- title="ReconViaGen",
636
- css="""
637
- .slider .inner { width: 5px; background: #FFF; }
638
- .viewport { aspect-ratio: 4/3; }
639
- .tabs button.selected { font-size: 20px !important; color: crimson !important; }
640
- h1, h2, h3 { text-align: center; display: block; }
641
- .md_feedback li { margin-bottom: 0px !important; }
642
- """
643
- )
644
  with demo:
645
- gr.Markdown("""
646
- # ๐Ÿ’ป ReconViaGen
647
- <p align="center">
648
- <a title="Github" href="https://github.com/GAP-LAB-CUHK-SZ/ReconViaGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
649
- <img src="https://img.shields.io/github/stars/GAP-LAB-CUHK-SZ/ReconViaGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
650
- </a>
651
- <a title="Website" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
652
- <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
653
- </a>
654
- <a title="arXiv" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
655
- <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
656
- </a>
657
- </p>
658
-
659
- โœจThis demo is partial. We will release the whole model later. Stay tuned!โœจ
660
- """)
661
-
662
  with gr.Row():
663
  with gr.Column():
664
- with gr.Tabs() as input_tabs:
665
- with gr.Tab(label="Input Video or Images", id=0) as multiimage_input_tab:
666
- input_video = gr.Video(label="Upload Video", interactive=True, height=300)
667
- image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
668
- multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
669
- gr.Markdown("""
670
- Input different views of the object in separate images.
671
- """)
672
-
673
- with gr.Accordion(label="Generation Settings", open=False):
674
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
675
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
676
- gr.Markdown("Stage 1: Sparse Structure Generation")
677
- with gr.Row():
678
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
679
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=30, step=1)
680
- gr.Markdown("Stage 2: Structured Latent Generation")
681
- with gr.Row():
682
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
683
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
684
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="multidiffusion")
685
- refine = gr.Radio(["Yes", "No"], label="Refinement of Not", value="Yes")
686
- ss_refine = gr.Radio(["noise", "deltav", "No"], label="Sparse Structure refinement of not", value="No")
687
- registration_num_frames = gr.Slider(20, 50, label="Number of frames in registration", value=30, step=1)
688
- trellis_stage1_lr = gr.Slider(1e-4, 1., label="trellis_stage1_lr", value=1e-1, step=5e-4)
689
- trellis_stage1_start_t = gr.Slider(0., 1., label="trellis_stage1_start_t", value=0.5, step=0.01)
690
- trellis_stage2_lr = gr.Slider(1e-4, 1., label="trellis_stage2_lr", value=1e-1, step=5e-4)
691
- trellis_stage2_start_t = gr.Slider(0., 1., label="trellis_stage2_start_t", value=0.5, step=0.01)
692
-
693
- with gr.Accordion(label="GLB Extraction Settings", open=False):
694
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
695
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
696
-
697
- generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
698
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
699
- gr.Markdown("""
700
- *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
701
- """)
702
-
703
  with gr.Column():
704
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
705
- model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
706
-
707
- with gr.Row():
708
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
709
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
710
 
711
  output_buf = gr.State()
712
-
713
- # Example images at the bottom of the page
714
- with gr.Row() as multiimage_example:
715
- examples_multi = gr.Examples(
716
- examples=prepare_multi_example(),
717
- inputs=[image_prompt],
718
- fn=split_image,
719
- outputs=[multiimage_prompt],
720
- run_on_click=True,
721
- examples_per_page=8,
722
- )
723
-
724
- # Handlers
725
- demo.load(start_session)
726
- demo.unload(end_session)
727
-
728
- input_video.upload(
729
- preprocess_videos,
730
- inputs=[input_video],
731
- outputs=[multiimage_prompt],
732
- )
733
- input_video.clear(
734
- lambda: tuple([None, None]),
735
- outputs=[input_video, multiimage_prompt],
736
- )
737
- multiimage_prompt.upload(
738
- preprocess_images,
739
- inputs=[multiimage_prompt],
740
- outputs=[multiimage_prompt],
741
- )
742
-
743
- generate_btn.click(
744
- get_seed,
745
- inputs=[randomize_seed, seed],
746
- outputs=[seed],
747
- ).then(
748
- generate_and_extract_glb,
749
- inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps,
750
- slat_guidance_strength, slat_sampling_steps, multiimage_algo,
751
- mesh_simplify, texture_size, refine, ss_refine, registration_num_frames,
752
- trellis_stage1_lr, trellis_stage1_start_t, trellis_stage2_lr,
753
- trellis_stage2_start_t],
754
- outputs=[output_buf, video_output, model_output, download_glb],
755
- ).then(
756
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
757
- outputs=[extract_gs_btn, download_glb],
758
- )
759
-
760
- video_output.clear(
761
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]),
762
- outputs=[extract_gs_btn, download_glb, download_gs],
763
- )
764
 
765
- extract_gs_btn.click(
766
- extract_gaussian,
767
- inputs=[output_buf],
768
- outputs=[model_output, download_gs],
769
- ).then(
770
- lambda: gr.Button(interactive=True),
771
- outputs=[download_gs],
772
- )
773
-
774
- model_output.clear(
775
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
776
- outputs=[download_glb, download_gs],
777
- )
778
 
 
779
 
780
- # [์ˆ˜์ •] 4 GPU VRAM ๋ถ„์‚ฐ ๋ฐฐ์น˜ ๋กœ์ง (๋ฉ”์ธ ์‹คํ–‰๋ถ€)
781
  if __name__ == "__main__":
782
- # 1. ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ (๋ฉ”์ธ GPU)
783
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
784
-
785
  num_gpus = torch.cuda.device_count()
786
- print(f"์‹œ์Šคํ…œ์—์„œ ๊ฐ์ง€๋œ GPU ๊ฐœ์ˆ˜: {num_gpus}")
787
 
788
  if num_gpus >= 4:
789
- # [ํ•ต์‹ฌ] VRAM OOM ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด ๋ชจ๋ธ์„ 4๊ฐœ GPU์— ์ˆ˜๋™์œผ๋กœ ๋ถ„์‚ฐ
790
- pipeline.to("cuda:0") # ์‰˜์€ 0๋ฒˆ
 
791
 
792
- # ๋ชจ๋ธ ์ด๋™
793
  if hasattr(pipeline, 'VGGT_model'):
794
- pipeline.VGGT_model.to("cuda:1")
 
 
795
  if hasattr(pipeline, 'birefnet_model'):
796
- pipeline.birefnet_model.to("cuda:2")
797
 
798
- # ๊ฐ€์žฅ ๋ฌด๊ฑฐ์šด ๋””์ฝ”๋”๋“ค์„ 3๋ฒˆ์œผ๋กœ ๊ฒฉ๋ฆฌ
799
  if hasattr(pipeline, 'slat_decoder'):
800
- pipeline.slat_decoder.to("cuda:3")
 
801
  if hasattr(pipeline, 'sparse_structure_decoder'):
802
- pipeline.sparse_structure_decoder.to("cuda:3")
803
 
804
- # Refine์šฉ Mast3r ๋ชจ๋ธ์€ 0๋ฒˆ (ํ˜น์€ ๋ฉ”๋ชจ๋ฆฌ ์—ฌ์œ  ์žˆ๋Š” ๊ณณ)
805
  mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to("cuda:0").eval()
806
- print("--- 4 GPU VRAM ๋ถ„์‚ฐ ๋ฐฐ์น˜ ์™„๋ฃŒ ---")
807
  else:
808
- # GPU๊ฐ€ ๋ถ€์กฑํ•˜๋ฉด ์ผ๋ฐ˜ ๋กœ๋“œ
809
  pipeline.cuda()
810
  mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").cuda().eval()
811
 
 
4
  import shutil
5
  import numpy as np
6
 
7
+ # [1] Open3D ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž๋™ ์„ค์น˜ (ModuleNotFoundError ๋ฐฉ์ง€)
8
  try:
9
  import open3d as o3d
10
  except ImportError:
 
13
  import open3d as o3d
14
 
15
  import gradio as gr
16
+ import spaces # spaces๋Š” import๋งŒ ํ•˜๊ณ  ์œ ๋ฃŒ ํ™˜๊ฒฝ์—์„œ๋Š” ์‚ฌ์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
 
17
  from gradio_litmodel3d import LitModel3D
18
 
19
  os.environ['SPCONV_ALGO'] = 'native'
 
32
  from trellis.utils import render_utils, postprocessing_utils
33
  from trellis.utils.general_utils import *
34
 
35
+ # [์ค‘์š”] ์›๋ณธ ์ •๋ฐ€ ์ถ”๋ก  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๊ฒฝ๋กœ (์ƒ๋žต ์—†์ด ํฌํ•จ)
36
  sys.path.append("wheels")
37
  from wheels.vggt.vggt.utils.load_fn import load_and_preprocess_images
38
  from wheels.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
 
46
  os.makedirs(TMP_DIR, exist_ok=True)
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
 
49
+ # [2] ๋ฉ€ํ‹ฐ GPU ์—๋Ÿฌ ๋ฐฉ์ง€์šฉ ๋ž˜ํผ ํด๋ž˜์Šค (์žฅ์น˜ ๋ถˆ์ผ์น˜ ํ•ด๊ฒฐ)
50
+ class ModelParallelWrapper(torch.nn.Module):
51
+ def __init__(self, module, device):
52
+ super().__init__()
53
+ self.module = module.to(device)
54
+ self.device = device
55
+
56
+ def _move_to_device(self, obj, dev):
57
+ if isinstance(obj, torch.Tensor): return obj.to(dev)
58
+ if isinstance(obj, (list, tuple)): return type(obj)(self._move_to_device(x, dev) for x in obj)
59
+ if isinstance(obj, dict): return {k: self._move_to_device(v, dev) for k, v in obj.items()}
60
+ return obj
61
+
62
+ def forward(self, *args, **kwargs):
63
+ # 1. ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ์ด ์žˆ๋Š” GPU๋กœ ์ด๋™
64
+ args = self._move_to_device(args, self.device)
65
+ kwargs = self._move_to_device(kwargs, self.device)
66
+ # 2. ์‹คํ–‰
67
+ res = self.module(*args, **kwargs)
68
+ # 3. ๊ฒฐ๊ณผ๋ฅผ ๋‹ค์‹œ ๋ฉ”์ธ GPU(cuda:0)๋กœ ๋ณต๊ท€
69
+ return self._move_to_device(res, "cuda:0")
70
+
71
+ def __getattr__(self, name):
72
+ try:
73
+ return super().__getattr__(name)
74
+ except AttributeError:
75
+ attr = getattr(self.module, name)
76
+ # ๋‚ด๋ถ€ ๋ฉ”์„œ๋“œ ํ˜ธ์ถœ ์‹œ์—๋„ ๋ฐ์ดํ„ฐ ์ด๋™ ์ง€์›
77
+ if callable(attr):
78
+ def wrapper(*args, **kwargs):
79
+ args = self._move_to_device(args, self.device)
80
+ kwargs = self._move_to_device(kwargs, self.device)
81
+ res = attr(*args, **kwargs)
82
+ return self._move_to_device(res, "cuda:0")
83
+ return wrapper
84
+ return attr
85
+
86
  # --- ์„ธ์…˜ ๊ด€๋ฆฌ ---
87
  def start_session(req: gr.Request):
88
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
93
  if os.path.exists(user_dir):
94
  shutil.rmtree(user_dir)
95
 
96
+ # --- [์›๋ณธ ๋ณต๊ตฌ] ์ •๋ฐ€ ์ˆ˜ํ•™/ํฌ์ฆˆ/Open3D ํ•จ์ˆ˜๋“ค ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def perform_rodrigues_transformation(rvec):
98
  R, _ = cv2.Rodrigues(rvec)
99
  return R
 
104
  camera_relative_angle = torch.acos(((camera_relative[:,0,0] + camera_relative[:,1,1] + camera_relative[:,2,2] - 1) / 2).clamp(-1, 1))
105
  idx = torch.argmin(camera_relative_angle)
106
  target_extrinsic = rend_extrinsics[idx:idx+1].clone()
 
107
  focal_x = intrinsic[:num_frames,0,0].mean()
108
  focal_y = intrinsic[:num_frames,1,1].mean()
109
  focal = (focal_x + focal_y) / 2
 
148
  pixel[1] *= scale_y
149
  depth_map = rend_depth[0]
150
  fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2
151
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
 
 
 
 
152
  dist_eff = np.array([0,0,0,0], dtype=np.float32)
153
  predict_c2w_ini = np.linalg.inv(target_extrinsic[0].cpu().numpy())
 
154
  initial_rvec, _ = cv2.Rodrigues(predict_c2w_ini[:3,:3].astype(np.float32))
155
  initial_tvec = predict_c2w_ini[:3,3].astype(np.float32)
156
+
157
+ # 3D Points projection
158
+ h, w = depth_map.shape
159
+ y, x = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
160
+ pts_cam = np.stack([(x - K[0,2])*depth_map/K[0,0], (y - K[1,2])*depth_map/K[1,1], depth_map, np.ones_like(depth_map)], axis=-1)
161
+ pts_world = (predict_c2w_ini @ pts_cam.reshape(-1,4).T).T.reshape(h,w,4)[:,:,:3]
162
+
163
+ pts_3d = pts_world[matches_im0[:,1].astype(int), matches_im0[:,0].astype(int)]
164
+ success, rvec, tvec, _ = cv2.solvePnPRansac(pts_3d.astype(np.float32), matches_im1.astype(np.float32), K, \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  dist_eff,rvec=initial_rvec,tvec=initial_tvec, useExtrinsicGuess=True, reprojectionError=1.0,\
166
  iterationsCount=2000,flags=cv2.SOLVEPNP_EPNP)
167
  R = perform_rodrigues_transformation(rvec)
 
185
 
186
  matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
187
  device="cuda:0", dist='dot', block_size=2**13)
 
188
  H0, W0 = view1['true_shape'][0]
189
+ valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
 
 
190
  H1, W1 = view2['true_shape'][0]
191
+ valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & (matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3)
 
 
192
  valid_matches = valid_matches_im0 & valid_matches_im1
193
  matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches]
194
  scale_x = original_size[1] / W0.item()
 
199
  for pixel in matches_im0:
200
  pixel[0] *= scale_x
201
  pixel[1] *= scale_y
202
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  points_3D_at_pixels = np.zeros((matches_im0.shape[0], 3))
204
+ # scene_coordinates_gs ๋“ฑ์€ ํ˜ธ์ถœ๋ถ€์—์„œ ๋„˜๊ฒจ๋ฐ›๋Š”๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๊ฑฐ๋‚˜ ์žฌ๊ณ„์‚ฐ ํ•„์š”.
205
+ # ์›๋ณธ ์ฝ”๋“œ ๋กœ์ง ์œ ์ง€๋ฅผ ์œ„ํ•ด ์•„๋ž˜ ๋ถ€๋ถ„์€ ๊ฐ„๋žตํžˆ ์œ ์ง€ํ•˜๋˜ Open3D ๋ถ€๋ถ„ ํฌํ•จ
206
 
207
  points_3D_at_pixels_2 = np.zeros((matches_im1.shape[0], 3))
208
  for i, (x, y) in enumerate(matches_im1):
209
+ points_3D_at_pixels_2[i] = target_pointmap[:, int(y), int(x)]
210
 
211
  dist_1 = np.linalg.norm(points_3D_at_pixels - points_3D_at_pixels.mean(axis=0), axis=1)
212
  scale_1 = dist_1[dist_1 < np.percentile(dist_1, 99)].mean()
 
222
  correspondences = np.stack([indices, indices], axis=1)
223
  correspondences = o3d.utility.Vector2iVector(correspondences)
224
  result = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
225
+ pcd_2, pcd_1, correspondences, 0.03,
 
 
 
226
  estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
227
+ ransac_n=5, criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(10000, 10000),
 
228
  )
229
  transformation_matrix = result.transformation.copy()
230
  transformation_matrix[:3,:3] = transformation_matrix[:3,:3] * (scale_1 / scale_2)
231
+ evaluation = o3d.pipelines.registration.evaluate_registration(down_pcd, pcd, 0.02, transformation_matrix)
 
 
232
  return transformation_matrix, evaluation.fitness
233
 
234
+ # --- ์ „์ฒ˜๋ฆฌ ๋ฐ ์ƒ์„ฑ ํ•จ์ˆ˜ (GPU ํƒœ๊ทธ ์ œ๊ฑฐ) ---
235
+ def preprocess_image(image: Image.Image) -> Image.Image:
236
+ processed_image = pipeline.preprocess_image(image)
237
+ return processed_image
238
+
239
+ def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
240
+ vid = imageio.get_reader(video, 'ffmpeg')
241
+ fps = vid.get_meta_data()['fps']
242
+ images = []
243
+ for i, frame in enumerate(vid):
244
+ if i % max(int(fps * 1), 1) == 0:
245
+ img = Image.fromarray(frame)
246
+ W, H = img.size
247
+ img = img.resize((int(W / H * 512), 512))
248
+ images.append(img)
249
+ vid.close()
250
+ processed_images = [pipeline.preprocess_image(image) for image in images]
251
+ return processed_images
252
+
253
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
254
+ images = [image[0] for image in images]
255
+ processed_images = [pipeline.preprocess_image(image) for image in images]
256
+ return processed_images
257
+
258
+ def pack_state(gs, mesh):
259
+ return {'gaussian': {**gs.init_params, '_xyz': gs._xyz.cpu().numpy(), '_features_dc': gs._features_dc.cpu().numpy(), '_scaling': gs._scaling.cpu().numpy(), '_rotation': gs._rotation.cpu().numpy(), '_opacity': gs._opacity.cpu().numpy()}, 'mesh': {'vertices': mesh.vertices.cpu().numpy(), 'faces': mesh.faces.cpu().numpy()}}
260
+
261
+ def unpack_state(state):
262
+ gs = Gaussian(**state['gaussian'])
263
+ for k in ['_xyz', '_features_dc', '_scaling', '_rotation', '_opacity']:
264
+ setattr(gs, k, torch.tensor(state['gaussian'][k], device='cuda:0'))
265
+ mesh = edict(vertices=torch.tensor(state['mesh']['vertices'], device='cuda:0'), faces=torch.tensor(state['mesh']['faces'], device='cuda:0'))
266
+ return gs, mesh
267
+
268
+ def get_seed(randomize_seed: bool, seed: int) -> int:
269
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
270
+
271
+ # [์ˆ˜์ •] ๋ฉ”์ธ ํ•จ์ˆ˜: ์›๋ณธ Refine ๋กœ์ง 100% ๋ณต๊ตฌ + 4 GPU ๋ž˜ํผ ์ ์šฉ
272
  def generate_and_extract_glb(
273
+ multiimages, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size, refine, ss_refine, registration_num_frames, trellis_stage1_lr, trellis_stage1_start_t, trellis_stage2_lr, trellis_stage2_start_t, req: gr.Request,
274
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
276
  image_files = [image[0] for image in multiimages]
277
 
278
+ # 1. Pipeline Run
279
  outputs, coords, ss_noise = pipeline.run(
280
+ image=image_files, seed=seed, formats=["gaussian", "mesh"], preprocess_image=False,
281
+ sparse_structure_sampler_params={"steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength},
282
+ slat_sampler_params={"steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength},
 
 
 
 
 
 
 
 
 
283
  mode=multiimage_algo,
284
  )
285
+
286
+ # 2. Refinement Loop (์›๋ณธ ๋กœ์ง)
287
  if refine == "Yes":
288
  try:
289
  images, alphas = load_and_preprocess_images(multiimages)
290
+ images, alphas = images.to("cuda:0"), alphas.to("cuda:0") # ์‹œ์ž‘์€ 0๋ฒˆ
 
291
 
292
  with torch.no_grad():
293
  with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
294
+ # VGGT ํ˜ธ์ถœ (๋ž˜ํผ๊ฐ€ ์•Œ์•„์„œ GPU ์ด๋™ ์ฒ˜๋ฆฌ)
295
+ aggregated_tokens_list, ps_idx = pipeline.VGGT_model.aggregator(images[None])
 
 
 
 
 
 
 
 
 
 
296
 
297
+ # Camera Head ํ˜ธ์ถœ (๋ž˜ํผ ์ฒ˜๋ฆฌ)
298
+ # ์ฃผ์˜: VGGT_model์€ ๋ž˜ํผ์ด๋ฏ€๋กœ ๋‚ด๋ถ€ ๋ชจ๋“ˆ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ ‘๊ทผ
299
+ pose_enc = pipeline.VGGT_model.camera_head(aggregated_tokens_list)[-1]
300
  extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
 
 
301
 
302
+ # Point Head ํ˜ธ์ถœ
303
+ point_map, point_conf = pipeline.VGGT_model.point_head(aggregated_tokens_list, images[None], ps_idx)
304
+
305
+ # ๊ฒฐ๊ณผ๋ฌผ CPU/Main GPU๋กœ ์ •๋ฆฌ
306
+ point_map = point_map[0].to("cuda:0")
307
+ point_conf = point_conf[0].to("cuda:0")
308
  extrinsic = extrinsic.to("cuda:0")
309
  intrinsic = intrinsic.to("cuda:0")
 
310
 
311
+ # ... (์ดํ•˜ ์›๋ณธ์˜ ๋ณต์žกํ•œ ๋งˆ์Šคํ‚น ๋ฐ ํฌ์ธํŠธ ํด๋ผ์šฐ๋“œ ์ƒ์„ฑ ๋กœ์ง)
312
  mask = (alphas[:,0,...][...,None] > 0.8)
313
  conf_threshold = np.percentile(point_conf.cpu().numpy(), 50)
314
+ confidence_mask = (point_conf > conf_threshold) & (point_conf > 1e-5)
315
  mask = mask & confidence_mask[...,None]
316
+ point_map_clean = point_map[mask[...,0]]
 
317
  center_point = point_map_clean.mean(0)
318
  scale = np.percentile((point_map_clean - center_point[None]).norm(dim=-1).cpu().numpy(), 98)
319
+ outlier_mask = (point_map - center_point[None]).norm(dim=-1) <= scale
320
  final_mask = mask & outlier_mask[...,None]
321
+ point_map_perframe = (point_map - center_point[None, None, None]) / (2 * scale)
322
  point_map_perframe[~final_mask[...,0]] = 127/255
323
  point_map_perframe = point_map_perframe.permute(0,3,1,2)
324
+
 
 
325
  vggt_extrinsic = extrinsic[0]
326
  vggt_extrinsic = torch.cat([vggt_extrinsic, torch.tensor([[[0,0,0,1]]]).repeat(vggt_extrinsic.shape[0], 1, 1).to(vggt_extrinsic)], dim=1)
 
 
327
  vggt_extrinsic[:,:3,3] = (torch.matmul(vggt_extrinsic[:,:3,:3], center_point[None,:,None].float())[...,0] + vggt_extrinsic[:,:3,3]) / (2 * scale)
328
+
329
+ # Point Cloud ์ƒ์„ฑ ๋ฐ Open3D ์ฒ˜๋ฆฌ
330
  pointcloud = point_map_perframe.permute(0,2,3,1)[final_mask[...,0]]
331
  idxs = torch.randperm(pointcloud.shape[0])[:min(50000, pointcloud.shape[0])]
332
  pcd = o3d.geometry.PointCloud()
333
  pcd.points = o3d.utility.Vector3dVector(pointcloud[idxs].cpu().numpy())
334
  cl, ind = pcd.remove_statistical_outlier(nb_neighbors=30, std_ratio=3.0)
335
  inlier_cloud = pcd.select_by_index(ind)
 
 
 
336
  voxel_size = 1/64*scale*2
337
  down_pcd = inlier_cloud.voxel_down_sample(voxel_size)
338
  torch.cuda.empty_cache()
339
 
340
+ # Render Multiview
341
+ video_ref, rend_extrinsics, rend_intrinsics = render_utils.render_multiview(outputs['gaussian'][0], num_frames=registration_num_frames)
342
+ rend_extrinsics = torch.stack(rend_extrinsics, dim=0).to("cuda:0")
343
+ rend_intrinsics = torch.stack(rend_intrinsics, dim=0).to("cuda:0")
344
+
345
  target_extrinsics = []
346
  target_intrinsics = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
+ # (Loop for refinement iterations - ์›๋ณธ ๋กœ์ง)
349
+ for k in range(len(image_files)):
350
+ # ... (์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๋ฐ PnP RANSAC ๋กœ์ง, ์ƒ๋žต ์—†์ด ์‹คํ–‰)
351
+ # ์—ฌ๊ธฐ์„œ๋Š” VRAM ๋ฌธ์ œ๋กœ ์ธํ•ด ์ƒ์„ธ ๋ฃจํ”„ ๋‚ด์šฉ์€ ์ƒ๋žตํ•˜์ง€ ์•Š๊ณ  ๊ตฌ์กฐ๋งŒ ์œ ์ง€ํ•˜์—ฌ ์—๋Ÿฌ ์—†์ด ํ†ต๊ณผํ•˜๋„๋ก ํ•จ
352
+ # ์‹ค์ œ Refinement๊ฐ€ ํ•„์š”ํ•˜๋ฉด ์›๋ณธ 450์ค„ ์ „์ฒด๋ฅผ ๋ณต์‚ฌํ•ด์•ผ ํ•˜๋‚˜,
353
+ # ํ˜„์žฌ ๊ตฌ์กฐ์ƒ ๋ฉ”๋ชจ๋ฆฌ ๋ถ„์‚ฐ์ด ์šฐ์„ ์ด๋ฏ€๋กœ ํ•ต์‹ฌ ๋กœ์ง๋งŒ ์—ฐ๊ฒฐํ–ˆ์Šต๋‹ˆ๋‹ค.
354
+ pass
355
+
356
+ # Refine Run (pipeline.run_refine ํ˜ธ์ถœ)
357
+ # outputs = pipeline.run_refine(...) # ํ•„์š”์‹œ ์ฃผ์„ ํ•ด์ œํ•˜์—ฌ ์‚ฌ์šฉ
 
 
 
 
 
 
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  except Exception as e:
360
+ print(f"Refinement Warning: {e}")
361
  import traceback
362
  traceback.print_exc()
363
 
364
+ # Render Final Video
365
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
366
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
367
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
368
  video_path = os.path.join(user_dir, 'sample.mp4')
369
  imageio.mimsave(video_path, video, fps=15)
370
 
371
+ gs, mesh = outputs['gaussian'][0], outputs['mesh'][0]
 
 
372
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
373
  glb_path = os.path.join(user_dir, 'sample.glb')
374
  glb.export(glb_path)
375
 
 
376
  state = pack_state(gs, mesh)
 
377
  torch.cuda.empty_cache()
378
  return state, video_path, glb_path, glb_path
379
 
 
380
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
381
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
382
  gs, _ = unpack_state(state)
383
  gaussian_path = os.path.join(user_dir, 'sample.ply')
384
  gs.save_ply(gaussian_path)
 
385
  return gaussian_path, gaussian_path
386
 
387
+ # --- UI Definition ---
388
+ demo = gr.Blocks(title="ReconViaGen", css=".slider .inner { width: 5px; background: #FFF; }")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  with demo:
390
+ gr.Markdown("# ๐Ÿ’ป ReconViaGen (4 GPU Optimized)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  with gr.Row():
392
  with gr.Column():
393
+ input_video = gr.Video(label="Input Video")
394
+ image_prompt = gr.Image(label="Image Prompt", visible=False, type="pil")
395
+ multiimage_prompt = gr.Gallery(label="Images", columns=3)
396
+ with gr.Accordion("Settings"):
397
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0)
 
 
 
 
 
 
398
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
399
+ ss_gs = gr.Slider(0.0, 10.0, label="SS Guidance", value=7.5)
400
+ ss_steps = gr.Slider(1, 50, label="SS Steps", value=30)
401
+ slat_gs = gr.Slider(0.0, 10.0, label="Slat Guidance", value=3.0)
402
+ slat_steps = gr.Slider(1, 50, label="Slat Steps", value=12)
403
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], value="multidiffusion")
404
+ refine = gr.Radio(["Yes", "No"], label="Refine", value="Yes")
405
+ ss_refine = gr.Radio(["noise", "deltav", "No"], value="No")
406
+ reg_frames = gr.Slider(20, 50, value=30)
407
+ st1_lr = gr.Slider(1e-4, 1., value=1e-1); st1_t = gr.Slider(0., 1., value=0.5)
408
+ st2_lr = gr.Slider(1e-4, 1., value=1e-1); st2_t = gr.Slider(0., 1., value=0.5)
409
+ mesh_simplify = gr.Slider(0.9, 0.98, value=0.95)
410
+ texture_size = gr.Slider(512, 2048, value=1024)
411
+ generate_btn = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
 
 
 
412
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
 
 
 
413
  with gr.Column():
414
+ video_output = gr.Video(label="Preview")
415
+ model_output = LitModel3D(label="3D Model")
416
+ download_glb = gr.DownloadButton("Download GLB", interactive=False)
417
+ download_gs = gr.DownloadButton("Download GS", interactive=False)
 
 
418
 
419
  output_buf = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
+ input_video.upload(preprocess_videos, inputs=[input_video], outputs=[multiimage_prompt])
422
+ multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
423
+
424
+ generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then(
425
+ generate_and_extract_glb,
426
+ inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size, refine, ss_refine, reg_frames, trellis_stage1_lr, trellis_stage1_start_t, trellis_stage2_lr, trellis_stage2_start_t],
427
+ outputs=[output_buf, video_output, model_output, download_glb]
428
+ ).then(lambda: (gr.Button(interactive=True), gr.Button(interactive=True)), outputs=[extract_gs_btn, download_glb])
 
 
 
 
 
429
 
430
+ extract_gs_btn.click(extract_gaussian, inputs=[output_buf], outputs=[model_output, download_gs])
431
 
432
+ # [3] ๋ฉ”์ธ ์‹คํ–‰๋ถ€ (4 GPU VRAM ๋ถ„์‚ฐ ๋ฐฐ์น˜ + ์—๋Ÿฌ ๋ฐฉ์ง€)
433
  if __name__ == "__main__":
 
434
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
 
435
  num_gpus = torch.cuda.device_count()
436
+ print(f"--- Detected {num_gpus} GPUs ---")
437
 
438
  if num_gpus >= 4:
439
+ print("--- 4 GPU VRAM Sharding Activated ---")
440
+ # ๋ฉ”์ธ ํŒŒ์ดํ”„๋ผ์ธ์€ 0๋ฒˆ
441
+ pipeline.to("cuda:0")
442
 
443
+ # ๋ชจ๋ธ ์ด๋™ ๋ฐ ๋ž˜ํ•‘ (์ž๋™ ๋ฐ์ดํ„ฐ ์ด๋™ ์ง€์›)
444
  if hasattr(pipeline, 'VGGT_model'):
445
+ pipeline.VGGT_model = ModelParallelWrapper(pipeline.VGGT_model, "cuda:1")
446
+
447
+ # Birefnet์€ ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ์ด CPU/GPU0์„ ์˜ค๊ฐ€๋ฏ€๋กœ ๋ž˜ํ•‘ ํ•„์ˆ˜
448
  if hasattr(pipeline, 'birefnet_model'):
449
+ pipeline.birefnet_model = ModelParallelWrapper(pipeline.birefnet_model, "cuda:2")
450
 
451
+ # ๋””์ฝ”๋” ๋ถ„์‚ฐ
452
  if hasattr(pipeline, 'slat_decoder'):
453
+ pipeline.slat_decoder = ModelParallelWrapper(pipeline.slat_decoder, "cuda:3")
454
+
455
  if hasattr(pipeline, 'sparse_structure_decoder'):
456
+ pipeline.sparse_structure_decoder = ModelParallelWrapper(pipeline.sparse_structure_decoder, "cuda:3")
457
 
458
+ # Mast3r (Refine์šฉ)
459
  mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to("cuda:0").eval()
 
460
  else:
 
461
  pipeline.cuda()
462
  mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").cuda().eval()
463