notenoughram commited on
Commit
9faeb08
·
verified ·
1 Parent(s): ec09383

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +611 -53
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import gradio as gr
 
 
 
 
2
  import os
3
  import shutil
4
-
5
  os.environ['SPCONV_ALGO'] = 'native'
6
  from typing import *
7
  import torch
@@ -12,26 +15,54 @@ from PIL import Image
12
  from trellis.pipelines import TrellisVGGTTo3DPipeline
13
  from trellis.representations import Gaussian, MeshExtractResult
14
  from trellis.utils import render_utils, postprocessing_utils
15
- from gradio_litmodel3d import LitModel3D
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
 
19
  os.makedirs(TMP_DIR, exist_ok=True)
 
20
 
21
  def start_session(req: gr.Request):
22
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
23
  os.makedirs(user_dir, exist_ok=True)
24
-
 
25
  def end_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
27
  if os.path.exists(user_dir):
28
  shutil.rmtree(user_dir)
29
 
 
30
  def preprocess_image(image: Image.Image) -> Image.Image:
 
 
 
31
  processed_image = pipeline.preprocess_image(image)
32
  return processed_image
33
 
 
34
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
 
 
 
35
  vid = imageio.get_reader(video, 'ffmpeg')
36
  fps = vid.get_meta_data()['fps']
37
  images = []
@@ -45,11 +76,16 @@ def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
45
  processed_images = [pipeline.preprocess_image(image) for image in images]
46
  return processed_images
47
 
 
48
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
 
 
 
49
  images = [image[0] for image in images]
50
  processed_images = [pipeline.preprocess_image(image) for image in images]
51
  return processed_images
52
 
 
53
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
54
  return {
55
  'gaussian': {
@@ -65,7 +101,8 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
65
  'faces': mesh.faces.cpu().numpy(),
66
  },
67
  }
68
-
 
69
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
70
  gs = Gaussian(
71
  aabb=state['gaussian']['aabb'],
@@ -75,7 +112,7 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
75
  opacity_bias=state['gaussian']['opacity_bias'],
76
  scaling_activation=state['gaussian']['scaling_activation'],
77
  )
78
- # 추론 메인 장치인 cuda:0으로 데이터를 보냅니다.
79
  gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda:0')
80
  gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda:0')
81
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda:0')
@@ -86,11 +123,218 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
86
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda:0'),
87
  faces=torch.tensor(state['mesh']['faces'], device='cuda:0'),
88
  )
 
89
  return gs, mesh
90
 
 
91
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
92
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def generate_and_extract_glb(
95
  multiimages: List[Tuple[Image.Image, str]],
96
  seed: int,
@@ -101,12 +345,23 @@ def generate_and_extract_glb(
101
  multiimage_algo: Literal["multidiffusion", "stochastic"],
102
  mesh_simplify: float,
103
  texture_size: int,
 
 
 
 
 
 
 
104
  req: gr.Request,
105
  ) -> Tuple[dict, str, str, str]:
 
 
 
106
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
107
  image_files = [image[0] for image in multiimages]
108
 
109
- outputs, _, _ = pipeline.run(
 
110
  image=image_files,
111
  seed=seed,
112
  formats=["gaussian", "mesh"],
@@ -121,24 +376,216 @@ def generate_and_extract_glb(
121
  },
122
  mode=multiimage_algo,
123
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
126
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
127
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
128
  video_path = os.path.join(user_dir, 'sample.mp4')
129
  imageio.mimsave(video_path, video, fps=15)
130
 
 
131
  gs = outputs['gaussian'][0]
132
  mesh = outputs['mesh'][0]
133
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
134
  glb_path = os.path.join(user_dir, 'sample.glb')
135
  glb.export(glb_path)
136
 
 
137
  state = pack_state(gs, mesh)
 
138
  torch.cuda.empty_cache()
139
  return state, video_path, glb_path, glb_path
140
 
 
 
141
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
 
 
 
142
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
143
  gs, _ = unpack_state(state)
144
  gaussian_path = os.path.join(user_dir, 'sample.ply')
@@ -146,16 +593,15 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
146
  torch.cuda.empty_cache()
147
  return gaussian_path, gaussian_path
148
 
 
149
  def prepare_multi_example() -> List[Image.Image]:
150
- if not os.path.exists("assets/example_multi_image"): return []
151
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
152
  images = []
153
  for case in multi_case:
154
  _images = []
155
  for i in range(1, 9):
156
- path = f'assets/example_multi_image/{case}_{i}.png'
157
- if os.path.exists(path):
158
- img = Image.open(path)
159
  W, H = img.size
160
  img = img.resize((int(W / H * 512), 512))
161
  _images.append(np.array(img))
@@ -163,83 +609,195 @@ def prepare_multi_example() -> List[Image.Image]:
163
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
164
  return images
165
 
 
166
  def split_image(image: Image.Image) -> List[Image.Image]:
167
- image = np.array(image)
168
- alpha = image[..., 3]
 
 
 
 
 
 
 
169
  alpha = np.any(alpha>0, axis=0)
170
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
171
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
172
  images = []
173
  for s, e in zip(start_pos, end_pos):
174
- images.append(Image.fromarray(image[:, s:e+1]))
175
  return [preprocess_image(image) for image in images]
176
 
177
- # --- Gradio UI ---
178
  demo = gr.Blocks(
179
  title="ReconViaGen",
180
- css=".slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; }"
 
 
 
 
 
 
181
  )
182
-
183
  with demo:
184
- gr.Markdown("# 💻 ReconViaGen\n✨This demo is partial. Stay tuned!✨")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  with gr.Row():
186
  with gr.Column():
187
  with gr.Tabs() as input_tabs:
188
- with gr.Tab(label="Input Video or Images", id=0):
189
  input_video = gr.Video(label="Upload Video", interactive=True, height=300)
190
- image_prompt = gr.Image(label="Image Prompt", visible=False, type="pil", height=300)
191
- multiimage_prompt = gr.Gallery(label="Image Prompt", columns=3)
192
- with gr.Accordion(label="Settings", open=False):
 
 
 
 
193
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
194
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
195
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="SS Guidance", value=7.5)
196
- ss_sampling_steps = gr.Slider(1, 50, label="SS Steps", value=30)
197
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Slat Guidance", value=3.0)
198
- slat_sampling_steps = gr.Slider(1, 50, label="Slat Steps", value=12)
199
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], value="multidiffusion")
200
- mesh_simplify = gr.Slider(0.9, 0.98, value=0.95)
201
- texture_size = gr.Slider(512, 2048, value=1024, step=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
203
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
 
 
 
204
  with gr.Column():
205
- video_output = gr.Video(label="Generated 3D Asset")
206
- model_output = LitModel3D(label="Extracted GLB/Gaussian")
 
207
  with gr.Row():
208
- download_glb = gr.DownloadButton("Download GLB", interactive=False)
209
- download_gs = gr.DownloadButton("Download Gaussian", interactive=False)
210
- output_buf = gr.State()
211
 
212
- with gr.Row():
213
- gr.Examples(examples=prepare_multi_example(), inputs=[image_prompt], fn=split_image, outputs=[multiimage_prompt], run_on_click=True)
 
 
 
 
 
 
 
 
 
 
214
 
 
215
  demo.load(start_session)
216
  demo.unload(end_session)
217
- input_video.upload(preprocess_videos, inputs=[input_video], outputs=[multiimage_prompt])
218
- input_video.clear(lambda: (None, None), outputs=[input_video, multiimage_prompt])
219
- multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
220
- generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  generate_and_extract_glb,
222
- inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
 
 
 
 
223
  outputs=[output_buf, video_output, model_output, download_glb],
224
- ).then(lambda: (gr.Button(interactive=True), gr.Button(interactive=True)), outputs=[extract_gs_btn, download_glb])
225
- extract_gs_btn.click(extract_gaussian, inputs=[output_buf], outputs=[model_output, download_gs])
 
 
226
 
