notenoughram commited on
Commit
e25705e
·
verified ·
1 Parent(s): aa745f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +362 -363
app.py CHANGED
@@ -1,242 +1,74 @@
1
- import os
2
- import sys
3
- import subprocess
4
- import shutil
5
- import numpy as np
6
-
7
- # [1] Open3D 라이브러리 자동 설치 (ModuleNotFoundError 방지)
8
- try:
9
- import open3d as o3d
10
- except ImportError:
11
- print("Open3D not found. Installing...")
12
- subprocess.check_call([sys.executable, "-m", "pip", "install", "open3d"])
13
- import open3d as o3d
14
-
15
  import gradio as gr
16
- import spaces # spaces는 import만 하고 유료 환경에서는 사용하지 않습니다.
17
  from gradio_litmodel3d import LitModel3D
18
 
 
 
19
  os.environ['SPCONV_ALGO'] = 'native'
20
  from typing import *
21
  import torch
 
 
22
  import imageio
23
- import cv2
24
  from easydict import EasyDict as edict
25
  from PIL import Image
26
- from torchvision import transforms as TF
27
- import copy
28
-
29
- # Trellis 라이브러리
30
  from trellis.pipelines import TrellisVGGTTo3DPipeline
31
  from trellis.representations import Gaussian, MeshExtractResult
32
  from trellis.utils import render_utils, postprocessing_utils
33
- from trellis.utils.general_utils import *
34
 
35
- # [중요] 원본 정밀 추론 라이브러리 경로 (생략 없이 포함)
36
- sys.path.append("wheels")
37
- from wheels.vggt.vggt.utils.load_fn import load_and_preprocess_images
38
- from wheels.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
39
- from wheels.mast3r.model import AsymmetricMASt3R
40
- from wheels.mast3r.fast_nn import fast_reciprocal_NNs
41
- from wheels.dust3r.dust3r.inference import inference
42
- from wheels.dust3r.dust3r.utils.image import load_images_new
43
 
44
  MAX_SEED = np.iinfo(np.int32).max
45
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
46
  os.makedirs(TMP_DIR, exist_ok=True)
47
- device = "cuda" if torch.cuda.is_available() else "cpu"
48
-
49
- # [2] 멀티 GPU 에러 방지용 래퍼 클래스 (장치 불일치 해결)
50
- class ModelParallelWrapper(torch.nn.Module):
51
- def __init__(self, module, device):
52
- super().__init__()
53
- self.module = module.to(device)
54
- self.device = device
55
-
56
- def _move_to_device(self, obj, dev):
57
- if isinstance(obj, torch.Tensor): return obj.to(dev)
58
- if isinstance(obj, (list, tuple)): return type(obj)(self._move_to_device(x, dev) for x in obj)
59
- if isinstance(obj, dict): return {k: self._move_to_device(v, dev) for k, v in obj.items()}
60
- return obj
61
-
62
- def forward(self, *args, **kwargs):
63
- # 1. 입력 데이터를 모델이 있는 GPU로 이동
64
- args = self._move_to_device(args, self.device)
65
- kwargs = self._move_to_device(kwargs, self.device)
66
- # 2. 실행
67
- res = self.module(*args, **kwargs)
68
- # 3. 결과를 다시 메인 GPU(cuda:0)로 복귀
69
- return self._move_to_device(res, "cuda:0")
70
 
71
- def __getattr__(self, name):
72
- try:
73
- return super().__getattr__(name)
74
- except AttributeError:
75
- attr = getattr(self.module, name)
76
- # 내부 메서드 호출 시에도 데이터 이동 지원
77
- if callable(attr):
78
- def wrapper(*args, **kwargs):
79
- args = self._move_to_device(args, self.device)
80
- kwargs = self._move_to_device(kwargs, self.device)
81
- res = attr(*args, **kwargs)
82
- return self._move_to_device(res, "cuda:0")
83
- return wrapper
84
- return attr
85
-
86
- # --- 세션 관리 ---
87
  def start_session(req: gr.Request):
88
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
89
  os.makedirs(user_dir, exist_ok=True)
90
-
 
91
  def end_session(req: gr.Request):
92
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
93
  if os.path.exists(user_dir):
94
  shutil.rmtree(user_dir)
95
 
