notenoughram commited on
Commit
ec09383
·
verified ·
1 Parent(s): 8e3a585

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -276
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import gradio as gr
2
- import spaces
3
- from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
6
 
@@ -14,6 +12,7 @@ from PIL import Image
14
  from trellis.pipelines import TrellisVGGTTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
@@ -25,43 +24,14 @@ def start_session(req: gr.Request):
25
 
26
  def end_session(req: gr.Request):
27
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
- # 폴더가 존재할 때만 삭제하도록 수정 (FileNotFoundError 방지)
29
  if os.path.exists(user_dir):
30
  shutil.rmtree(user_dir)
31
 
32
- @spaces.GPU
33
  def preprocess_image(image: Image.Image) -> Image.Image:
34
- """
35
- Preprocess the input image for 3D generation.
36
-
37
- This function is called when a user uploads an image or selects an example.
38
- It applies background removal and other preprocessing steps necessary for
39
- optimal 3D model generation.
40
-
41
- Args:
42
- image (Image.Image): The input image from the user
43
-
44
- Returns:
45
- Image.Image: The preprocessed image ready for 3D generation
46
- """
47
  processed_image = pipeline.preprocess_image(image)
48
  return processed_image
49
 
50
- @spaces.GPU
51
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
52
- """
53
- Preprocess the input video for multi-image 3D generation.
54
-
55
- This function is called when a user uploads a video.
56
- It extracts frames from the video and processes each frame to prepare them
57
- for the multi-image 3D generation pipeline.
58
-
59
- Args:
60
- video (str): The path to the input video file
61
-
62
- Returns:
63
- List[Tuple[Image.Image, str]]: The list of preprocessed images ready for 3D generation
64
- """
65
  vid = imageio.get_reader(video, 'ffmpeg')
66
  fps = vid.get_meta_data()['fps']
67
  images = []
@@ -75,25 +45,11 @@ def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
75
  processed_images = [pipeline.preprocess_image(image) for image in images]
76
  return processed_images
77
 
78
- @spaces.GPU
79
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
80
- """
81
- Preprocess a list of input images for multi-image 3D generation.
82
-
83
- This function is called when users upload multiple images in the gallery.
84
- It processes each image to prepare them for the multi-image 3D generation pipeline.
85
-
86
- Args:
87
- images (List[Tuple[Image.Image, str]]): The input images from the gallery
88
-
89
- Returns:
90
- List[Image.Image]: The preprocessed images ready for 3D generation
91
- """
92
  images = [image[0] for image in images]
93
  processed_images = [pipeline.preprocess_image(image) for image in images]
94
  return processed_images
95
 
96
-
97
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
98
  return {
99
  'gaussian': {
@@ -109,8 +65,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
109
  'faces': mesh.faces.cpu().numpy(),
110
  },
111
  }
112
-
113
-
114
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
115
  gs = Gaussian(
116
  aabb=state['gaussian']['aabb'],
@@ -120,38 +75,22 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
120
  opacity_bias=state['gaussian']['opacity_bias'],
121
  scaling_activation=state['gaussian']['scaling_activation'],
122
  )
123
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
124
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
125
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
126
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
127
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
 
128
 
129
  mesh = edict(
130
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
131
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
132
  )
133
-
134
  return gs, mesh
135
 
136
-
137
  def get_seed(randomize_seed: bool, seed: int) -> int:
138
- """
139
- Get the random seed for generation.
140
-
141
- This function is called by the generate button to determine whether to use
142
- a random seed or the user-specified seed value.
143
-
144
- Args:
145
- randomize_seed (bool): Whether to generate a random seed
146
- seed (int): The user-specified seed value
147
-
148
- Returns:
149
- int: The seed to use for generation
150
- """
151
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
152
 
