notenoughram commited on
Commit
8bcd01f
ยท
verified ยท
1 Parent(s): 7ae84cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -226
app.py CHANGED
@@ -1,29 +1,30 @@
1
  import os
2
- # [์ค‘์š”] OOM ๋ฐฉ์ง€๋ฅผ ์œ„ํ•œ ๋ฉ”๋ชจ๋ฆฌ ํŒŒํŽธํ™” ์„ค์ • (ํ† ์น˜ ๋กœ๋“œ ์ „์— ์„ค์ •ํ•ด์•ผ ํ•จ)
3
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
4
- os.environ['SPCONV_ALGO'] = 'native'
5
-
6
- import gradio as gr
7
- from gradio_litmodel3d import LitModel3D
8
-
9
  import shutil
10
- from typing import *
11
  import torch
 
12
  import numpy as np
13
  import imageio
14
- import gc # ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰์…˜ ์ถ”๊ฐ€
15
  from easydict import EasyDict as edict
16
  from PIL import Image
 
 
 
 
 
 
 
17
  from trellis.pipelines import TrellisVGGTTo3DPipeline
18
  from trellis.representations import Gaussian, MeshExtractResult
19
  from trellis.utils import render_utils, postprocessing_utils
20
 
21
- # [์ค‘์š”] ๋ชจ๋ธ ๋ถ„์‚ฐ์„ ์œ„ํ•œ accelerate ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ฒดํฌ
22
  try:
23
- from accelerate import dispatch_model
24
- ACCELERATE_AVAILABLE = True
25
  except ImportError:
26
- ACCELERATE_AVAILABLE = False
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
@@ -33,46 +34,19 @@ def start_session(req: gr.Request):
33
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
  os.makedirs(user_dir, exist_ok=True)
35
 
36
-
37
  def end_session(req: gr.Request):
38
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
39
  if os.path.exists(user_dir):
40
  shutil.rmtree(user_dir)
41
- # ์„ธ์…˜ ์ข…๋ฃŒ ์‹œ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
42
  gc.collect()
43
  torch.cuda.empty_cache()
44
 
45
  def preprocess_image(image: Image.Image) -> Image.Image:
46
- """
47
- Preprocess the input image for 3D generation.
48
-
49
- This function is called when a user uploads an image or selects an example.
50
- It applies background removal and other preprocessing steps necessary for
51
- optimal 3D model generation.
52
-
53
- Args:
54
- image (Image.Image): The input image from the user
55
-
56
- Returns:
57
- Image.Image: The preprocessed image ready for 3D generation
58
- """
59
  processed_image = pipeline.preprocess_image(image)
60
  return processed_image
61
 
62
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
63
- """
64
- Preprocess the input video for multi-image 3D generation.
65
-
66
- This function is called when a user uploads a video.
67
- It extracts frames from the video and processes each frame to prepare them
68
- for the multi-image 3D generation pipeline.
69
-
70
- Args:
71
- video (str): The path to the input video file
72
-
73
- Returns:
74
- List[Tuple[Image.Image, str]]: The list of preprocessed images ready for 3D generation
75
- """
76
  vid = imageio.get_reader(video, 'ffmpeg')
77
  fps = vid.get_meta_data()['fps']
78
  images = []
@@ -87,23 +61,10 @@ def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
87
  return processed_images
88
 
89
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
90
- """
91
- Preprocess a list of input images for multi-image 3D generation.
92
-
93
- This function is called when users upload multiple images in the gallery.
94
- It processes each image to prepare them for the multi-image 3D generation pipeline.
95
-
96
- Args:
97
- images (List[Tuple[Image.Image, str]]): The input images from the gallery
98
-
99
- Returns:
100
- List[Image.Image]: The preprocessed images ready for 3D generation
101
- """
102
  images = [image[0] for image in images]
103
  processed_images = [pipeline.preprocess_image(image) for image in images]
104
  return processed_images
105
 
106
-
107
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
108
  return {
109
  'gaussian': {
@@ -119,11 +80,10 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
119
  'faces': mesh.faces.cpu().numpy(),
120
  },
121
  }