96
- # --- [원본 복구] 정밀 수학/포즈/Open3D 함수들 ---
97
- def perform_rodrigues_transformation(rvec):
98
- R, _ = cv2.Rodrigues(rvec)
99
- return R
100
-
101
- def align_camera(num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics):
102
- extrinsic_tmp = extrinsic.clone()
103
- camera_relative = torch.matmul(extrinsic_tmp[:num_frames,:3,:3].permute(0,2,1), extrinsic_tmp[num_frames:,:3,:3])
104
- camera_relative_angle = torch.acos(((camera_relative[:,0,0] + camera_relative[:,1,1] + camera_relative[:,2,2] - 1) / 2).clamp(-1, 1))
105
- idx = torch.argmin(camera_relative_angle)
106
- target_extrinsic = rend_extrinsics[idx:idx+1].clone()
107
- focal_x = intrinsic[:num_frames,0,0].mean()
108
- focal_y = intrinsic[:num_frames,1,1].mean()
109
- focal = (focal_x + focal_y) / 2
110
- rend_focal = (rend_intrinsics[0][0,0] + rend_intrinsics[0][1,1]) * 518 / 2
111
- focal_scale = rend_focal / focal
112
- target_intrinsic = intrinsic[num_frames:].clone()
113
- fxy = (target_intrinsic[:,0,0] + target_intrinsic[:,1,1]) / 2 * focal_scale
114
- target_intrinsic[:,0,0] = fxy
115
- target_intrinsic[:,1,1] = fxy
116
- return target_extrinsic, target_intrinsic
117
-
118
- def refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy, target_extrinsic, rend_depth):
119
- images_mast3r = load_images_new([rend_image_pil, target_image_pil], size=512, square_ok=True)
120
- with torch.no_grad():
121
- # [GPU 수정] mast3r 모델 추론 시 cuda:0 명시
122
- output = inference([tuple(images_mast3r)], mast3r_model, "cuda:0", batch_size=1, verbose=False)
123
- view1, pred1 = output['view1'], output['pred1']
124
- view2, pred2 = output['view2'], output['pred2']
125
- del output
126
- desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
127
-
128
- matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
129
- device="cuda:0", dist='dot', block_size=2**13)
130
-
131
- H0, W0 = view1['true_shape'][0]
132
- valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
133
- matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
134
-
135
- H1, W1 = view2['true_shape'][0]
136
- valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & (
137
- matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3)
138
-
139
- valid_matches = valid_matches_im0 & valid_matches_im1
140
- matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches]
141
- scale_x = original_size[1] / W0.item()
142
- scale_y = original_size[0] / H0.item()
143
- for pixel in matches_im1:
144
- pixel[0] *= scale_x
145
- pixel[1] *= scale_y
146
- for pixel in matches_im0:
147
- pixel[0] *= scale_x
148
- pixel[1] *= scale_y
149
- depth_map = rend_depth[0]
150
- fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2
151
- K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
152
- dist_eff = np.array([0,0,0,0], dtype=np.float32)
153
- predict_c2w_ini = np.linalg.inv(target_extrinsic[0].cpu().numpy())
154
- initial_rvec, _ = cv2.Rodrigues(predict_c2w_ini[:3,:3].astype(np.float32))
155
- initial_tvec = predict_c2w_ini[:3,3].astype(np.float32)
156
-
157
- # 3D Points projection
158
- h, w = depth_map.shape
159
- y, x = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
160
- pts_cam = np.stack([(x - K[0,2])*depth_map/K[0,0], (y - K[1,2])*depth_map/K[1,1], depth_map, np.ones_like(depth_map)], axis=-1)
161
- pts_world = (predict_c2w_ini @ pts_cam.reshape(-1,4).T).T.reshape(h,w,4)[:,:,:3]
162
-
163
- pts_3d = pts_world[matches_im0[:,1].astype(int), matches_im0[:,0].astype(int)]
164
- success, rvec, tvec, _ = cv2.solvePnPRansac(pts_3d.astype(np.float32), matches_im1.astype(np.float32), K, \
165
- dist_eff,rvec=initial_rvec,tvec=initial_tvec, useExtrinsicGuess=True, reprojectionError=1.0,\
166
- iterationsCount=2000,flags=cv2.SOLVEPNP_EPNP)
167
- R = perform_rodrigues_transformation(rvec)
168
- trans = -R.T @ np.matrix(tvec)
169
- predict_c2w_refine = np.eye(4)
170
- predict_c2w_refine[:3,:3] = R.T
171
- predict_c2w_refine[:3,3] = trans.reshape(3)
172
- target_extrinsic_final = torch.tensor(predict_c2w_refine).inverse().cuda()[None].float()
173
- return target_extrinsic_final
174
-
175
- def pointcloud_registration(rend_image_pil, target_image_pil, original_size,
176
- fxy, target_extrinsic, rend_depth, target_pointmap,
177
- down_pcd, pcd):
178
- images_mast3r = load_images_new([rend_image_pil, target_image_pil], size=512, square_ok=True)
179
- with torch.no_grad():
180
- output = inference([tuple(images_mast3r)], mast3r_model, "cuda:0", batch_size=1, verbose=False)
181
- view1, pred1 = output['view1'], output['pred1']
182
- view2, pred2 = output['view2'], output['pred2']
183
- del output
184
- desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
185
-
186
- matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
187
- device="cuda:0", dist='dot', block_size=2**13)
188
- H0, W0 = view1['true_shape'][0]
189
- valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
190
- H1, W1 = view2['true_shape'][0]
191
- valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & (matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3)
192
- valid_matches = valid_matches_im0 & valid_matches_im1
193
- matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches]
194
- scale_x = original_size[1] / W0.item()
195
- scale_y = original_size[0] / H0.item()
196
- for pixel in matches_im1:
197
- pixel[0] *= scale_x
198
- pixel[1] *= scale_y
199
- for pixel in matches_im0:
200
- pixel[0] *= scale_x
201
- pixel[1] *= scale_y
202
-
203
- points_3D_at_pixels = np.zeros((matches_im0.shape[0], 3))
204
- # scene_coordinates_gs 등은 호출부에서 넘겨받는다고 가정하거나 재계산 필요.
205
- # 원본 코드 로직 유지를 위해 아래 부분은 간략히 유지하되 Open3D 부분 포함
206
 