153
-
154
- @spaces.GPU(duration=120)
155
  def generate_and_extract_glb(
156
  multiimages: List[Tuple[Image.Image, str]],
157
  seed: int,
@@ -164,32 +103,9 @@ def generate_and_extract_glb(
164
  texture_size: int,
165
  req: gr.Request,
166
  ) -> Tuple[dict, str, str, str]:
167
- """
168
- Convert an image to a 3D model and extract GLB file.
169
-
170
- Args:
171
- image (Image.Image): The input image.
172
- multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
173
- is_multiimage (bool): Whether is in multi-image mode.
174
- seed (int): The random seed.
175
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
176
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
177
- slat_guidance_strength (float): The guidance strength for structured latent generation.
178
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
179
- multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
180
- mesh_simplify (float): The mesh simplification factor.
181
- texture_size (int): The texture resolution.
182
-
183
- Returns:
184
- dict: The information of the generated 3D model.
185
- str: The path to the video of the 3D model.
186
- str: The path to the extracted GLB file.
187
- str: The path to the extracted GLB file (for download).
188
- """
189
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
190
  image_files = [image[0] for image in multiimages]
191
 
192
- # Generate 3D model
193
  outputs, _, _ = pipeline.run(
194
  image=image_files,
195
  seed=seed,
@@ -206,49 +122,23 @@ def generate_and_extract_glb(
206
  mode=multiimage_algo,
207
  )
208
 
209
- # Render video
210
- # import uuid
211
- # output_id = str(uuid.uuid4())
212
- # os.makedirs(f"{TMP_DIR}/{output_id}", exist_ok=True)
213
- # video_path = f"{TMP_DIR}/{output_id}/preview.mp4"
214
- # glb_path = f"{TMP_DIR}/{output_id}/mesh.glb"
215
-
216
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
217
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
218
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
219
  video_path = os.path.join(user_dir, 'sample.mp4')
220
  imageio.mimsave(video_path, video, fps=15)
221
 
222
- # Extract GLB
223
  gs = outputs['gaussian'][0]
224
  mesh = outputs['mesh'][0]
225
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
226
  glb_path = os.path.join(user_dir, 'sample.glb')
227
  glb.export(glb_path)
228
 
229
- # Pack state for optional Gaussian extraction
230
  state = pack_state(gs, mesh)
231
-
232
  torch.cuda.empty_cache()
233
  return state, video_path, glb_path, glb_path
234
 
235
-
236
- @spaces.GPU
237
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
238
- """
239
- Extract a Gaussian splatting file from the generated 3D model.
240
-
241
- This function is called when the user clicks "Extract Gaussian" button.
242
- It converts the 3D model state into a .ply file format containing
243
- Gaussian splatting data for advanced 3D applications.
244
-
245
- Args:
246
- state (dict): The state of the generated 3D model containing Gaussian data
247
- req (gr.Request): Gradio request object for session management
248
-
249
- Returns:
250
- Tuple[str, str]: Paths to the extracted Gaussian file (for display and download)
251
- """
252
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
253
  gs, _ = unpack_state(state)
254
  gaussian_path = os.path.join(user_dir, 'sample.ply')
@@ -256,15 +146,16 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
256
  torch.cuda.empty_cache()
257
  return gaussian_path, gaussian_path
258
 
259
-
260
  def prepare_multi_example() -> List[Image.Image]:
 
261
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
262
  images = []
263
  for case in multi_case:
264
  _images = []
265
  for i in range(1, 9):
266
- if os.path.exists(f'assets/example_multi_image/{case}_{i}.png'):
267
- img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
 
268
  W, H = img.size
269
  img = img.resize((int(W / H * 512), 512))
270
  _images.append(np.array(img))
@@ -272,21 +163,7 @@ def prepare_multi_example() -> List[Image.Image]:
272
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
273
  return images
274
 
275
-
276
  def split_image(image: Image.Image) -> List[Image.Image]:
277
- """
278
- Split a multi-view image into separate view images.
279
-
280
- This function is called when users select multi-image examples that contain
281
- multiple views in a single concatenated image. It automatically splits them
282
- based on alpha channel boundaries and preprocesses each view.
283
-
284
- Args:
285
- image (Image.Image): A concatenated image containing multiple views
286
-
287
- Returns:
288
- List[Image.Image]: List of individual preprocessed view images
289
- """
290
  image = np.array(image)
291
  alpha = image[..., 3]
292
  alpha = np.any(alpha>0, axis=0)
@@ -297,168 +174,72 @@ def split_image(image: Image.Image) -> List[Image.Image]:
297
  images.append(Image.fromarray(image[:, s:e+1]))
298
  return [preprocess_image(image) for image in images]
299
 
300
- # Create interface
301
  demo = gr.Blocks(
302
  title="ReconViaGen",
303
- css="""
304
- .slider .inner { width: 5px; background: #FFF; }
305
- .viewport { aspect-ratio: 4/3; }
306
- .tabs button.selected { font-size: 20px !important; color: crimson !important; }
307
- h1, h2, h3 { text-align: center; display: block; }
308
- .md_feedback li { margin-bottom: 0px !important; }
309
- """
310
  )
 
311
  with demo:
312
- gr.Markdown("""
313
- # 💻 ReconViaGen
314
- <p align="center">
315
- <a title="Github" href="https://github.com/GAP-LAB-CUHK-SZ/ReconViaGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
316
- <img src="https://img.shields.io/github/stars/GAP-LAB-CUHK-SZ/ReconViaGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
317
- </a>
318
- <a title="Website" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
319
- <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
320
- </a>
321
- <a title="arXiv" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
322
- <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
323
- </a>
324
- </p>
325
-
326
- ✨This demo is partial. We will release the whole model later. Stay tuned!✨
327
- """)
328
-
329
  with gr.Row():
330
  with gr.Column():
331
  with gr.Tabs() as input_tabs:
332
- with gr.Tab(label="Input Video or Images", id=0) as multiimage_input_tab:
333
  input_video = gr.Video(label="Upload Video", interactive=True, height=300)
334
- image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
335
- multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
336
- gr.Markdown("""
337
- Input different views of the object in separate images.
338
- """)
339
-
340
- with gr.Accordion(label="Generation Settings", open=False):
341
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
342
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
343
- gr.Markdown("Stage 1: Sparse Structure Generation")
344
- with gr.Row():
345
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
346
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=30, step=1)
347
- gr.Markdown("Stage 2: Structured Latent Generation")
348
- with gr.Row():
349
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
350
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
351
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="multidiffusion")
352
-
353
- with gr.Accordion(label="GLB Extraction Settings", open=False):
354
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
355
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
356
-
357
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
358
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
359
- gr.Markdown("""
360
- *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
361
- """)
362
-
363
  with gr.Column():