227
- # --- 메인 실행부 (VRAM 분산 최적화) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  if __name__ == "__main__":
229
- # 모델 로드
230
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
231
 
232
  num_gpus = torch.cuda.device_count()
233
  if num_gpus >= 4:
234
- print(f"--- 4 GPUs Detected: Splitting Models to prevent VRAM Error ---")
235
- # 모델의 각 부분을 서로 다른 GPU 메모리에 적재하여 1개 GPU의 부담을 줄임
236
- pipeline.to("cuda:0")
237
- if hasattr(pipeline, 'VGGT_model'): pipeline.VGGT_model.to("cuda:1")
238
- if hasattr(pipeline, 'birefnet_model'): pipeline.birefnet_model.to("cuda:2")
239
- # 가장 무거운 디코더들은 3번 GPU로 격리
240
- if hasattr(pipeline, 'slat_decoder'): pipeline.slat_decoder.to("cuda:3")
241
- if hasattr(pipeline, 'sparse_structure_decoder'): pipeline.sparse_structure_decoder.to("cuda:3")
 
 
 
 
 
 
 
242
  else:
243
  pipeline.cuda()
244
-
 
245
  demo.launch()
 
1
  import gradio as gr
2
+ import spaces
3
+ # 유료 환경에서는 spaces import가 있어도 @spaces.GPU 데코레이터만 안 쓰면 됩니다.
4
+ from gradio_litmodel3d import LitModel3D
5
+
6
  import os
7
  import shutil
 
8
  os.environ['SPCONV_ALGO'] = 'native'
9
  from typing import *
10
  import torch
 
15
  from trellis.pipelines import TrellisVGGTTo3DPipeline
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
+
19
+ from wheels.vggt.vggt.utils.load_fn import load_and_preprocess_images
20
+ from wheels.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
21
+ import open3d as o3d
22
+ from torchvision import transforms as TF
23
+ from PIL import Image
24
+ import sys
25
+ # sys.path.append("wheels") # 필요시 경로 수정
26
+ import cv2 # cv2가 누락되어 있을 수 있어 추가했습니다.
27
+
28
+ from wheels.mast3r.model import AsymmetricMASt3R
29
+ from wheels.mast3r.fast_nn import fast_reciprocal_NNs
30
+ from wheels.dust3r.dust3r.inference import inference
31
+ from wheels.dust3r.dust3r.utils.image import load_images_new
32
+ from trellis.utils.general_utils import *
33
+ import copy
34
 
35
  MAX_SEED = np.iinfo(np.int32).max
36
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
37
+ # TMP_DIR = "tmp/Trellis-demo"
38
+ # os.environ['GRADIO_TEMP_DIR'] = 'tmp'
39
  os.makedirs(TMP_DIR, exist_ok=True)
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
 
42
  def start_session(req: gr.Request):
43
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
44
  os.makedirs(user_dir, exist_ok=True)
45
+
46
+
47
  def end_session(req: gr.Request):
48
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
49
+ # [수정] 폴더가 없으면 삭제하지 않도록 예외 처리 추가
50
  if os.path.exists(user_dir):
51
  shutil.rmtree(user_dir)
52
 
53
+ # [수정] 유료 4 GPU 사용 시 충돌 방지를 위해 @spaces.GPU 제거
54
  def preprocess_image(image: Image.Image) -> Image.Image:
55
+ """
56
+ Preprocess the input image for 3D generation.
57
+ """
58
  processed_image = pipeline.preprocess_image(image)
59
  return processed_image
60
 
61
+ # [수정] @spaces.GPU 제거
62
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
63
+ """
64
+ Preprocess the input video for multi-image 3D generation.
65
+ """
66
  vid = imageio.get_reader(video, 'ffmpeg')
67
  fps = vid.get_meta_data()['fps']
68
  images = []
 
76
  processed_images = [pipeline.preprocess_image(image) for image in images]