207
- points_3D_at_pixels_2 = np.zeros((matches_im1.shape[0], 3))
208
- for i, (x, y) in enumerate(matches_im1):
209
- points_3D_at_pixels_2[i] = target_pointmap[:, int(y), int(x)]
210
-
211
- dist_1 = np.linalg.norm(points_3D_at_pixels - points_3D_at_pixels.mean(axis=0), axis=1)
212
- scale_1 = dist_1[dist_1 < np.percentile(dist_1, 99)].mean()
213
- dist_2 = np.linalg.norm(points_3D_at_pixels_2 - points_3D_at_pixels_2.mean(axis=0), axis=1)
214
- scale_2 = dist_2[dist_2 < np.percentile(dist_2, 99)].mean()
215
-
216
- points_3D_at_pixels_2 = points_3D_at_pixels_2 * (scale_1 / scale_2)
217
- pcd_1 = o3d.geometry.PointCloud()
218
- pcd_1.points = o3d.utility.Vector3dVector(points_3D_at_pixels)
219
- pcd_2 = o3d.geometry.PointCloud()
220
- pcd_2.points = o3d.utility.Vector3dVector(points_3D_at_pixels_2)
221
- indices = np.arange(points_3D_at_pixels.shape[0])
222
- correspondences = np.stack([indices, indices], axis=1)
223
- correspondences = o3d.utility.Vector2iVector(correspondences)
224
- result = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
225
- pcd_2, pcd_1, correspondences, 0.03,
226
- estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
227
- ransac_n=5, criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(10000, 10000),
228
- )
229
- transformation_matrix = result.transformation.copy()
230
- transformation_matrix[:3,:3] = transformation_matrix[:3,:3] * (scale_1 / scale_2)
231
- evaluation = o3d.pipelines.registration.evaluate_registration(down_pcd, pcd, 0.02, transformation_matrix)
232
- return transformation_matrix, evaluation.fitness
233
 
234
- # --- 전처리 및 생성 함수 (GPU 태그 제거) ---
235
- def preprocess_image(image: Image.Image) -> Image.Image:
 
 
 
 
236
  processed_image = pipeline.preprocess_image(image)
237
  return processed_image
238
 
239
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  vid = imageio.get_reader(video, 'ffmpeg')
241
  fps = vid.get_meta_data()['fps']
242
  images = []
@@ -251,214 +83,381 @@ def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
251
  return processed_images
252
 
253
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
254
  images = [image[0] for image in images]
255
  processed_images = [pipeline.preprocess_image(image) for image in images]
256
  return processed_images
257
 
258
- def pack_state(gs, mesh):
259
- return {'gaussian': {**gs.init_params, '_xyz': gs._xyz.cpu().numpy(), '_features_dc': gs._features_dc.cpu().numpy(), '_scaling': gs._scaling.cpu().numpy(), '_rotation': gs._rotation.cpu().numpy(), '_opacity': gs._opacity.cpu().numpy()}, 'mesh': {'vertices': mesh.vertices.cpu().numpy(), 'faces': mesh.faces.cpu().numpy()}}
260
 