364
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
365
- model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
366
-
367
  with gr.Row():
368
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
369
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
370
-
371
  output_buf = gr.State()
 
 
 
372
 
373
- # Example images at the bottom of the page
374
- with gr.Row() as multiimage_example:
375
- examples_multi = gr.Examples(
376
- examples=prepare_multi_example(),
377
- inputs=[image_prompt],
378
- fn=split_image,
379
- outputs=[multiimage_prompt],
380
- run_on_click=True,
381
- examples_per_page=8,
382
- )
383
-
384
- # Handlers
385
  demo.load(start_session)
386
  demo.unload(end_session)
387
-
388
- input_video.upload(
389
- preprocess_videos,
390
- inputs=[input_video],
391
- outputs=[multiimage_prompt],
392
- )
393
- input_video.clear(
394
- lambda: tuple([None, None]),
395
- outputs=[input_video, multiimage_prompt],
396
- )
397
- multiimage_prompt.upload(
398
- preprocess_images,
399
- inputs=[multiimage_prompt],
400
- outputs=[multiimage_prompt],
401
- )
402
-
403
- generate_btn.click(
404
- get_seed,
405
- inputs=[randomize_seed, seed],
406
- outputs=[seed],
407
- ).then(
408
  generate_and_extract_glb,
409
  inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
410
  outputs=[output_buf, video_output, model_output, download_glb],
411
- ).then(
412
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
413
- outputs=[extract_gs_btn, download_glb],
414
- )
415
-
416
- video_output.clear(
417
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]),
418
- outputs=[extract_gs_btn, download_glb, download_gs],
419
- )
420
-
421
- extract_gs_btn.click(
422
- extract_gaussian,
423
- inputs=[output_buf],
424
- outputs=[model_output, download_gs],
425
- ).then(
426
- lambda: gr.Button(interactive=True),
427
- outputs=[download_gs],
428
- )
429
-
430
- model_output.clear(
431
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
432
- outputs=[download_glb, download_gs],
433
- )
434
-
435
 
 
436
  if __name__ == "__main__":
437
- # 1. 모델 로드
438
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
439
 
