notenoughram commited on
Commit
7ae84cb
·
verified ·
1 Parent(s): e25705e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -44
app.py CHANGED
@@ -1,28 +1,29 @@
 
 
 
 
 
1
  import gradio as gr
2
  from gradio_litmodel3d import LitModel3D
3
 
4
- import os
5
  import shutil
6
- os.environ['SPCONV_ALGO'] = 'native'
7
  from typing import *
8
  import torch
9
- import torch.nn as nn # nn 모듈 추가
10
  import numpy as np
11
  import imageio
 
12
  from easydict import EasyDict as edict
13
  from PIL import Image
14
  from trellis.pipelines import TrellisVGGTTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
18
- # --- [Fix] DataParallel 속성 접근 오류 해결을 위한 래퍼 클래스 ---
19
- class CustomDataParallel(nn.DataParallel):
20
- def __getattr__(self, name):
21
- try:
22
- return super().__getattr__(name)
23
- except AttributeError:
24
- return getattr(self.module, name)
25
- # -----------------------------------------------------------
26
 
27
  MAX_SEED = np.iinfo(np.int32).max
28
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
@@ -37,6 +38,9 @@ def end_session(req: gr.Request):
37
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
38
  if os.path.exists(user_dir):
39
  shutil.rmtree(user_dir)
 
 
 
40
 
41
  def preprocess_image(image: Image.Image) -> Image.Image:
42
  """
@@ -118,6 +122,9 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
118
 
119
 
120
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
 
 
 
121
  gs = Gaussian(
122
  aabb=state['gaussian']['aabb'],
123
  sh_degree=state['gaussian']['sh_degree'],
@@ -126,15 +133,15 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
126
  opacity_bias=state['gaussian']['opacity_bias'],
127
  scaling_activation=state['gaussian']['scaling_activation'],
128
  )
129
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
130
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
131
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
132
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
133
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
134
 
135
  mesh = edict(
136
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
137
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
138
  )
139
 
140
  return gs, mesh
@@ -191,25 +198,31 @@ def generate_and_extract_glb(
191
  str: The path to the extracted GLB file.
192
  str: The path to the extracted GLB file (for download).
193
  """
 
 
 
 
194
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
195
  image_files = [image[0] for image in multiimages]
196
 
197
  # Generate 3D model
198
- outputs, _, _ = pipeline.run(
199
- image=image_files,
200
- seed=seed,
201
- formats=["gaussian", "mesh"],
202
- preprocess_image=False,
203
- sparse_structure_sampler_params={
204
- "steps": ss_sampling_steps,
205
- "cfg_strength": ss_guidance_strength,
206
- },
207
- slat_sampler_params={
208
- "steps": slat_sampling_steps,
209
- "cfg_strength": slat_guidance_strength,
210
- },
211
- mode=multiimage_algo,
212
- )
 
 
213
 
214
  # Render video
215
  # import uuid