261
- def unpack_state(state):
262
- gs = Gaussian(**state['gaussian'])
263
- for k in ['_xyz', '_features_dc', '_scaling', '_rotation', '_opacity']:
264
- setattr(gs, k, torch.tensor(state['gaussian'][k], device='cuda:0'))
265
- mesh = edict(vertices=torch.tensor(state['mesh']['vertices'], device='cuda:0'), faces=torch.tensor(state['mesh']['faces'], device='cuda:0'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  return gs, mesh
267
 
 
268
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
270
 
271
- # [수정] 메인 함수: 원본 Refine 로직 100% 복구 + 4 GPU 래퍼 적용
272
  def generate_and_extract_glb(
273
- multiimages, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size, refine, ss_refine, registration_num_frames, trellis_stage1_lr, trellis_stage1_start_t, trellis_stage2_lr, trellis_stage2_start_t, req: gr.Request,
274
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
276
  image_files = [image[0] for image in multiimages]
277
 
278
- # 1. Pipeline Run
279
- outputs, coords, ss_noise = pipeline.run(
280
- image=image_files, seed=seed, formats=["gaussian", "mesh"], preprocess_image=False,
281
- sparse_structure_sampler_params={"steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength},
282
- slat_sampler_params={"steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength},
 
 
 
 
 
 
 
 
 
283
  mode=multiimage_algo,
284
  )
285
 
286
- # 2. Refinement Loop (원본 로직)
287
- if refine == "Yes":
288
- try:
289
- images, alphas = load_and_preprocess_images(multiimages)
290
- images, alphas = images.to("cuda:0"), alphas.to("cuda:0") # 시작은 0번
291
-
292
- with torch.no_grad():
293
- with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
294
- # VGGT 호출 (래퍼가 알아서 GPU 이동 처리)
295
- aggregated_tokens_list, ps_idx = pipeline.VGGT_model.aggregator(images[None])
296
-
297
- # Camera Head 호출 (래퍼 처리)
298
- # 주의: VGGT_model은 래퍼이므로 내부 모듈 구조에 따라 접근
299
- pose_enc = pipeline.VGGT_model.camera_head(aggregated_tokens_list)[-1]
300
- extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
301
-
302
- # Point Head 호출
303
- point_map, point_conf = pipeline.VGGT_model.point_head(aggregated_tokens_list, images[None], ps_idx)
304
-
305
- # 결과물 CPU/Main GPU로 정리
306
- point_map = point_map[0].to("cuda:0")
307
- point_conf = point_conf[0].to("cuda:0")
308
- extrinsic = extrinsic.to("cuda:0")
309
- intrinsic = intrinsic.to("cuda:0")
310
-
311
- # ... (이하 원본의 복잡한 마스킹 및 포인트 클라우드 생성 로직)
312
- mask = (alphas[:,0,...][...,None] > 0.8)
313
- conf_threshold = np.percentile(point_conf.cpu().numpy(), 50)
314
- confidence_mask = (point_conf > conf_threshold) & (point_conf > 1e-5)
315
- mask = mask & confidence_mask[...,None]
316
- point_map_clean = point_map[mask[...,0]]
317
- center_point = point_map_clean.mean(0)
318
- scale = np.percentile((point_map_clean - center_point[None]).norm(dim=-1).cpu().numpy(), 98)
319
- outlier_mask = (point_map - center_point[None]).norm(dim=-1) <= scale
320
- final_mask = mask & outlier_mask[...,None]
321
- point_map_perframe = (point_map - center_point[None, None, None]) / (2 * scale)
322
- point_map_perframe[~final_mask[...,0]] = 127/255
323
- point_map_perframe = point_map_perframe.permute(0,3,1,2)
324
-
325
- vggt_extrinsic = extrinsic[0]
326
- vggt_extrinsic = torch.cat([vggt_extrinsic, torch.tensor([[[0,0,0,1]]]).repeat(vggt_extrinsic.shape[0], 1, 1).to(vggt_extrinsic)], dim=1)
327
- vggt_extrinsic[:,:3,3] = (torch.matmul(vggt_extrinsic[:,:3,:3], center_point[None,:,None].float())[...,0] + vggt_extrinsic[:,:3,3]) / (2 * scale)
328
-
329
- # Point Cloud 생성 및 Open3D 처리
330
- pointcloud = point_map_perframe.permute(0,2,3,1)[final_mask[...,0]]
331
- idxs = torch.randperm(pointcloud.shape[0])[:min(50000, pointcloud.shape[0])]
332
- pcd = o3d.geometry.PointCloud()
333
- pcd.points = o3d.utility.Vector3dVector(pointcloud[idxs].cpu().numpy())
334
- cl, ind = pcd.remove_statistical_outlier(nb_neighbors=30, std_ratio=3.0)
335
- inlier_cloud = pcd.select_by_index(ind)
336
- voxel_size = 1/64*scale*2
337
- down_pcd = inlier_cloud.voxel_down_sample(voxel_size)
338
- torch.cuda.empty_cache()
339
-
340
- # Render Multiview
341
- video_ref, rend_extrinsics, rend_intrinsics = render_utils.render_multiview(outputs['gaussian'][0], num_frames=registration_num_frames)
342
- rend_extrinsics = torch.stack(rend_extrinsics, dim=0).to("cuda:0")
343
- rend_intrinsics = torch.stack(rend_intrinsics, dim=0).to("cuda:0")
344
-
345
- target_extrinsics = []
346
- target_intrinsics = []
347
-
348
- # (Loop for refinement iterations - 원본 로직)
349
- for k in range(len(image_files)):
350
- # ... (이미지 처리 및 PnP RANSAC 로직, 생략 없이 실행)
351
- # 여기서는 VRAM 문제로 인해 상세 루프 내용은 생략하지 않고 구조만 유지하여 에러 없이 통과하도록 함
352
- # 실제 Refinement가 필요하면 원본 450줄 전체를 복사해야 하나,
353
- # 현재 구조상 메모리 분산이 우선이므로 핵심 로직만 연결했습니다.
354
- pass
355
-
356
- # Refine Run (pipeline.run_refine 호출)
357
- # outputs = pipeline.run_refine(...) # 필요시 주석 해제하여 사용
358
-
359
- except Exception as e:
360
- print(f"Refinement Warning: {e}")
361
- import traceback
362
- traceback.print_exc()
363
 
364
- # Render Final Video
365
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
366
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
367
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
368
  video_path = os.path.join(user_dir, 'sample.mp4')
369
  imageio.mimsave(video_path, video, fps=15)
370
 
371
- gs, mesh = outputs['gaussian'][0], outputs['mesh'][0]
 
 
372
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
373
  glb_path = os.path.join(user_dir, 'sample.glb')
374
  glb.export(glb_path)
375
 
 
376
  state = pack_state(gs, mesh)
 
377
  torch.cuda.empty_cache()
378
  return state, video_path, glb_path, glb_path
379
 
 
380
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
382
  gs, _ = unpack_state(state)
383
  gaussian_path = os.path.join(user_dir, 'sample.ply')
384
  gs.save_ply(gaussian_path)
 
385
  return gaussian_path, gaussian_path
386
 
387
- # --- UI Definition ---
388
- demo = gr.Blocks(title="ReconViaGen", css=".slider .inner { width: 5px; background: #FFF; }")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  with demo:
390
- gr.Markdown("# 💻 ReconViaGen (4 GPU Optimized)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  with gr.Row():
392
  with gr.Column():
393
- input_video = gr.Video(label="Input Video")
394
- image_prompt = gr.Image(label="Image Prompt", visible=False, type="pil")
395
- multiimage_prompt = gr.Gallery(label="Images", columns=3)
396
- with gr.Accordion("Settings"):
397
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0)
 
 
 
 
 
 
398
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
399
- ss_gs = gr.Slider(0.0, 10.0, label="SS Guidance", value=7.5)
400
- ss_steps = gr.Slider(1, 50, label="SS Steps", value=30)
401
- slat_gs = gr.Slider(0.0, 10.0, label="Slat Guidance", value=3.0)
402
- slat_steps = gr.Slider(1, 50, label="Slat Steps", value=12)
403
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], value="multidiffusion")
404
- refine = gr.Radio(["Yes", "No"], label="Refine", value="Yes")
405
- ss_refine = gr.Radio(["noise", "deltav", "No"], value="No")
406
- reg_frames = gr.Slider(20, 50, value=30)
407
- st1_lr = gr.Slider(1e-4, 1., value=1e-1); st1_t = gr.Slider(0., 1., value=0.5)
408
- st2_lr = gr.Slider(1e-4, 1., value=1e-1); st2_t = gr.Slider(0., 1., value=0.5)
409
- mesh_simplify = gr.Slider(0.9, 0.98, value=0.95)
410
- texture_size = gr.Slider(512, 2048, value=1024)
411
- generate_btn = gr.Button("Generate", variant="primary")
 
 
412
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
 
 
 
413
  with gr.Column():
414
- video_output = gr.Video(label="Preview")
415
- model_output = LitModel3D(label="3D Model")
416
- download_glb = gr.DownloadButton("Download GLB", interactive=False)
417
- download_gs = gr.DownloadButton("Download GS", interactive=False)
 
 
418
 
419
  output_buf = gr.State()
420
-
421
- input_video.upload(preprocess_videos, inputs=[input_video], outputs=[multiimage_prompt])
422
- multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
423
-
424
- generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  generate_and_extract_glb,
426
- inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size, refine, ss_refine, reg_frames, trellis_stage1_lr, trellis_stage1_start_t, trellis_stage2_lr, trellis_stage2_start_t],
427
- outputs=[output_buf, video_output, model_output, download_glb]
428
- ).then(lambda: (gr.Button(interactive=True), gr.Button(interactive=True)), outputs=[extract_gs_btn, download_glb])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
- extract_gs_btn.click(extract_gaussian, inputs=[output_buf], outputs=[model_output, download_gs])
431
 
432
- # [3] 메인 실행부 (4 GPU VRAM 분산 배치 + 에러 방지)
433
  if __name__ == "__main__":
434
- pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
435
- num_gpus = torch.cuda.device_count()
436
- print(f"--- Detected {num_gpus} GPUs ---")
437
-
438
- if num_gpus >= 4:
439
- print("--- 4 GPU VRAM Sharding Activated ---")
440
- # 메인 파이프라인은 0번
441
- pipeline.to("cuda:0")
442
-
443
- # 모델 이동 래핑 (자동 데이터 이동 지원)
444
- if hasattr(pipeline, 'VGGT_model'):
445
- pipeline.VGGT_model = ModelParallelWrapper(pipeline.VGGT_model, "cuda:1")
446
-
447
- # Birefnet은 입력과 출력이 CPU/GPU0을 오가므로 래핑 필수
448
- if hasattr(pipeline, 'birefnet_model'):
449
- pipeline.birefnet_model = ModelParallelWrapper(pipeline.birefnet_model, "cuda:2")
450
 
451
- # 디코더 분산
452
- if hasattr(pipeline, 'slat_decoder'):
453
- pipeline.slat_decoder = ModelParallelWrapper(pipeline.slat_decoder, "cuda:3")
454
-
455
- if hasattr(pipeline, 'sparse_structure_decoder'):
456
- pipeline.sparse_structure_decoder = ModelParallelWrapper(pipeline.sparse_structure_decoder, "cuda:3")
457
-
458
- # Mast3r (Refine용)
459
- mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to("cuda:0").eval()
460
  else:
461
- pipeline.cuda()
462
- mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").cuda().eval()
463
 
464
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  from gradio_litmodel3d import LitModel3D
3
 
4
+ import os
5
+ import shutil
6
  os.environ['SPCONV_ALGO'] = 'native'
7
  from typing import *
8
  import torch
9
+ import torch.nn as nn # nn 모듈 추가
10
+ import numpy as np
11
  import imageio
 
12
  from easydict import EasyDict as edict
13
  from PIL import Image
 
 
 
 
14
  from trellis.pipelines import TrellisVGGTTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
 
17
 
18
+ # --- [Fix] DataParallel 속성 접근 오류 해결을 위한 래퍼 클래스 ---
19
+ class CustomDataParallel(nn.DataParallel):
20
+ def __getattr__(self, name):
21
+ try:
22
+ return super().__getattr__(name)
23
+ except AttributeError:
24
+ return getattr(self.module, name)
25
+ # -----------------------------------------------------------
26
 
27
  MAX_SEED = np.iinfo(np.int32).max
28
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
29
  os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def start_session(req: gr.Request):
32
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
  os.makedirs(user_dir, exist_ok=True)
34
+
35
+
36
  def end_session(req: gr.Request):
37
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
38
  if os.path.exists(user_dir):
39
  shutil.rmtree(user_dir)
40
 
41
+ def preprocess_image(image: Image.Image) -> Image.Image:
42
+ """
43
+ Preprocess the input image for 3D generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ This function is called when a user uploads an image or selects an example.
46
+ It applies background removal and other preprocessing steps necessary for
47
+ optimal 3D model generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ Args:
50
+ image (Image.Image): The input image from the user
51
+
52
+ Returns:
53
+ Image.Image: The preprocessed image ready for 3D generation
54
+ """
55
  processed_image = pipeline.preprocess_image(image)
56
  return processed_image
57
 
58
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
59
+ """
60
+ Preprocess the input video for multi-image 3D generation.
61
+
62
+ This function is called when a user uploads a video.
63
+ It extracts frames from the video and processes each frame to prepare them
64
+ for the multi-image 3D generation pipeline.
65
+
66
+ Args:
67
+ video (str): The path to the input video file
68
+
69
+ Returns:
70
+ List[Tuple[Image.Image, str]]: The list of preprocessed images ready for 3D generation
71
+ """
72
  vid = imageio.get_reader(video, 'ffmpeg')
73
  fps = vid.get_meta_data()['fps']
74
  images = []
 
83
  return processed_images
84
 
85
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
86
+ """
87
+ Preprocess a list of input images for multi-image 3D generation.
88
+
89
+ This function is called when users upload multiple images in the gallery.
90
+ It processes each image to prepare them for the multi-image 3D generation pipeline.
91
+
92
+ Args:
93
+ images (List[Tuple[Image.Image, str]]): The input images from the gallery
94
+
95
+ Returns:
96
+ List[Image.Image]: The preprocessed images ready for 3D generation
97
+ """
98
  images = [image[0] for image in images]
99
  processed_images = [pipeline.preprocess_image(image) for image in images]
100
  return processed_images
101
 
 
 
102
 
103
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
104
+ return {
105
+ 'gaussian': {
106
+ **gs.init_params,
107
+ '_xyz': gs._xyz.cpu().numpy(),
108
+ '_features_dc': gs._features_dc.cpu().numpy(),
109
+ '_scaling': gs._scaling.cpu().numpy(),
110
+ '_rotation': gs._rotation.cpu().numpy(),
111
+ '_opacity': gs._opacity.cpu().numpy(),
112
+ },
113
+ 'mesh': {
114
+ 'vertices': mesh.vertices.cpu().numpy(),
115
+ 'faces': mesh.faces.cpu().numpy(),
116
+ },
117
+ }
118
+
119
+
120
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
121
+ gs = Gaussian(
122
+ aabb=state['gaussian']['aabb'],
123
+ sh_degree=state['gaussian']['sh_degree'],
124
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
125
+ scaling_bias=state['gaussian']['scaling_bias'],
126
+ opacity_bias=state['gaussian']['opacity_bias'],
127
+ scaling_activation=state['gaussian']['scaling_activation'],
128
+ )
129
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
130
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
131
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
132
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
133
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
134
+
135
+ mesh = edict(
136
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
137
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
138
+ )
139
+
140
  return gs, mesh
141
 
142
+
143
  def get_seed(randomize_seed: bool, seed: int) -> int:
144
+ """
145
+ Get the random seed for generation.
146
+
147
+ This function is called by the generate button to determine whether to use
148
+ a random seed or the user-specified seed value.
149
+
150
+ Args:
151
+ randomize_seed (bool): Whether to generate a random seed
152
+ seed (int): The user-specified seed value
153
+
154
+ Returns:
155
+ int: The seed to use for generation
156
+ """
157
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
158
 
159
+
160
  def generate_and_extract_glb(
161
+ multiimages: List[Tuple[Image.Image, str]],
162
+ seed: int,
163
+ ss_guidance_strength: float,
164
+ ss_sampling_steps: int,
165
+ slat_guidance_strength: float,
166
+ slat_sampling_steps: int,
167
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
168
+ mesh_simplify: float,
169
+ texture_size: int,
170
+ req: gr.Request,
171
+ ) -> Tuple[dict, str, str, str]:
172
+ """
173
+ Convert an image to a 3D model and extract GLB file.
174
+
175
+ Args:
176
+ image (Image.Image): The input image.
177
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
178
+ is_multiimage (bool): Whether is in multi-image mode.
179
+ seed (int): The random seed.
180
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
181
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
182
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
183
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
184
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
185
+ mesh_simplify (float): The mesh simplification factor.
186
+ texture_size (int): The texture resolution.
187
+
188
+ Returns:
189
+ dict: The information of the generated 3D model.
190
+ str: The path to the video of the 3D model.
191
+ str: The path to the extracted GLB file.
192
+ str: The path to the extracted GLB file (for download).
193
+ """
194
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
195
  image_files = [image[0] for image in multiimages]
196
 
197
+ # Generate 3D model
198
+ outputs, _, _ = pipeline.run(
199
+ image=image_files,
200
+ seed=seed,
201
+ formats=["gaussian", "mesh"],
202
+ preprocess_image=False,
203
+ sparse_structure_sampler_params={
204
+ "steps": ss_sampling_steps,
205
+ "cfg_strength": ss_guidance_strength,
206
+ },
207
+ slat_sampler_params={
208
+ "steps": slat_sampling_steps,
209
+ "cfg_strength": slat_guidance_strength,
210
+ },
211
  mode=multiimage_algo,
212
  )
213
 
214
+ # Render video
215
+ # import uuid
216
+ # output_id = str(uuid.uuid4())
217
+ # os.makedirs(f"{TMP_DIR}/{output_id}", exist_ok=True)
218
+ # video_path = f"{TMP_DIR}/{output_id}/preview.mp4"
219
+ # glb_path = f"{TMP_DIR}/{output_id}/mesh.glb"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
 
221
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
222
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
223
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
224
  video_path = os.path.join(user_dir, 'sample.mp4')
225
  imageio.mimsave(video_path, video, fps=15)
226
 
227
+ # Extract GLB
228
+ gs = outputs['gaussian'][0]
229
+ mesh = outputs['mesh'][0]
230
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
231
  glb_path = os.path.join(user_dir, 'sample.glb')
232
  glb.export(glb_path)
233
 
234
+ # Pack state for optional Gaussian extraction
235
  state = pack_state(gs, mesh)
236
+
237
  torch.cuda.empty_cache()
238
  return state, video_path, glb_path, glb_path
239
 
240
+
241
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
242
+ """
243
+ Extract a Gaussian splatting file from the generated 3D model.
244
+
245
+ This function is called when the user clicks "Extract Gaussian" button.
246
+ It converts the 3D model state into a .ply file format containing
247
+ Gaussian splatting data for advanced 3D applications.
248
+
249
+ Args:
250
+ state (dict): The state of the generated 3D model containing Gaussian data
251
+ req (gr.Request): Gradio request object for session management
252
+
253
+ Returns:
254
+ Tuple[str, str]: Paths to the extracted Gaussian file (for display and download)
255
+ """
256
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
257
  gs, _ = unpack_state(state)
258
  gaussian_path = os.path.join(user_dir, 'sample.ply')
259
  gs.save_ply(gaussian_path)
260
+ torch.cuda.empty_cache()
261
  return gaussian_path, gaussian_path
262
 
263
+
264
+ def prepare_multi_example() -> List[Image.Image]:
265
+ # assets 경로 체크 추가 (에러 방지용)
266
+ if not os.path.exists("assets/example_multi_image"):
267
+ return []
268
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
269
+ images = []
270
+ for case in multi_case:
271
+ _images = []
272
+ for i in range(1, 9):
273
+ if os.path.exists(f'assets/example_multi_image/{case}_{i}.png'):
274
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
275
+ W, H = img.size
276
+ img = img.resize((int(W / H * 512), 512))
277
+ _images.append(np.array(img))
278
+ if len(_images) > 0:
279
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
280
+ return images
281
+
282
+
283
+ def split_image(image: Image.Image) -> List[Image.Image]:
284
+ """
285
+ Split a multi-view image into separate view images.
286
+
287
+ This function is called when users select multi-image examples that contain
288
+ multiple views in a single concatenated image. It automatically splits them
289
+ based on alpha channel boundaries and preprocesses each view.
290
+
291
+ Args:
292
+ image (Image.Image): A concatenated image containing multiple views
293
+
294
+ Returns:
295
+ List[Image.Image]: List of individual preprocessed view images
296
+ """
297
+ image = np.array(image)
298
+ alpha = image[..., 3]
299
+ alpha = np.any(alpha>0, axis=0)
300
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
301
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
302
+ images = []
303
+ for s, e in zip(start_pos, end_pos):
304
+ images.append(Image.fromarray(image[:, s:e+1]))
305
+ return [preprocess_image(image) for image in images]
306
+
307
+ # Create interface
308
+ demo = gr.Blocks(
309
+ title="ReconViaGen",
310
+ css="""
311
+ .slider .inner { width: 5px; background: #FFF; }
312
+ .viewport { aspect-ratio: 4/3; }
313
+ .tabs button.selected { font-size: 20px !important; color: crimson !important; }
314
+ h1, h2, h3 { text-align: center; display: block; }
315
+ .md_feedback li { margin-bottom: 0px !important; }
316
+ """
317
+ )
318
  with demo:
319
+ gr.Markdown("""
320
+ # 💻 ReconViaGen
321
+ <p align="center">
322
+ <a title="Github" href="https://github.com/GAP-LAB-CUHK-SZ/ReconViaGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
323
+ <img src="https://img.shields.io/github/stars/GAP-LAB-CUHK-SZ/ReconViaGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
324
+ </a>
325
+ <a title="Website" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
326
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
327
+ </a>
328
+ <a title="arXiv" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
329
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
330
+ </a>
331
+ </p>
332
+
333
+ ✨This demo is partial. We will release the whole model later. Stay tuned!✨
334
+ """)
335
+
336
  with gr.Row():
337
  with gr.Column():
338
+ with gr.Tabs() as input_tabs:
339
+ with gr.Tab(label="Input Video or Images", id=0) as multiimage_input_tab:
340
+ input_video = gr.Video(label="Upload Video", interactive=True, height=300)
341
+ image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
342
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
343
+ gr.Markdown("""
344
+ Input different views of the object in separate images.
345
+ """)
346
+
347
+ with gr.Accordion(label="Generation Settings", open=False):
348
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
349
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
350
+ gr.Markdown("Stage 1: Sparse Structure Generation")
351
+ with gr.Row():
352
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
353
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=30, step=1)
354
+ gr.Markdown("Stage 2: Structured Latent Generation")
355
+ with gr.Row():
356
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
357
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
358
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="multidiffusion")
359
+
360
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
361
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
362
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
363
+
364
+ generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
365
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
366
+ gr.Markdown("""
367
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
368
+ """)
369
+
370
  with gr.Column():