440
- # 2. 멀티 GPU 강제 설정
441
- if torch.cuda.is_available():
442
- device = torch.device("cuda:0")
443
- # 모든 하위 모델들을 먼저 0번 GPU 확실히 보냅니다
444
- pipeline.to(device)
445
-
446
- num_gpus = torch.cuda.device_count()
447
- if num_gpus > 1:
448
- print(f"--- 멀티 GPU 활성화: {num_gpus}개의 GPU를 사용합니다 ---")
449
- # 에러가 났던 birefnet과 주요 모델들을 DataParallel로 래핑
450
- try:
451
- if hasattr(pipeline, 'VGGT_model'):
452
- pipeline.VGGT_model = torch.nn.DataParallel(pipeline.VGGT_model).cuda()
453
- if hasattr(pipeline, 'birefnet_model'):
454
- pipeline.birefnet_model = torch.nn.DataParallel(pipeline.birefnet_model).cuda()
455
- if hasattr(pipeline, 'sparse_structure_decoder'):
456
- pipeline.sparse_structure_decoder = torch.nn.DataParallel(pipeline.sparse_structure_decoder).cuda()
457
- if hasattr(pipeline, 'slat_decoder'):
458
- pipeline.slat_decoder = torch.nn.DataParallel(pipeline.slat_decoder).cuda()
459
- except Exception as e:
460
- print(f"멀티 GPU 설정 중 경고 발생(단일 GPU로 전환): {e}")
461
- pipeline.to(device)
462
-
463
- # 3. 앱 실행
464
  demo.launch()
 
1
  import gradio as gr
 
 
2
  import os
3
  import shutil
4
 
 
12
  from trellis.pipelines import TrellisVGGTTo3DPipeline
13
  from trellis.representations import Gaussian, MeshExtractResult
14
  from trellis.utils import render_utils, postprocessing_utils
15
+ from gradio_litmodel3d import LitModel3D
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
24
 
25
  def end_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
27
  if os.path.exists(user_dir):
28
  shutil.rmtree(user_dir)
29
 
 
30
  def preprocess_image(image: Image.Image) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  processed_image = pipeline.preprocess_image(image)
32
  return processed_image
33
 
 
34
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  vid = imageio.get_reader(video, 'ffmpeg')
36
  fps = vid.get_meta_data()['fps']
37
  images = []
 
45
  processed_images = [pipeline.preprocess_image(image) for image in images]
46
  return processed_images
47
 
 
48
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
49
  images = [image[0] for image in images]
50
  processed_images = [pipeline.preprocess_image(image) for image in images]
51
  return processed_images
52
 
 
53
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
54
  return {
55
  'gaussian': {
 
65
  'faces': mesh.faces.cpu().numpy(),
66
  },
67
  }
68
+
 
69
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
70
  gs = Gaussian(
71
  aabb=state['gaussian']['aabb'],
 
75
  opacity_bias=state['gaussian']['opacity_bias'],
76
  scaling_activation=state['gaussian']['scaling_activation'],
77
  )
78
+ # 추론 메인 장치인 cuda:0으로 데이터를 보냅니다.
79
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda:0')
80
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda:0')
81
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda:0')
82
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda:0')
83
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda:0')
84
 
85
  mesh = edict(
86
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda:0'),
87
+ faces=torch.tensor(state['mesh']['faces'], device='cuda:0'),
88
  )
 
89
  return gs, mesh
90
 
 
91
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
93
 
 
 