122
-
123
-
124
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
125
- # ์–ธํŒฉ ์‹œ ๋ฐ”๋กœ CUDA๋กœ ์˜ฌ๋ฆฌ๋ฉด ๋ฉ”๋ชจ๋ฆฌ ํŠˆ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ์ƒํ™ฉ์— ๋งž๊ฒŒ device ์„ค์ •
126
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
127
 
128
  gs = Gaussian(
129
  aabb=state['gaussian']['aabb'],
@@ -143,27 +103,11 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
143
  vertices=torch.tensor(state['mesh']['vertices'], device=device),
144
  faces=torch.tensor(state['mesh']['faces'], device=device),
145
  )
146
-
147
  return gs, mesh
148
 
149
-
150
  def get_seed(randomize_seed: bool, seed: int) -> int:
151
- """
152
- Get the random seed for generation.
153
-
154
- This function is called by the generate button to determine whether to use
155
- a random seed or the user-specified seed value.
156
-
157
- Args:
158
- randomize_seed (bool): Whether to generate a random seed
159
- seed (int): The user-specified seed value
160
-
161
- Returns:
162
- int: The seed to use for generation
163
- """
164
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
165
 
166
-
167
  def generate_and_extract_glb(
168
  multiimages: List[Tuple[Image.Image, str]],
169
  seed: int,
@@ -176,61 +120,41 @@ def generate_and_extract_glb(
176
  texture_size: int,
177
  req: gr.Request,
178
  ) -> Tuple[dict, str, str, str]:
179
- """
180
- Convert an image to a 3D model and extract GLB file.
181
-
182
- Args:
183
- image (Image.Image): The input image.
184
- multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
185
- is_multiimage (bool): Whether is in multi-image mode.
186
- seed (int): The random seed.
187
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
188
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
189
- slat_guidance_strength (float): The guidance strength for structured latent generation.
190
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
191
- multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
192
- mesh_simplify (float): The mesh simplification factor.
193
- texture_size (int): The texture resolution.
194
-
195
- Returns:
196
- dict: The information of the generated 3D model.
197
- str: The path to the video of the 3D model.
198
- str: The path to the extracted GLB file.
199
- str: The path to the extracted GLB file (for download).
200
- """
201
- # [์ˆ˜์ •] ์ถ”๋ก  ์‹œ์ž‘ ์ „ ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰์…˜ ์ˆ˜ํ–‰
202
  gc.collect()
203
  torch.cuda.empty_cache()
204
 
205
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
206
  image_files = [image[0] for image in multiimages]
207
 
208
- # Generate 3D model
209
- # [์ˆ˜์ •] torch.no_grad()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ถˆํ•„์š”ํ•œ ๊ทธ๋ผ๋””์–ธํŠธ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ ๋ฐฉ์ง€
210
- with torch.no_grad():
211
- outputs, _, _ = pipeline.run(
212
- image=image_files,
213
- seed=seed,
214
- formats=["gaussian", "mesh"],
215
- preprocess_image=False,
216
- sparse_structure_sampler_params={
217
- "steps": ss_sampling_steps,
218
- "cfg_strength": ss_guidance_strength,
219
- },
220
- slat_sampler_params={
221
- "steps": slat_sampling_steps,
222
- "cfg_strength": slat_guidance_strength,
223
- },
224
- mode=multiimage_algo,
225
- )
226
-
227
- # Render video
228
- # import uuid
229
- # output_id = str(uuid.uuid4())
230
- # os.makedirs(f"{TMP_DIR}/{output_id}", exist_ok=True)
231
- # video_path = f"{TMP_DIR}/{output_id}/preview.mp4"
232
- # glb_path = f"{TMP_DIR}/{output_id}/mesh.glb"
233
-
 
234
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
235
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
236
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -244,46 +168,27 @@ def generate_and_extract_glb(
244
  glb_path = os.path.join(user_dir, 'sample.glb')
245
  glb.export(glb_path)
246
 
247
- # Pack state for optional Gaussian extraction
248
  state = pack_state(gs, mesh)
249
 
250
- # [์ˆ˜์ •] ์‚ฌ์šฉ ๋๋‚œ ๋ณ€์ˆ˜ ๋ช…์‹œ์  ์‚ญ์ œ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
251
  del outputs, gs, mesh, glb
252
  gc.collect()
253
  torch.cuda.empty_cache()
254
 
255
  return state, video_path, glb_path, glb_path
256
 
257
-
258
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
259
- """
260
- Extract a Gaussian splatting file from the generated 3D model.
261
-
262
- This function is called when the user clicks "Extract Gaussian" button.
263
- It converts the 3D model state into a .ply file format containing
264
- Gaussian splatting data for advanced 3D applications.
265
-
266
- Args:
267
- state (dict): The state of the generated 3D model containing Gaussian data
268
- req (gr.Request): Gradio request object for session management
269
-
270
- Returns:
271
- Tuple[str, str]: Paths to the extracted Gaussian file (for display and download)
272
- """
273
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
274
  gs, _ = unpack_state(state)
275
  gaussian_path = os.path.join(user_dir, 'sample.ply')
276
  gs.save_ply(gaussian_path)
277
 
278
- # [์ˆ˜์ •] ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
279
  del gs
280
  torch.cuda.empty_cache()
281
-
282
  return gaussian_path, gaussian_path
283
 
284
-
285
  def prepare_multi_example() -> List[Image.Image]:
286
- # ์—๋Ÿฌ ๋ฐฉ์ง€์šฉ ๊ฒฝ๋กœ ์ฒดํฌ
287
  if not os.path.exists("assets/example_multi_image"):
288
  return []
289
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
@@ -300,21 +205,7 @@ def prepare_multi_example() -> List[Image.Image]:
300
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
301
  return images
302
 
303
-
304
  def split_image(image: Image.Image) -> List[Image.Image]:
305
- """
306
- Split a multi-view image into separate view images.
307
-
308
- This function is called when users select multi-image examples that contain
309
- multiple views in a single concatenated image. It automatically splits them
310
- based on alpha channel boundaries and preprocesses each view.
311
-
312
- Args:
313
- image (Image.Image): A concatenated image containing multiple views
314
-
315
- Returns:
316
- List[Image.Image]: List of individual preprocessed view images
317
- """
318
  image = np.array(image)
319
  alpha = image[..., 3]
320
  alpha = np.any(alpha>0, axis=0)
@@ -338,20 +229,7 @@ demo = gr.Blocks(
338
  )
339
  with demo:
340
  gr.Markdown("""
341
- # ๐Ÿ’ป ReconViaGen
342
- <p align="center">
343
- <a title="Github" href="https://github.com/GAP-LAB-CUHK-SZ/ReconViaGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
344
- <img src="https://img.shields.io/github/stars/GAP-LAB-CUHK-SZ/ReconViaGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
345
- </a>
346
- <a title="Website" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
347
- <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
348
- </a>
349
- <a title="arXiv" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
350
- <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
351
- </a>
352
- </p>
353
-
354
- โœจThis demo is partial. We will release the whole model later. Stay tuned!โœจ
355
  """)
356
 
357
  with gr.Row():
@@ -361,9 +239,6 @@ with demo:
361
  input_video = gr.Video(label="Upload Video", interactive=True, height=300)
362
  image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
363
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
364
- gr.Markdown("""
365
- Input different views of the object in separate images.
366
- """)
367
 
368
  with gr.Accordion(label="Generation Settings", open=False):
369
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
@@ -384,9 +259,6 @@ with demo:
384
 
385
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
386
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
387
- gr.Markdown("""
388
- *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
389
- """)
390
 
391
  with gr.Column():
392
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
@@ -398,7 +270,6 @@ with demo:
398
 
399
  output_buf = gr.State()
400
 
401
- # Example images at the bottom of the page
402
  with gr.Row() as multiimage_example:
403
  examples_multi = gr.Examples(
404
  examples=prepare_multi_example(),
@@ -413,25 +284,12 @@ with demo:
413
  demo.load(start_session)
414
  demo.unload(end_session)
415
 
416
- input_video.upload(
417
- preprocess_videos,
418
- inputs=[input_video],
419
- outputs=[multiimage_prompt],
420
- )
421
- input_video.clear(
422
- lambda: tuple([None, None]),
423
- outputs=[input_video, multiimage_prompt],
424
- )
425
- multiimage_prompt.upload(
426
- preprocess_images,
427
- inputs=[multiimage_prompt],
428
- outputs=[multiimage_prompt],
429
- )
430
 
431
  generate_btn.click(
432
- get_seed,
433
- inputs=[randomize_seed, seed],
434
- outputs=[seed],
435
  ).then(
436
  generate_and_extract_glb,
437
  inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
@@ -446,46 +304,50 @@ with demo:
446
  outputs=[extract_gs_btn, download_glb, download_gs],
447
  )
448
 
449
- extract_gs_btn.click(
450
- extract_gaussian,
451
- inputs=[output_buf],
452
- outputs=[model_output, download_gs],
453
- ).then(
454
- lambda: gr.Button(interactive=True),
455
- outputs=[download_gs],
456
  )
457
 
458
  model_output.clear(
459
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
460
  outputs=[download_glb, download_gs],
461
  )
462
-
463
 
464
- # Launch the Gradio app
465
  if __name__ == "__main__":
466
- print("Initializing Pipeline...")
 
 
467
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("esther11/trellis-vggt-v0-2")
468
- pipeline.cuda()
469
 
470
- # [์ˆ˜์ •] ๋ฉ€ํ‹ฐ GPU ์ฒ˜๋ฆฌ ๋กœ์ง (Model Parallelism)
471
- # ๊ธฐ์กด DataParallel์€ ๋ชจ๋ธ์„ ๋ณต์ œํ•˜์—ฌ VRAM์„ 2๋ฐฐ๋กœ ์“ฐ๋ฏ€๋กœ OOM ๋ฐœ์ƒํ•จ.
472
- # ๋Œ€์‹  accelerate์˜ dispatch_model์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ์—ฌ๋Ÿฌ GPU์— ๋ถ„ํ•  ์ ์žฌํ•ด์•ผ ํ•จ.
473
- if torch.cuda.device_count() > 1:
474
- if ACCELERATE_AVAILABLE:
475
- print(f"โšก Accelerate detected: {torch.cuda.device_count()} GPUs found.")
476
- print("Applying 'device_map=auto' to VGGT_model to split layers across GPUs (Memory Efficient).")
477
- # VGGT_model์ด ๊ฐ€์žฅ ๋ฌด๊ฑฐ์šฐ๋ฏ€๋กœ ์ด๋ฅผ ์—ฌ๋Ÿฌ GPU์— ์ชผ๊ฐœ์„œ ์˜ฌ๋ฆผ
478
- pipeline.VGGT_model = dispatch_model(pipeline.VGGT_model, device_map="auto")
479
- else:
480
- print("โš ๏ธ 'accelerate' library not found. Cannot split model across GPUs.")
481
- print("Installing 'accelerate' (`pip install accelerate`) is highly recommended for multi-GPU inference.")
482
- # ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์—†์œผ๋ฉด ๊ธฐ๋ณธ ๋™์ž‘(๋‹จ์ผ GPU or ๊ธฐ์กด ์ƒํƒœ) ์œ ์ง€
 
 
 
 
 
 
 
 
 
 
 
 
483
  else:
484
- print(f"Running on Single GPU: {torch.cuda.get_device_name(0)}")
 
 
485
 
486
- # ๋‚˜๋จธ์ง€ ๋ชจ๋ธ๋“ค๋„ CUDA๋กœ ์ด๋™ (accelerate ์ ์šฉ ์•ˆ๋œ ๊ฒฝ์šฐ)
487
- if not ACCELERATE_AVAILABLE or torch.cuda.device_count() <= 1:
488
- pipeline.VGGT_model.cuda()
489
- pipeline.birefnet_model.cuda()
490
-
491
  demo.launch()
 
1
  import os
2
+ import gc
 
 
 
 
 
 
3
  import shutil
 
4
  import torch
5
+ import torch.nn as nn
6
  import numpy as np
7
  import imageio
8
+ from typing import *
9
  from easydict import EasyDict as edict
10
  from PIL import Image
11
+
12
+ # [ํ•ต์‹ฌ 1] ๋ฉ”๋ชจ๋ฆฌ ํŒŒํŽธํ™” ๋ฐฉ์ง€ ๋ฐ 4090/L40S ๋“ฑ ์ตœ์‹  GPU ํ˜ธํ™˜์„ฑ ์„ค์ •
13
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
14
+ os.environ['SPCONV_ALGO'] = 'native'
15
+
16
+ import gradio as gr
17
+ from gradio_litmodel3d import LitModel3D
18
  from trellis.pipelines import TrellisVGGTTo3DPipeline
19
  from trellis.representations import Gaussian, MeshExtractResult
20
  from trellis.utils import render_utils, postprocessing_utils
21
 
22
+ # [ํ•ต์‹ฌ 2] Accelerate ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ํ•„์ˆ˜ ๋กœ๋“œ (์—†์œผ๋ฉด ์—๋Ÿฌ ๋ฐœ์ƒ)
23
  try:
24
+ from accelerate import dispatch_model, infer_auto_device_map
25
+ from accelerate.utils import get_balanced_memory
26
  except ImportError:
27
+ raise ImportError("Accelerate library is missing. Please run `pip install accelerate`.")
28
 
29
  MAX_SEED = np.iinfo(np.int32).max
30
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
34
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
35
  os.makedirs(user_dir, exist_ok=True)
36
 
 
37
  def end_session(req: gr.Request):
38
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
39
  if os.path.exists(user_dir):
40
  shutil.rmtree(user_dir)
 
41
  gc.collect()
42
  torch.cuda.empty_cache()
43
 
44
  def preprocess_image(image: Image.Image) -> Image.Image:
45
+ # ํŒŒ์ดํ”„๋ผ์ธ์ด ๋ถ„์‚ฐ๋˜์–ด ์žˆ์–ด๋„ ์ „์ฒ˜๋ฆฌ๋Š” CPU/GPU0์—์„œ ์ˆ˜ํ–‰๋จ
 
 
 
 
 
 
 
 
 
 
 
 
46
  processed_image = pipeline.preprocess_image(image)
47
  return processed_image
48
 
49
  def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  vid = imageio.get_reader(video, 'ffmpeg')
51
  fps = vid.get_meta_data()['fps']
52
  images = []
 
61
  return processed_images
62
 
63
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
64
  images = [image[0] for image in images]
65
  processed_images = [pipeline.preprocess_image(image) for image in images]
66
  return processed_images
67
 
 
68
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
69
  return {
70
  'gaussian': {
 
80
  'faces': mesh.faces.cpu().numpy(),
81
  },
82
  }
83
+
 
84
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
85
+ # ๊ฒฐ๊ณผ๋ฌผ ๋ Œ๋”๋ง ์‹œ์—๋Š” GPU ํ•˜๋‚˜(๋ณดํ†ต 0๋ฒˆ)๋‚˜ CPU๋กœ ๋ชจ์•„์•ผ ํ•จ
86
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
87
 
88
  gs = Gaussian(
89
  aabb=state['gaussian']['aabb'],
 
103
  vertices=torch.tensor(state['mesh']['vertices'], device=device),
104
  faces=torch.tensor(state['mesh']['faces'], device=device),
105
  )
 
106
  return gs, mesh
107
 
 
108
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
110
 
 
111
  def generate_and_extract_glb(
112
  multiimages: List[Tuple[Image.Image, str]],
113
  seed: int,
 
120
  texture_size: int,
121
  req: gr.Request,
122
  ) -> Tuple[dict, str, str, str]:
123
+
124
+ # [ํ•ต์‹ฌ 3] ์‹คํ–‰ ์ „ ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰์…˜์œผ๋กœ VRAM ํ™•๋ณด
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  gc.collect()
126
  torch.cuda.empty_cache()
127
 
128
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
129
  image_files = [image[0] for image in multiimages]
130
 
131
+ try:
132
+ # [ํ•ต์‹ฌ 4] ์ถ”๋ก  ๋ชจ๋“œ ๊ฐ•์ œ. Gradient ์ €์žฅ ์•ˆ ํ•จ (๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ)
133
+ with torch.no_grad():
134
+ outputs, _, _ = pipeline.run(
135
+ image=image_files,
136
+ seed=seed,
137
+ formats=["gaussian", "mesh"],
138
+ preprocess_image=False,
139
+ sparse_structure_sampler_params={
140
+ "steps": ss_sampling_steps,
141
+ "cfg_strength": ss_guidance_strength,
142
+ },
143
+ slat_sampler_params={
144
+ "steps": slat_sampling_steps,
145
+ "cfg_strength": slat_guidance_strength,
146
+ },
147
+ mode=multiimage_algo,
148
+ )
149
+ except Exception as e:
150
+ # ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ๋ฉ”๋ชจ๋ฆฌ ๋น„์šฐ๊ณ  ์—๋Ÿฌ ๋‹ค์‹œ ๋˜์ง
151
+ torch.cuda.empty_cache()
152
+ raise e
153
+
154
+ # Render video (ํ›„์ฒ˜๋ฆฌ๋Š” ๋‹จ์ผ GPU ํ˜น์€ CPU์—์„œ ์ˆ˜ํ–‰)
155
+ # ํ…์„œ๋“ค์ด ์—ฌ๋Ÿฌ GPU์— ํฉ์–ด์ ธ ์žˆ์„ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ .cpu() ํ˜น์€ ๋‹จ์ผ cuda๋กœ ๋ชจ์•„์•ผ ํ•จ
156
+ # render_utils ๋‚ด๋ถ€์—์„œ ์ฒ˜๋ฆฌํ•˜๊ฒ ์ง€๋งŒ ์•ˆ์ „ํ•˜๊ฒŒ ์ง„ํ–‰
157
+
158
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
159
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
160
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
168
  glb_path = os.path.join(user_dir, 'sample.glb')
169
  glb.export(glb_path)
170
 
171
+ # Pack state
172
  state = pack_state(gs, mesh)
173
 
174
+ # ๊ฒฐ๊ณผ๋ฌผ ๋ฐ˜ํ™˜ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
175
  del outputs, gs, mesh, glb
176
  gc.collect()
177
  torch.cuda.empty_cache()
178
 
179
  return state, video_path, glb_path, glb_path
180
 
 
181
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
183
  gs, _ = unpack_state(state)
184
  gaussian_path = os.path.join(user_dir, 'sample.ply')
185
  gs.save_ply(gaussian_path)
186
 
 
187
  del gs
188
  torch.cuda.empty_cache()
 
189
  return gaussian_path, gaussian_path
190
 
 
191
  def prepare_multi_example() -> List[Image.Image]:
 
192
  if not os.path.exists("assets/example_multi_image"):
193
  return []
194
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
 
205
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
206
  return images
207
 
 
208
  def split_image(image: Image.Image) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  image = np.array(image)
210
  alpha = image[..., 3]
211
  alpha = np.any(alpha>0, axis=0)
 
229
  )
230
  with demo:
231
  gr.Markdown("""
232
+ # ๐Ÿ’ป ReconViaGen (Multi-GPU Force Enabled)
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  """)
234
 
235
  with gr.Row():
 
239
  input_video = gr.Video(label="Upload Video", interactive=True, height=300)
240
  image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
241
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
 
 
 
242
 
243
  with gr.Accordion(label="Generation Settings", open=False):
244
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
 
259
 
260
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
261
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
 
 
262
 
263
  with gr.Column():
264
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
270
 
271
  output_buf = gr.State()
272
 
 
273
  with gr.Row() as multiimage_example:
274
  examples_multi = gr.Examples(
275
  examples=prepare_multi_example(),
 
284
  demo.load(start_session)
285
  demo.unload(end_session)
286
 
287
+ input_video.upload(preprocess_videos, inputs=[input_video], outputs=[multiimage_prompt])
288
+ input_video.clear(lambda: tuple([None, None]), outputs=[input_video, multiimage_prompt])
289
+ multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  generate_btn.click(
292
+ get_seed, inputs=[randomize_seed, seed], outputs=[seed]
 
 
293
  ).then(
294
  generate_and_extract_glb,
295
  inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
 
304
  outputs=[extract_gs_btn, download_glb, download_gs],
305
  )
306
 
307
+ extract_gs_btn.click(extract_gaussian, inputs=[output_buf], outputs=[model_output, download_gs]).then(
308
+ lambda: gr.Button(interactive=True), outputs=[download_gs]
 
 
 
 
 
309
  )
310
 
311
  model_output.clear(
312
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
313
  outputs=[download_glb, download_gs],
314
  )
 
315
 
316
+ # Launch Script
317
  if __name__ == "__main__":
318
+ print("๐Ÿš€ Initializing Pipeline...")
319
+ # [ํ•ต์‹ฌ 5] ๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ ์— device๋ฅผ ์ง€์ •ํ•˜์ง€ ์•Š๊ณ , CPU์— ๋‘ก๋‹ˆ๋‹ค.
320
+ # ์ด๋ ‡๊ฒŒ ํ•ด์•ผ accelerate๊ฐ€ '๋นˆ ๋ชจ๋ธ' ์ƒํƒœ ํ˜น์€ CPU ์ƒํƒœ์—์„œ ๊ฐ€์ ธ๊ฐ€์„œ ์ชผ๊ฐค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
321
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("esther11/trellis-vggt-v0-2")
 
322
 
323
+ gpu_count = torch.cuda.device_count()
324
+ print(f"โšก Detected {gpu_count} GPUs.")
325
+
326
+ if gpu_count > 1:
327
+ print("โšก Multi-GPU Mode Activated: Distributing model across all available GPUs.")
328
+
329
+ # [ํ•ต์‹ฌ 6] VGGT ๋ชจ๋ธ ๋ถ„์‚ฐ (๊ฐ€์žฅ ํฐ ๋ฉ”๋ชจ๋ฆฌ ์ฐจ์ง€)
330
+ # device_map="balanced"๋Š” ๊ฐ€๋Šฅํ•œ ํ•œ GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ณ ๋ฅด๊ฒŒ ์“ฐ๋ ค๊ณ  ๋…ธ๋ ฅํ•ฉ๋‹ˆ๋‹ค.
331
+ # "auto"๋Š” ์ฒซ ๋ฒˆ์งธ GPU๋ถ€ํ„ฐ ์ฑ„์šฐ๋ฏ€๋กœ OOM ์œ„ํ—˜์ด ์žˆ์„ ์ˆ˜ ์žˆ์–ด balanced๊ฐ€ ์•ˆ์ „ํ•ฉ๋‹ˆ๋‹ค.
332
+ pipeline.VGGT_model = dispatch_model(
333
+ pipeline.VGGT_model,
334
+ device_map="balanced"
335
+ )
336
+
337
+ # [ํ•ต์‹ฌ 7] SLAT ๋ชจ๋ธ ๋ถ„์‚ฐ (๊ทธ ๋‹ค์Œ์œผ๋กœ ํผ)
338
+ pipeline.slat_model = dispatch_model(
339
+ pipeline.slat_model,
340
+ device_map="balanced"
341
+ )
342
+
343
+ # [ํ•ต์‹ฌ 8] ๋‚˜๋จธ์ง€ ๊ฐ€๋ฒผ์šด ๋ชจ๋ธ๋“ค์€ GPU 0์— ํ• ๋‹น (๋ถ„์‚ฐ ์˜ค๋ฒ„ํ—ค๋“œ ๋ฐฉ์ง€)
344
+ pipeline.birefnet_model.to("cuda:0")
345
+
346
+ print("โœ… Model dispatched successfully via Accelerate.")
347
+
348
  else:
349
+ # ๋‹จ์ผ GPU์ด๊ฑฐ๋‚˜ GPU๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ (์‚ฌ์šฉ์ž๊ฐ€ ์›์น˜ ์•Š์ง€๋งŒ ์•ˆ์ „๋ง์œผ๋กœ ๋‘ )
350
+ print("โš ๏ธ Warning: Only 1 GPU detected. Loading entire model to CUDA:0 (High OOM Risk).")
351
+ pipeline.cuda()
352
 
 
 
 
 
 
353
  demo.launch()