371
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
372
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
373
+
374
+ with gr.Row():
375
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
376
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
377
 
378
  output_buf = gr.State()
379
+
380
+ # Example images at the bottom of the page
381
+ with gr.Row() as multiimage_example:
382
+ examples_multi = gr.Examples(
383
+ examples=prepare_multi_example(),
384
+ inputs=[image_prompt],
385
+ fn=split_image,
386
+ outputs=[multiimage_prompt],
387
+ run_on_click=True,
388
+ examples_per_page=8,
389
+ )
390
+
391
+ # Handlers
392
+ demo.load(start_session)
393
+ demo.unload(end_session)
394
+
395
+ input_video.upload(
396
+ preprocess_videos,
397
+ inputs=[input_video],
398
+ outputs=[multiimage_prompt],
399
+ )
400
+ input_video.clear(
401
+ lambda: tuple([None, None]),
402
+ outputs=[input_video, multiimage_prompt],
403
+ )
404
+ multiimage_prompt.upload(
405
+ preprocess_images,
406
+ inputs=[multiimage_prompt],
407
+ outputs=[multiimage_prompt],
408
+ )
409
+
410
+ generate_btn.click(
411
+ get_seed,
412
+ inputs=[randomize_seed, seed],
413
+ outputs=[seed],
414
+ ).then(
415
  generate_and_extract_glb,
416
+ inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
417
+ outputs=[output_buf, video_output, model_output, download_glb],
418
+ ).then(
419
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
420
+ outputs=[extract_gs_btn, download_glb],
421
+ )
422
+
423
+ video_output.clear(
424
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]),
425
+ outputs=[extract_gs_btn, download_glb, download_gs],
426
+ )
427
+
428
+ extract_gs_btn.click(
429
+ extract_gaussian,
430
+ inputs=[output_buf],
431
+ outputs=[model_output, download_gs],
432
+ ).then(
433
+ lambda: gr.Button(interactive=True),
434
+ outputs=[download_gs],
435
+ )
436
+
437
+ model_output.clear(
438
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
439
+ outputs=[download_glb, download_gs],
440
+ )
441
 
 
442
 
443
+ # Launch the Gradio app
444
  if __name__ == "__main__":
445
+ print("Initializing Pipeline...")
446
+ pipeline = TrellisVGGTTo3DPipeline.from_pretrained("esther11/trellis-vggt-v0-2")
447
+ pipeline.cuda()
448
+ pipeline.VGGT_model.cuda()
449
+ pipeline.birefnet_model.cuda()
450
+
451
+ # --- [Fix] Multi-GPU Logic ---
452
+ if torch.cuda.device_count() > 1:
453
+ print(f"⚡ Multi-GPU Detected: {torch.cuda.device_count()} GPUs found.")
454
+ print("Wrapping VGGT_model with CustomDataParallel to handle attributes correctly.")
 
 
 
 
 
 
455
 
456
+ # VGGT_model을 CustomDataParallel로 감싸서 분산 처리
457
+ # CustomDataParallel은 'aggregator' 같은 속성을 module 내부에서 찾아줌
458
+ pipeline.VGGT_model = CustomDataParallel(pipeline.VGGT_model)
 
 
 
 
 
 
459
  else:
460
+ print(f"Running on Single GPU: {torch.cuda.get_device_name(0)}")
461
+ # -----------------------------
462
 
463
  demo.launch()