94
  def generate_and_extract_glb(
95
  multiimages: List[Tuple[Image.Image, str]],
96
  seed: int,
 
103
  texture_size: int,
104
  req: gr.Request,
105
  ) -> Tuple[dict, str, str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
107
  image_files = [image[0] for image in multiimages]
108
 
 
109
  outputs, _, _ = pipeline.run(
110
  image=image_files,
111
  seed=seed,
 
122
  mode=multiimage_algo,
123
  )
124
 
 
 
 
 
 
 
 
125
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
126
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
127
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
128
  video_path = os.path.join(user_dir, 'sample.mp4')
129
  imageio.mimsave(video_path, video, fps=15)
130
 
 
131
  gs = outputs['gaussian'][0]
132
  mesh = outputs['mesh'][0]
133
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
134
  glb_path = os.path.join(user_dir, 'sample.glb')
135
  glb.export(glb_path)
136
 
 
137
  state = pack_state(gs, mesh)
 
138
  torch.cuda.empty_cache()
139
  return state, video_path, glb_path, glb_path
140
 
 
 
141
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
143
  gs, _ = unpack_state(state)
144
  gaussian_path = os.path.join(user_dir, 'sample.ply')
 
146
  torch.cuda.empty_cache()
147
  return gaussian_path, gaussian_path
148
 
 
149
  def prepare_multi_example() -> List[Image.Image]:
150
+ if not os.path.exists("assets/example_multi_image"): return []
151
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
152
  images = []
153
  for case in multi_case:
154
  _images = []
155
  for i in range(1, 9):
156
+ path = f'assets/example_multi_image/{case}_{i}.png'
157
+ if os.path.exists(path):
158
+ img = Image.open(path)
159
  W, H = img.size
160
  img = img.resize((int(W / H * 512), 512))
161
  _images.append(np.array(img))
 
163
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
164
  return images
165
 
 
166
  def split_image(image: Image.Image) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  image = np.array(image)
168
  alpha = image[..., 3]
169
  alpha = np.any(alpha>0, axis=0)
 
174
  images.append(Image.fromarray(image[:, s:e+1]))
175
  return [preprocess_image(image) for image in images]
176
 
177
+ # --- Gradio UI ---
178
  demo = gr.Blocks(
179
  title="ReconViaGen",
180
+ css=".slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; }"
 
 
 
 
 
 
181
  )
182
+
183
  with demo:
184
+ gr.Markdown("# 💻 ReconViaGen\n✨This demo is partial. Stay tuned!✨")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  with gr.Row():
186
  with gr.Column():
187
  with gr.Tabs() as input_tabs:
188
+ with gr.Tab(label="Input Video or Images", id=0):
189
  input_video = gr.Video(label="Upload Video", interactive=True, height=300)
190
+ image_prompt = gr.Image(label="Image Prompt", visible=False, type="pil", height=300)
191
+ multiimage_prompt = gr.Gallery(label="Image Prompt", columns=3)
192
+ with gr.Accordion(label="Settings", open=False):
 
 
 
 
193
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
194
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
195
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="SS Guidance", value=7.5)
196
+ ss_sampling_steps = gr.Slider(1, 50, label="SS Steps", value=30)
197
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Slat Guidance", value=3.0)
198
+ slat_sampling_steps = gr.Slider(1, 50, label="Slat Steps", value=12)
199
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], value="multidiffusion")
200
+ mesh_simplify = gr.Slider(0.9, 0.98, value=0.95)
201
+ texture_size = gr.Slider(512, 2048, value=1024, step=512)
 
 
 
 
 
 
 
202
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
203
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
 
 
 
204
  with gr.Column():
205
+ video_output = gr.Video(label="Generated 3D Asset")
206
+ model_output = LitModel3D(label="Extracted GLB/Gaussian")
 
207
  with gr.Row():
208
+ download_glb = gr.DownloadButton("Download GLB", interactive=False)
209
+ download_gs = gr.DownloadButton("Download Gaussian", interactive=False)
 
210
  output_buf = gr.State()
211
+
212
+ with gr.Row():
213
+ gr.Examples(examples=prepare_multi_example(), inputs=[image_prompt], fn=split_image, outputs=[multiimage_prompt], run_on_click=True)
214
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  demo.load(start_session)
216
  demo.unload(end_session)
217
+ input_video.upload(preprocess_videos, inputs=[input_video], outputs=[multiimage_prompt])
218
+ input_video.clear(lambda: (None, None), outputs=[input_video, multiimage_prompt])
219
+ multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
220
+ generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  generate_and_extract_glb,
222
  inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
223
  outputs=[output_buf, video_output, model_output, download_glb],
224
+ ).then(lambda: (gr.Button(interactive=True), gr.Button(interactive=True)), outputs=[extract_gs_btn, download_glb])
225
+ extract_gs_btn.click(extract_gaussian, inputs=[output_buf], outputs=[model_output, download_gs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ # --- 메인 실행부 (VRAM 분산 최적화) ---
228
  if __name__ == "__main__":
229
+ # 모델 로드
230
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
231
 
232
+ num_gpus = torch.cuda.device_count()
233
+ if num_gpus >= 4:
234
+ print(f"--- 4 GPUs Detected: Splitting Models to prevent VRAM Error ---")
235
+ # 모델의 부분을 서로 다른 GPU 메모리에 적재하여 1개 GPU의 부담을 줄임
236
+ pipeline.to("cuda:0")
237
+ if hasattr(pipeline, 'VGGT_model'): pipeline.VGGT_model.to("cuda:1")
238
+ if hasattr(pipeline, 'birefnet_model'): pipeline.birefnet_model.to("cuda:2")
239
+ # 가장 무거운 디코더들은 3번 GPU로 격리
240
+ if hasattr(pipeline, 'slat_decoder'): pipeline.slat_decoder.to("cuda:3")
241
+ if hasattr(pipeline, 'sparse_structure_decoder'): pipeline.sparse_structure_decoder.to("cuda:3")
242
+ else:
243
+ pipeline.cuda()
244
+
 
 
 
 
 
 
 
 
 
 
 
245
  demo.launch()