@@ -234,7 +247,11 @@ def generate_and_extract_glb(
234
  # Pack state for optional Gaussian extraction
235
  state = pack_state(gs, mesh)
236
 
 
 
 
237
  torch.cuda.empty_cache()
 
238
  return state, video_path, glb_path, glb_path
239
 
240
 
@@ -257,12 +274,16 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
257
  gs, _ = unpack_state(state)
258
  gaussian_path = os.path.join(user_dir, 'sample.ply')
259
  gs.save_ply(gaussian_path)
 
 
 
260
  torch.cuda.empty_cache()
 
261
  return gaussian_path, gaussian_path
262
 
263
 
264
  def prepare_multi_example() -> List[Image.Image]:
265
- # assets 경로 체크 추가 (에러 방지용)
266
  if not os.path.exists("assets/example_multi_image"):
267
  return []
268
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
@@ -445,19 +466,26 @@ if __name__ == "__main__":
445
  print("Initializing Pipeline...")
446
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("esther11/trellis-vggt-v0-2")
447
  pipeline.cuda()
448
- pipeline.VGGT_model.cuda()
449
- pipeline.birefnet_model.cuda()
450
 
451
- # --- [Fix] Multi-GPU Logic ---
 
 
452
  if torch.cuda.device_count() > 1:
453
- print(f"⚡ Multi-GPU Detected: {torch.cuda.device_count()} GPUs found.")
454
- print("Wrapping VGGT_model with CustomDataParallel to handle attributes correctly.")
455
-
456
- # VGGT_model CustomDataParallel로 감싸서 분산 처리
457
- # CustomDataParallel은 'aggregator' 같은 속성을 module 내부에서 찾아줌
458
- pipeline.VGGT_model = CustomDataParallel(pipeline.VGGT_model)
 
 
 
459
  else:
460
  print(f"Running on Single GPU: {torch.cuda.get_device_name(0)}")
461
- # -----------------------------
462
 
 
 
 
 
 
463
  demo.launch()
 
1
+ import os
2
+ # [중요] OOM 방지를 위한 메모리 파편화 설정 (토치 로드 전에 설정해야 함)
3
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
4
+ os.environ['SPCONV_ALGO'] = 'native'
5
+
6
  import gradio as gr
7
  from gradio_litmodel3d import LitModel3D
8
 
 
9
  import shutil
 
10
  from typing import *
11
  import torch
 
12
  import numpy as np
13
  import imageio
14
+ import gc # 가비지 컬렉션 추가
15
  from easydict import EasyDict as edict
16
  from PIL import Image
17
  from trellis.pipelines import TrellisVGGTTo3DPipeline
18
  from trellis.representations import Gaussian, MeshExtractResult
19
  from trellis.utils import render_utils, postprocessing_utils
20
 
21
+ # [중요] 모델 분산을 위한 accelerate 라이브러리 체크
22
+ try:
23
+ from accelerate import dispatch_model
24
+ ACCELERATE_AVAILABLE = True
25
+ except ImportError:
26
+ ACCELERATE_AVAILABLE = False
 
 
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
38
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
39
  if os.path.exists(user_dir):
40
  shutil.rmtree(user_dir)
41
+ # 세션 종료 시 메모리 정리
42
+ gc.collect()
43
+ torch.cuda.empty_cache()
44
 
45
  def preprocess_image(image: Image.Image) -> Image.Image:
46
  """
 
122
 
123
 
124
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
125
+ # 언팩 시 바로 CUDA로 올리면 메모리 튈 수 있으므로 상황에 맞게 device 설정
126
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
127
+
128
  gs = Gaussian(
129
  aabb=state['gaussian']['aabb'],
130
  sh_degree=state['gaussian']['sh_degree'],
 
133
  opacity_bias=state['gaussian']['opacity_bias'],
134
  scaling_activation=state['gaussian']['scaling_activation'],
135
  )
136
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device=device)
137
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device=device)
138
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device=device)
139
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device=device)
140
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device=device)
141
 
142
  mesh = edict(
143
+ vertices=torch.tensor(state['mesh']['vertices'], device=device),
144
+ faces=torch.tensor(state['mesh']['faces'], device=device),
145
  )
146
 
147
  return gs, mesh
 
198
  str: The path to the extracted GLB file.
199
  str: The path to the extracted GLB file (for download).
200
  """
201
+ # [수정] 추론 시작 전 가비지 컬렉션 수행
202
+ gc.collect()
203
+ torch.cuda.empty_cache()
204
+
205
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
206
  image_files = [image[0] for image in multiimages]
207
 
208
  # Generate 3D model
209
+ # [수정] torch.no_grad()를 사용하여 불필요한 그라디언트 메모리 사용 방지
210
+ with torch.no_grad():
211
+ outputs, _, _ = pipeline.run(
212
+ image=image_files,
213
+ seed=seed,
214
+ formats=["gaussian", "mesh"],
215
+ preprocess_image=False,
216
+ sparse_structure_sampler_params={
217
+ "steps": ss_sampling_steps,
218
+ "cfg_strength": ss_guidance_strength,
219
+ },
220
+ slat_sampler_params={
221
+ "steps": slat_sampling_steps,
222
+ "cfg_strength": slat_guidance_strength,
223
+ },
224
+ mode=multiimage_algo,
225
+ )
226
 
227
  # Render video
228
  # import uuid
 
247
  # Pack state for optional Gaussian extraction
248
  state = pack_state(gs, mesh)
249
 
250
+ # [수정] 사용 끝난 변수 명시적 삭제 및 메모리 정리
251
+ del outputs, gs, mesh, glb
252
+ gc.collect()
253
  torch.cuda.empty_cache()
254
+
255
  return state, video_path, glb_path, glb_path
256
 
257
 
 
274
  gs, _ = unpack_state(state)
275
  gaussian_path = os.path.join(user_dir, 'sample.ply')
276
  gs.save_ply(gaussian_path)
277
+
278
+ # [수정] 메모리 정리
279
+ del gs
280
  torch.cuda.empty_cache()
281
+
282
  return gaussian_path, gaussian_path
283
 
284
 
285
  def prepare_multi_example() -> List[Image.Image]:
286
+ # 에러 방지용 경로 체크
287
  if not os.path.exists("assets/example_multi_image"):
288
  return []
289
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
 
466
  print("Initializing Pipeline...")
467
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("esther11/trellis-vggt-v0-2")
468
  pipeline.cuda()
 
 
469
 
470
+ # [수정] 멀티 GPU 처리 로직 (Model Parallelism)
471
+ # 기존 DataParallel은 모델을 복제하여 VRAM을 2배로 쓰므로 OOM 발생함.
472
+ # 대신 accelerate의 dispatch_model을 사용하여 모델을 여러 GPU에 분할 적재해야 함.
473
  if torch.cuda.device_count() > 1:
474
+ if ACCELERATE_AVAILABLE:
475
+ print(f" Accelerate detected: {torch.cuda.device_count()} GPUs found.")
476
+ print("Applying 'device_map=auto' to VGGT_model to split layers across GPUs (Memory Efficient).")
477
+ # VGGT_model 가장 무거우므로 이를 여러 GPU에 쪼개서 올림
478
+ pipeline.VGGT_model = dispatch_model(pipeline.VGGT_model, device_map="auto")
479
+ else:
480
+ print("⚠️ 'accelerate' library not found. Cannot split model across GPUs.")
481
+ print("Installing 'accelerate' (`pip install accelerate`) is highly recommended for multi-GPU inference.")
482
+ # 라이브러리가 없으면 기본 동작(단일 GPU or 기존 상태) 유지
483
  else:
484
  print(f"Running on Single GPU: {torch.cuda.get_device_name(0)}")
 
485
 
486
+ # 나머지 모델들도 CUDA로 이동 (accelerate 적용 안된 경우)
487
+ if not ACCELERATE_AVAILABLE or torch.cuda.device_count() <= 1:
488
+ pipeline.VGGT_model.cuda()
489
+ pipeline.birefnet_model.cuda()
490
+
491
  demo.launch()