77
  return processed_images
78
 
79
+ # [수정] @spaces.GPU 제거
80
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
81
+ """
82
+ Preprocess a list of input images for multi-image 3D generation.
83
+ """
84
  images = [image[0] for image in images]
85
  processed_images = [pipeline.preprocess_image(image) for image in images]
86
  return processed_images
87
 
88
+
89
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
90
  return {
91
  'gaussian': {
 
101
  'faces': mesh.faces.cpu().numpy(),
102
  },
103
  }
104
+
105
+
106
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
107
  gs = Gaussian(
108
  aabb=state['gaussian']['aabb'],
 
112
  opacity_bias=state['gaussian']['opacity_bias'],
113
  scaling_activation=state['gaussian']['scaling_activation'],
114
  )
115
+ # [수정] 데이터를 로드할 때 메인 GPU(cuda:0)으로 보냄
116
  gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda:0')
117
  gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda:0')
118
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda:0')
 
123
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda:0'),
124
  faces=torch.tensor(state['mesh']['faces'], device='cuda:0'),
125
  )
126
+
127
  return gs, mesh
128
 
129
+
130
  def get_seed(randomize_seed: bool, seed: int) -> int:
131
+ """
132
+ Get the random seed for generation.
133
+ """
134
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
135
 
136
+ def align_camera(num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics):
137
+
138
+ extrinsic_tmp = extrinsic.clone()
139
+ camera_relative = torch.matmul(extrinsic_tmp[:num_frames,:3,:3].permute(0,2,1), extrinsic_tmp[num_frames:,:3,:3])
140
+ camera_relative_angle = torch.acos(((camera_relative[:,0,0] + camera_relative[:,1,1] + camera_relative[:,2,2] - 1) / 2).clamp(-1, 1))
141
+ idx = torch.argmin(camera_relative_angle)
142
+ target_extrinsic = rend_extrinsics[idx:idx+1].clone()
143
+
144
+ focal_x = intrinsic[:num_frames,0,0].mean()
145
+ focal_y = intrinsic[:num_frames,1,1].mean()
146
+ focal = (focal_x + focal_y) / 2
147
+ rend_focal = (rend_intrinsics[0][0,0] + rend_intrinsics[0][1,1]) * 518 / 2
148
+ focal_scale = rend_focal / focal
149
+ target_intrinsic = intrinsic[num_frames:].clone()
150
+ fxy = (target_intrinsic[:,0,0] + target_intrinsic[:,1,1]) / 2 * focal_scale
151
+ target_intrinsic[:,0,0] = fxy
152
+ target_intrinsic[:,1,1] = fxy
153
+ return target_extrinsic, target_intrinsic
154
+
155
+ def refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy, target_extrinsic, rend_depth):
156
+ images_mast3r = load_images_new([rend_image_pil, target_image_pil], size=512, square_ok=True)
157
+ with torch.no_grad():
158
+ # [수정] mast3r_model 추론 시 cuda:0 명시 (또는 할당된 device)
159
+ output = inference([tuple(images_mast3r)], mast3r_model, "cuda:0", batch_size=1, verbose=False)
160
+ view1, pred1 = output['view1'], output['pred1']
161
+ view2, pred2 = output['view2'], output['pred2']
162
+ del output
163
+ desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
164
+
165
+ # find 2D-2D matches between the two images
166
+ matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
167
+ device="cuda:0", dist='dot', block_size=2**13)
168
+
169
+ # ignore small border around the edge
170
+ H0, W0 = view1['true_shape'][0]
171
+
172
+ valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
173
+ matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
174
+
175
+ H1, W1 = view2['true_shape'][0]
176
+ valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & (
177
+ matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3)
178
+
179
+ valid_matches = valid_matches_im0 & valid_matches_im1
180
+ matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches]
181
+ scale_x = original_size[1] / W0.item()
182
+ scale_y = original_size[0] / H0.item()
183
+ for pixel in matches_im1:
184
+ pixel[0] *= scale_x
185
+ pixel[1] *= scale_y
186
+ for pixel in matches_im0:
187
+ pixel[0] *= scale_x
188
+ pixel[1] *= scale_y
189
+ depth_map = rend_depth[0]
190
+ fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2 # Example values for focal lengths and principal point
191
+ K = np.array([
192
+ [fx, 0, cx],
193
+ [0, fy, cy],
194
+ [0, 0, 1]
195
+ ])
196
+ dist_eff = np.array([0,0,0,0], dtype=np.float32)
197
+ predict_c2w_ini = np.linalg.inv(target_extrinsic[0].cpu().numpy())
198
+ predict_w2c_ini = target_extrinsic[0].cpu().numpy()
199
+ initial_rvec, _ = cv2.Rodrigues(predict_c2w_ini[:3,:3].astype(np.float32))
200
+ initial_tvec = predict_c2w_ini[:3,3].astype(np.float32)
201
+ K_inv = np.linalg.inv(K)
202
+ height, width = depth_map.shape
203
+ x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
204
+ x_flat = x_coords.flatten()
205
+ y_flat = y_coords.flatten()
206
+ depth_flat = depth_map.flatten()
207
+ x_normalized = (x_flat - K[0, 2]) / K[0, 0]
208
+ y_normalized = (y_flat - K[1, 2]) / K[1, 1]
209
+ X_camera = depth_flat * x_normalized
210
+ Y_camera = depth_flat * y_normalized
211
+ Z_camera = depth_flat
212
+ points_camera = np.vstack((X_camera, Y_camera, Z_camera, np.ones_like(X_camera)))
213
+ points_world = predict_c2w_ini @ points_camera
214
+ X_world = points_world[0, :]
215
+ Y_world = points_world[1, :]
216
+ Z_world = points_world[2, :]
217
+ points_3D = np.vstack((X_world, Y_world, Z_world))
218
+ scene_coordinates_gs = points_3D.reshape(3, original_size[0], original_size[1])
219
+ points_3D_at_pixels = np.zeros((matches_im0.shape[0], 3))
220
+ for i, (x, y) in enumerate(matches_im0):
221
+ points_3D_at_pixels[i] = scene_coordinates_gs[:, y, x]
222
+
223
+ success, rvec, tvec, inliers = cv2.solvePnPRansac(points_3D_at_pixels.astype(np.float32), matches_im1.astype(np.float32), K, \
224
+ dist_eff,rvec=initial_rvec,tvec=initial_tvec, useExtrinsicGuess=True, reprojectionError=1.0,\
225
+ iterationsCount=2000,flags=cv2.SOLVEPNP_EPNP)
226
+ R = perform_rodrigues_transformation(rvec)
227
+ trans = -R.T @ np.matrix(tvec)
228
+ predict_c2w_refine = np.eye(4)
229
+ predict_c2w_refine[:3,:3] = R.T
230
+ predict_c2w_refine[:3,3] = trans.reshape(3)
231
+ target_extrinsic_final = torch.tensor(predict_c2w_refine).inverse().cuda()[None].float()
232
+ return target_extrinsic_final
233
+
234
+ def pointcloud_registration(rend_image_pil, target_image_pil, original_size,
235
+ fxy, target_extrinsic, rend_depth, target_pointmap,
236
+ down_pcd, pcd):
237
+ images_mast3r = load_images_new([rend_image_pil, target_image_pil], size=512, square_ok=True)
238
+ with torch.no_grad():
239
+ output = inference([tuple(images_mast3r)], mast3r_model, "cuda:0", batch_size=1, verbose=False)
240
+ view1, pred1 = output['view1'], output['pred1']
241
+ view2, pred2 = output['view2'], output['pred2']
242
+ del output
243
+ desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
244
+
245
+ # find 2D-2D matches between the two images
246
+ matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
247
+ device="cuda:0", dist='dot', block_size=2**13)
248
+
249
+ # ignore small border around the edge
250
+ H0, W0 = view1['true_shape'][0]
251
+
252
+ valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
253
+ matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
254
+
255
+ H1, W1 = view2['true_shape'][0]
256
+ valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & (
257
+ matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3)
258
+
259
+ valid_matches = valid_matches_im0 & valid_matches_im1
260
+ matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches]
261
+ scale_x = original_size[1] / W0.item()
262
+ scale_y = original_size[0] / H0.item()
263
+ for pixel in matches_im1:
264
+ pixel[0] *= scale_x
265
+ pixel[1] *= scale_y
266
+ for pixel in matches_im0:
267
+ pixel[0] *= scale_x
268
+ pixel[1] *= scale_y
269
+ depth_map = rend_depth[0]
270
+ fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2 # Example values for focal lengths and principal point
271
+ K = np.array([
272
+ [fx, 0, cx],
273
+ [0, fy, cy],
274
+ [0, 0, 1]
275
+ ])
276
+ dist_eff = np.array([0,0,0,0], dtype=np.float32)
277
+ predict_c2w_ini = np.linalg.inv(target_extrinsic[0].cpu().numpy())
278
+ predict_w2c_ini = target_extrinsic[0].cpu().numpy()
279
+ initial_rvec, _ = cv2.Rodrigues(predict_c2w_ini[:3,:3].astype(np.float32))
280
+ initial_tvec = predict_c2w_ini[:3,3].astype(np.float32)
281
+ K_inv = np.linalg.inv(K)
282
+ height, width = depth_map.shape
283
+ x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
284
+ x_flat = x_coords.flatten()
285
+ y_flat = y_coords.flatten()
286
+ depth_flat = depth_map.flatten()
287
+ x_normalized = (x_flat - K[0, 2]) / K[0, 0]
288
+ y_normalized = (y_flat - K[1, 2]) / K[1, 1]
289
+ X_camera = depth_flat * x_normalized
290
+ Y_camera = depth_flat * y_normalized
291
+ Z_camera = depth_flat
292
+ points_camera = np.vstack((X_camera, Y_camera, Z_camera, np.ones_like(X_camera)))
293
+ points_world = predict_c2w_ini @ points_camera
294
+ X_world = points_world[0, :]
295
+ Y_world = points_world[1, :]
296
+ Z_world = points_world[2, :]
297
+ points_3D = np.vstack((X_world, Y_world, Z_world))
298
+ scene_coordinates_gs = points_3D.reshape(3, original_size[0], original_size[1])
299
+ points_3D_at_pixels = np.zeros((matches_im0.shape[0], 3))
300
+ for i, (x, y) in enumerate(matches_im0):
301
+ points_3D_at_pixels[i] = scene_coordinates_gs[:, y, x]
302
+
303
+ points_3D_at_pixels_2 = np.zeros((matches_im1.shape[0], 3))
304
+ for i, (x, y) in enumerate(matches_im1):
305
+ points_3D_at_pixels_2[i] = target_pointmap[:, y, x]
306
+
307
+ dist_1 = np.linalg.norm(points_3D_at_pixels - points_3D_at_pixels.mean(axis=0), axis=1)
308
+ scale_1 = dist_1[dist_1 < np.percentile(dist_1, 99)].mean()
309
+ dist_2 = np.linalg.norm(points_3D_at_pixels_2 - points_3D_at_pixels_2.mean(axis=0), axis=1)
310
+ scale_2 = dist_2[dist_2 < np.percentile(dist_2, 99)].mean()
311
+ # scale_1 = np.linalg.norm(points_3D_at_pixels - points_3D_at_pixels.mean(axis=0), axis=1).mean()
312
+ # scale_2 = np.linalg.norm(points_3D_at_pixels_2 - points_3D_at_pixels_2.mean(axis=0), axis=1).mean()
313
+ points_3D_at_pixels_2 = points_3D_at_pixels_2 * (scale_1 / scale_2)
314
+ pcd_1 = o3d.geometry.PointCloud()
315
+ pcd_1.points = o3d.utility.Vector3dVector(points_3D_at_pixels)
316
+ pcd_2 = o3d.geometry.PointCloud()
317
+ pcd_2.points = o3d.utility.Vector3dVector(points_3D_at_pixels_2)
318
+ indices = np.arange(points_3D_at_pixels.shape[0])
319
+ correspondences = np.stack([indices, indices], axis=1)
320
+ correspondences = o3d.utility.Vector2iVector(correspondences)
321
+ result = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
322
+ pcd_2,
323
+ pcd_1,
324
+ correspondences,
325
+ 0.03,
326
+ estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
327
+ ransac_n=5,
328
+ criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(10000, 10000),
329
+ )
330
+ transformation_matrix = result.transformation.copy()
331
+ transformation_matrix[:3,:3] = transformation_matrix[:3,:3] * (scale_1 / scale_2)
332
+ evaluation = o3d.pipelines.registration.evaluate_registration(
333
+ down_pcd, pcd, 0.02, transformation_matrix
334
+ )
335
+ return transformation_matrix, evaluation.fitness
336
+
337
+ # [수정] @spaces.GPU 제거
338
  def generate_and_extract_glb(
339
  multiimages: List[Tuple[Image.Image, str]],
340
  seed: int,
 
345
  multiimage_algo: Literal["multidiffusion", "stochastic"],
346
  mesh_simplify: float,
347
  texture_size: int,
348
+ refine: Literal["Yes", "No"],
349
+ ss_refine: Literal["noise", "deltav", "No"],
350
+ registration_num_frames: int,
351
+ trellis_stage1_lr: float,
352
+ trellis_stage1_start_t: float,
353
+ trellis_stage2_lr: float,
354
+ trellis_stage2_start_t: float,
355
  req: gr.Request,
356
  ) -> Tuple[dict, str, str, str]:
357
+ """
358
+ Convert an image to a 3D model and extract GLB file.
359
+ """
360
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
361
  image_files = [image[0] for image in multiimages]
362
 
363
+ # Generate 3D model
364
+ outputs, coords, ss_noise = pipeline.run(
365
  image=image_files,
366
  seed=seed,
367
  formats=["gaussian", "mesh"],
 
376
  },
377
  mode=multiimage_algo,
378
  )
379
+ if refine == "Yes":
380
+ try:
381
+ images, alphas = load_and_preprocess_images(multiimages)
382
+ images, alphas = images.to(device), alphas.to(device)
383
+ with torch.no_grad():
384
+ with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
385
+ images = images[None]
386
+ # [수정] 분산 배치된 VGGT_model 접근
387
+ vggt = pipeline.VGGT_model if not hasattr(pipeline.VGGT_model, 'module') else pipeline.VGGT_model.module
388
+ aggregated_tokens_list, ps_idx = vggt.aggregator(images)
389
+ # Predict Cameras
390
+ pose_enc = vggt.camera_head(aggregated_tokens_list)[-1]
391
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
392
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
393
+ # Predict Point Cloud
394
+ point_map, point_conf = vggt.point_head(aggregated_tokens_list, images, ps_idx)
395
+ del aggregated_tokens_list
396
+ mask = (alphas[:,0,...][...,None] > 0.8)
397
+ conf_threshold = np.percentile(point_conf.cpu().numpy(), 50)
398
+ confidence_mask = (point_conf[0] > conf_threshold) & (point_conf[0] > 1e-5)
399
+ mask = mask & confidence_mask[...,None]
400
+ point_map_by_unprojection = point_map[0]
401
+ point_map_clean = point_map_by_unprojection[mask[...,0]]
402
+ center_point = point_map_clean.mean(0)
403
+ scale = np.percentile((point_map_clean - center_point[None]).norm(dim=-1).cpu().numpy(), 98)
404
+ outlier_mask = (point_map_by_unprojection - center_point[None]).norm(dim=-1) <= scale
405
+ final_mask = mask & outlier_mask[...,None]
406
+ point_map_perframe = (point_map_by_unprojection - center_point[None, None, None]) / (2 * scale)
407
+ point_map_perframe[~final_mask[...,0]] = 127/255
408
+ point_map_perframe = point_map_perframe.permute(0,3,1,2)
409
+ images = images[0].permute(0,2,3,1)
410
+ images[~(alphas[:,0,...][...,None] > 0.8)[...,0]] = 0.
411
+ input_images = images.permute(0,3,1,2).clone()
412
+ vggt_extrinsic = extrinsic[0]
413
+ vggt_extrinsic = torch.cat([vggt_extrinsic, torch.tensor([[[0,0,0,1]]]).repeat(vggt_extrinsic.shape[0], 1, 1).to(vggt_extrinsic)], dim=1)
414
+ vggt_intrinsic = intrinsic[0]
415
+ vggt_intrinsic[:,:2] = vggt_intrinsic[:,:2] / 518
416
+ vggt_extrinsic[:,:3,3] = (torch.matmul(vggt_extrinsic[:,:3,:3], center_point[None,:,None].float())[...,0] + vggt_extrinsic[:,:3,3]) / (2 * scale)
417
+ pointcloud = point_map_perframe.permute(0,2,3,1)[final_mask[...,0]]
418
+ idxs = torch.randperm(pointcloud.shape[0])[:min(50000, pointcloud.shape[0])]
419
+ pcd = o3d.geometry.PointCloud()
420
+ pcd.points = o3d.utility.Vector3dVector(pointcloud[idxs].cpu().numpy())
421
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=30, std_ratio=3.0)
422
+ inlier_cloud = pcd.select_by_index(ind)
423
+ outlier_cloud = pcd.select_by_index(ind, invert=True)
424
+ distance = np.array(inlier_cloud.points) - np.array(inlier_cloud.points).mean(axis=0)[None]
425
+ scale = np.percentile(np.linalg.norm(distance, axis=1), 97)
426
+ voxel_size = 1/64*scale*2
427
+ down_pcd = inlier_cloud.voxel_down_sample(voxel_size)
428
+ torch.cuda.empty_cache()
429
 
430
+ video, rend_extrinsics, rend_intrinsics = render_utils.render_multiview(outputs['gaussian'][0], num_frames=registration_num_frames)
431
+ rend_extrinsics = torch.stack(rend_extrinsics, dim=0)
432
+ rend_intrinsics = torch.stack(rend_intrinsics, dim=0)
433
+ target_extrinsics = []
434
+ target_intrinsics = []
435
+ target_transforms = []
436
+ target_fitnesses = []
437
+ pcd = o3d.geometry.PointCloud()
438
+ mesh = outputs['mesh'][0]
439
+ idxs = torch.randperm(mesh.vertices.shape[0])[:min(50000, mesh.vertices.shape[0])]
440
+ pcd.points = o3d.utility.Vector3dVector(mesh.vertices[idxs].cpu().numpy())
441
+ distance = np.array(pcd.points) - np.array(pcd.points).mean(axis=0)[None]
442
+ scale = np.linalg.norm(distance, axis=1).max()
443
+ voxel_size = 1/64*scale*2
444
+ pcd = pcd.voxel_down_sample(voxel_size)
445
+ # pcd.points = o3d.utility.Vector3dVector((coords[:,1:].cpu().numpy() + 0.5) / 64 - 0.5)
446
+
447
+ for k in range(len(image_files)):
448
+ images = torch.stack([TF.ToTensor()(render_image) for render_image in video['color']] + [TF.ToTensor()(image_files[k].convert("RGB"))], dim=0)
449
+ # if len(images) == 0:
450
+ with torch.no_grad():
451
+ with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
452
+ # predictions = vggt_model(images.cuda())
453
+ vggt = pipeline.VGGT_model if not hasattr(pipeline.VGGT_model, 'module') else pipeline.VGGT_model.module
454
+ aggregated_tokens_list, ps_idx = vggt.aggregator(images[None].cuda())
455
+ pose_enc = vggt.camera_head(aggregated_tokens_list)[-1]
456
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
457
+ extrinsic, intrinsic = extrinsic[0], intrinsic[0]
458
+ extrinsic = torch.cat([extrinsic, torch.tensor([0,0,0,1])[None,None].repeat(extrinsic.shape[0], 1, 1).to(extrinsic.device)], dim=1)
459
+ del aggregated_tokens_list, ps_idx
460
+
461
+ target_extrinsic, target_intrinsic = align_camera(registration_num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics)
462
+ fxy = target_intrinsic[:,0,0]
463
+ target_intrinsic_tmp = target_intrinsic.clone()
464
+ target_intrinsic_tmp[:,:2] = target_intrinsic_tmp[:,:2] / 518
465
+
466
+ target_extrinsic_list = [target_extrinsic]
467
+ iou_list = []
468
+ iterations = 3
469
+ for i in range(iterations + 1):
470
+ j = 0
471
+ rend = render_utils.render_frames(outputs['gaussian'][0], target_extrinsic, target_intrinsic_tmp, {'resolution': 518, 'bg_color': (0, 0, 0)}, need_depth=True)
472
+ rend_image = rend['color'][j] # (518, 518, 3)
473
+ rend_depth = rend['depth'][j] # (3, 518, 518)
474
+
475
+ depth_single = rend_depth[0].astype(np.float32) # (H, W)
476
+ mask = (depth_single != 0).astype(np.uint8) #
477
+ kernel = np.ones((3, 3), np.uint8)
478
+ mask_eroded = cv2.erode(mask, kernel, iterations=3)
479
+ depth_eroded = depth_single * mask_eroded
480
+ rend_depth_eroded = np.stack([depth_eroded]*3, axis=0)
481
+
482
+ rend_image = torch.tensor(rend_image).permute(2,0,1) / 255
483
+ target_image = images[registration_num_frames:].to(target_extrinsic.device)[j]
484
+ original_size = (rend_image.shape[1], rend_image.shape[2])
485
+
486
+ import torchvision
487
+ torchvision.utils.save_image(rend_image, 'rend_image_{}.png'.format(k))
488
+ torchvision.utils.save_image(target_image, 'target_image_{}.png'.format(k))
489
+
490
+ mask_rend = (rend_image.detach().cpu() > 0).any(dim=0)
491
+ mask_target = (target_image.detach().cpu() > 0).any(dim=0)
492
+ intersection = (mask_rend & mask_target).sum().item()
493
+ union = (mask_rend | mask_target).sum().item()
494
+ iou = intersection / union if union > 0 else 0.0
495
+ iou_list.append(iou)
496
+
497
+ if i == iterations:
498
+ break
499
+
500
+ rend_image = rend_image * torch.from_numpy(mask_eroded[None]).to(rend_image.device)
501
+ rend_image_pil = Image.fromarray((rend_image.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8))
502
+ target_image_pil = Image.fromarray((target_image.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8))
503
+ target_extrinsic[j:j+1] = refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy[j:j+1], target_extrinsic[j:j+1], rend_depth_eroded)
504
+ target_extrinsic_list.append(target_extrinsic[j:j+1])
505
+
506
+ idx = iou_list.index(max(iou_list))
507
+ target_extrinsic[j:j+1] = target_extrinsic_list[idx]
508
+ target_transform, fitness = pointcloud_registration(rend_image_pil, target_image_pil, original_size, fxy[j:j+1], target_extrinsic[j:j+1], \
509
+ rend_depth_eroded, point_map_perframe[k].cpu().numpy(), down_pcd, pcd)
510
+ target_transforms.append(target_transform)
511
+ target_fitnesses.append(fitness)
512
+
513
+ target_extrinsics.append(target_extrinsic[j:j+1])
514
+ target_intrinsics.append(target_intrinsic_tmp[j:j+1])
515
+ target_extrinsics = torch.cat(target_extrinsics, dim=0)
516
+ target_intrinsics = torch.cat(target_intrinsics, dim=0)
517
+
518
+ target_fitnesses_filtered = [x for x in target_fitnesses if x <= 1]
519
+ idx = target_fitnesses.index(max(target_fitnesses_filtered))
520
+ target_transform = target_transforms[idx]
521
+ down_pcd_align = copy.deepcopy(down_pcd).transform(target_transform)
522
+ # pcd = o3d.geometry.PointCloud()
523
+ # pcd.points = o3d.utility.Vector3dVector(coords[:,1:].cpu().numpy() / 64 - 0.5)
524
+ reg_p2p = o3d.pipelines.registration.registration_icp(
525
+ down_pcd_align, pcd, 0.02, np.eye(4),
526
+ o3d.pipelines.registration.TransformationEstimationPointToPoint(with_scaling=True),
527
+ o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration = 10000))
528
+ down_pcd_align_2 = copy.deepcopy(down_pcd_align).transform(reg_p2p.transformation)
529
+ input_points = torch.tensor(np.asarray(down_pcd_align_2.points)).to(extrinsic.device).float()
530
+ input_points = ((input_points + 0.5).clip(0, 1) * 64 - 0.5).to(torch.int32)
531
+
532
+ outputs = pipeline.run_refine(
533
+ image=image_files,
534
+ ss_learning_rate=trellis_stage1_lr,
535
+ ss_start_t=trellis_stage1_start_t,
536
+ apperance_learning_rate=trellis_stage2_lr,
537
+ apperance_start_t=trellis_stage2_start_t,
538
+ extrinsics=target_extrinsics,
539
+ intrinsics=target_intrinsics,
540
+ ss_noise=ss_noise,
541
+ input_points=input_points,
542
+ ss_refine_type = ss_refine,
543
+ coords=coords if ss_refine == "No" else None,
544
+ seed=seed,
545
+ formats=["mesh", "gaussian"],
546
+ sparse_structure_sampler_params={
547
+ "steps": ss_sampling_steps,
548
+ "cfg_strength": ss_guidance_strength,
549
+ },
550
+ slat_sampler_params={
551
+ "steps": slat_sampling_steps,
552
+ "cfg_strength": slat_guidance_strength,
553
+ },
554
+ mode=multiimage_algo,
555
+ )
556
+ except Exception as e:
557
+ print(f"Error during refinement: {e}")
558
+ # Render video
559
+ # import uuid
560
+ # output_id = str(uuid.uuid4())
561
+ # os.makedirs(f"{TMP_DIR}/{output_id}", exist_ok=True)
562
+ # video_path = f"{TMP_DIR}/{output_id}/preview.mp4"
563
+ # glb_path = f"{TMP_DIR}/{output_id}/mesh.glb"
564
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
565
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
566
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
567
  video_path = os.path.join(user_dir, 'sample.mp4')
568
  imageio.mimsave(video_path, video, fps=15)
569
 
570
+ # Extract GLB
571
  gs = outputs['gaussian'][0]
572
  mesh = outputs['mesh'][0]
573
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
574
  glb_path = os.path.join(user_dir, 'sample.glb')
575
  glb.export(glb_path)
576
 
577
+ # Pack state for optional Gaussian extraction
578
  state = pack_state(gs, mesh)
579
+
580
  torch.cuda.empty_cache()
581
  return state, video_path, glb_path, glb_path
582
 
583
+
584
+ # [수정] @spaces.GPU 제거
585
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
586
+ """
587
+ Extract a Gaussian splatting file from the generated 3D model.
588
+ """
589
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
590
  gs, _ = unpack_state(state)
591
  gaussian_path = os.path.join(user_dir, 'sample.ply')
 
593
  torch.cuda.empty_cache()
594
  return gaussian_path, gaussian_path
595
 
596
+
597
  def prepare_multi_example() -> List[Image.Image]:
 
598
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
599
  images = []
600
  for case in multi_case:
601
  _images = []
602
  for i in range(1, 9):
603
+ if os.path.exists(f'assets/example_multi_image/{case}_{i}.png'):
604
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
 
605
  W, H = img.size
606
  img = img.resize((int(W / H * 512), 512))
607
  _images.append(np.array(img))
 
609
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
610
  return images
611
 
612
+
613
  def split_image(image: Image.Image) -> List[Image.Image]:
614
+ """
615
+ Split a multi-view image into separate view images.
616
+ """
617
+ image_np = np.array(image)
618
+ # [안정성 추가] 채널 3개짜리(RGB) 이미지가 들어올 경우 에러 방지
619
+ if image_np.shape[-1] < 4:
620
+ return [preprocess_image(image)]
621
+
622
+ alpha = image_np[..., 3]
623
  alpha = np.any(alpha>0, axis=0)
624
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
625
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
626
  images = []
627
  for s, e in zip(start_pos, end_pos):
628
+ images.append(Image.fromarray(image_np[:, s:e+1]))
629
  return [preprocess_image(image) for image in images]
630
 
631
+ # Create interface
632
  demo = gr.Blocks(
633
  title="ReconViaGen",
634
+ css="""
635
+ .slider .inner { width: 5px; background: #FFF; }
636
+ .viewport { aspect-ratio: 4/3; }
637
+ .tabs button.selected { font-size: 20px !important; color: crimson !important; }
638
+ h1, h2, h3 { text-align: center; display: block; }
639
+ .md_feedback li { margin-bottom: 0px !important; }
640
+ """
641
  )
 
642
  with demo:
643
+ gr.Markdown("""
644
+ # 💻 ReconViaGen
645
+ <p align="center">
646
+ <a title="Github" href="https://github.com/GAP-LAB-CUHK-SZ/ReconViaGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
647
+ <img src="https://img.shields.io/github/stars/GAP-LAB-CUHK-SZ/ReconViaGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
648
+ </a>
649
+ <a title="Website" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
650
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
651
+ </a>
652
+ <a title="arXiv" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
653
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
654
+ </a>
655
+ </p>
656
+
657
+ ✨This demo is partial. We will release the whole model later. Stay tuned!✨
658
+ """)
659
+
660
  with gr.Row():
661
  with gr.Column():
662
  with gr.Tabs() as input_tabs:
663
+ with gr.Tab(label="Input Video or Images", id=0) as multiimage_input_tab:
664
  input_video = gr.Video(label="Upload Video", interactive=True, height=300)
665
+ image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
666
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
667
+ gr.Markdown("""
668
+ Input different views of the object in separate images.
669
+ """)
670
+
671
+ with gr.Accordion(label="Generation Settings", open=False):
672
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
673
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
674
+ gr.Markdown("Stage 1: Sparse Structure Generation")
675
+ with gr.Row():
676
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
677
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=30, step=1)
678
+ gr.Markdown("Stage 2: Structured Latent Generation")
679
+ with gr.Row():
680
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
681
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
682
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="multidiffusion")
683
+ refine = gr.Radio(["Yes", "No"], label="Refinement of Not", value="Yes")
684
+ ss_refine = gr.Radio(["noise", "deltav", "No"], label="Sparse Structure refinement of not", value="No")
685
+ registration_num_frames = gr.Slider(20, 50, label="Number of frames in registration", value=30, step=1)
686
+ trellis_stage1_lr = gr.Slider(1e-4, 1., label="trellis_stage1_lr", value=1e-1, step=5e-4)
687
+ trellis_stage1_start_t = gr.Slider(0., 1., label="trellis_stage1_start_t", value=0.5, step=0.01)
688
+ trellis_stage2_lr = gr.Slider(1e-4, 1., label="trellis_stage2_lr", value=1e-1, step=5e-4)
689
+ trellis_stage2_start_t = gr.Slider(0., 1., label="trellis_stage2_start_t", value=0.5, step=0.01)
690
+
691
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
692
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
693
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
694
+
695
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
696
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
697
+ gr.Markdown("""
698
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
699
+ """)
700
+
701
  with gr.Column():
702
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
703
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
704
+
705
  with gr.Row():
706
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
707
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
 
708
 
709
+ output_buf = gr.State()
710
+
711
+ # Example images at the bottom of the page
712
+ with gr.Row() as multiimage_example:
713
+ examples_multi = gr.Examples(
714
+ examples=prepare_multi_example(),
715
+ inputs=[image_prompt],
716
+ fn=split_image,
717
+ outputs=[multiimage_prompt],
718
+ run_on_click=True,
719
+ examples_per_page=8,
720
+ )
721
 
722
+ # Handlers
723
  demo.load(start_session)
724
  demo.unload(end_session)
725
+
726
+ input_video.upload(
727
+ preprocess_videos,
728
+ inputs=[input_video],
729
+ outputs=[multiimage_prompt],
730
+ )
731
+ input_video.clear(
732
+ lambda: tuple([None, None]),
733
+ outputs=[input_video, multiimage_prompt],
734
+ )
735
+ multiimage_prompt.upload(
736
+ preprocess_images,
737
+ inputs=[multiimage_prompt],
738
+ outputs=[multiimage_prompt],
739
+ )
740
+
741
+ generate_btn.click(
742
+ get_seed,
743
+ inputs=[randomize_seed, seed],
744
+ outputs=[seed],
745
+ ).then(
746
  generate_and_extract_glb,
747
+ inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps,
748
+ slat_guidance_strength, slat_sampling_steps, multiimage_algo,
749
+ mesh_simplify, texture_size, refine, ss_refine, registration_num_frames,
750
+ trellis_stage1_lr, trellis_stage1_start_t, trellis_stage2_lr,
751
+ trellis_stage2_start_t],
752
  outputs=[output_buf, video_output, model_output, download_glb],
753
+ ).then(
754
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
755
+ outputs=[extract_gs_btn, download_glb],
756
+ )
757
 
758
+ video_output.clear(
759
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]),
760
+ outputs=[extract_gs_btn, download_glb, download_gs],
761
+ )
762
+
763
+ extract_gs_btn.click(
764
+ extract_gaussian,
765
+ inputs=[output_buf],
766
+ outputs=[model_output, download_gs],
767
+ ).then(
768
+ lambda: gr.Button(interactive=True),
769
+ outputs=[download_gs],
770
+ )
771
+
772
+ model_output.clear(
773
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
774
+ outputs=[download_glb, download_gs],
775
+ )
776
+
777
+
778
+ # Launch the Gradio app - VRAM 4개 분산 최적화 적용
779
  if __name__ == "__main__":
 
780
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
781
 
782
  num_gpus = torch.cuda.device_count()
783
  if num_gpus >= 4:
784
+ # [VRAM 분산 핵심] 모델을 물리적으로 다른 GPU에 로드하여 OOM 방지
785
+ pipeline.to("cuda:0") # 메인 파이프라인
786
+ if hasattr(pipeline, 'VGGT_model'):
787
+ pipeline.VGGT_model.to("cuda:1")
788
+ if hasattr(pipeline, 'birefnet_model'):
789
+ pipeline.birefnet_model.to("cuda:2")
790
+ # 가장 무거운 디코더들을 3번 GPU로 격리
791
+ if hasattr(pipeline, 'slat_decoder'):
792
+ pipeline.slat_decoder.to("cuda:3")
793
+ if hasattr(pipeline, 'sparse_structure_decoder'):
794
+ pipeline.sparse_structure_decoder.to("cuda:3")
795
+
796
+ # Mast3r 모델 로드 (cuda:0 사용)
797
+ mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to("cuda:0").eval()
798
+ print("Success: 4 GPU VRAM Sharding Activated.")
799
  else:
800
  pipeline.cuda()
801
+ mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").cuda().eval()
802
+
803
  demo.launch()