Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,28 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from gradio_litmodel3d import LitModel3D
|
| 3 |
|
| 4 |
-
import os
|
| 5 |
import shutil
|
| 6 |
-
os.environ['SPCONV_ALGO'] = 'native'
|
| 7 |
from typing import *
|
| 8 |
import torch
|
| 9 |
-
import torch.nn as nn # nn 모듈 추가
|
| 10 |
import numpy as np
|
| 11 |
import imageio
|
|
|
|
| 12 |
from easydict import EasyDict as edict
|
| 13 |
from PIL import Image
|
| 14 |
from trellis.pipelines import TrellisVGGTTo3DPipeline
|
| 15 |
from trellis.representations import Gaussian, MeshExtractResult
|
| 16 |
from trellis.utils import render_utils, postprocessing_utils
|
| 17 |
|
| 18 |
-
#
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
return getattr(self.module, name)
|
| 25 |
-
# -----------------------------------------------------------
|
| 26 |
|
| 27 |
MAX_SEED = np.iinfo(np.int32).max
|
| 28 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
|
@@ -37,6 +38,9 @@ def end_session(req: gr.Request):
|
|
| 37 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 38 |
if os.path.exists(user_dir):
|
| 39 |
shutil.rmtree(user_dir)
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def preprocess_image(image: Image.Image) -> Image.Image:
|
| 42 |
"""
|
|
@@ -118,6 +122,9 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
|
| 118 |
|
| 119 |
|
| 120 |
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
|
|
|
|
|
|
|
|
|
| 121 |
gs = Gaussian(
|
| 122 |
aabb=state['gaussian']['aabb'],
|
| 123 |
sh_degree=state['gaussian']['sh_degree'],
|
|
@@ -126,15 +133,15 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
|
| 126 |
opacity_bias=state['gaussian']['opacity_bias'],
|
| 127 |
scaling_activation=state['gaussian']['scaling_activation'],
|
| 128 |
)
|
| 129 |
-
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device=
|
| 130 |
-
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device=
|
| 131 |
-
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device=
|
| 132 |
-
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device=
|
| 133 |
-
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device=
|
| 134 |
|
| 135 |
mesh = edict(
|
| 136 |
-
vertices=torch.tensor(state['mesh']['vertices'], device=
|
| 137 |
-
faces=torch.tensor(state['mesh']['faces'], device=
|
| 138 |
)
|
| 139 |
|
| 140 |
return gs, mesh
|
|
@@ -191,25 +198,31 @@ def generate_and_extract_glb(
|
|
| 191 |
str: The path to the extracted GLB file.
|
| 192 |
str: The path to the extracted GLB file (for download).
|
| 193 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 195 |
image_files = [image[0] for image in multiimages]
|
| 196 |
|
| 197 |
# Generate 3D model
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
| 213 |
|
| 214 |
# Render video
|
| 215 |
# import uuid
|
|
@@ -234,7 +247,11 @@ def generate_and_extract_glb(
|
|
| 234 |
# Pack state for optional Gaussian extraction
|
| 235 |
state = pack_state(gs, mesh)
|
| 236 |
|
|
|
|
|
|
|
|
|
|
| 237 |
torch.cuda.empty_cache()
|
|
|
|
| 238 |
return state, video_path, glb_path, glb_path
|
| 239 |
|
| 240 |
|
|
@@ -257,12 +274,16 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
|
| 257 |
gs, _ = unpack_state(state)
|
| 258 |
gaussian_path = os.path.join(user_dir, 'sample.ply')
|
| 259 |
gs.save_ply(gaussian_path)
|
|
|
|
|
|
|
|
|
|
| 260 |
torch.cuda.empty_cache()
|
|
|
|
| 261 |
return gaussian_path, gaussian_path
|
| 262 |
|
| 263 |
|
| 264 |
def prepare_multi_example() -> List[Image.Image]:
|
| 265 |
-
#
|
| 266 |
if not os.path.exists("assets/example_multi_image"):
|
| 267 |
return []
|
| 268 |
multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
|
|
@@ -445,19 +466,26 @@ if __name__ == "__main__":
|
|
| 445 |
print("Initializing Pipeline...")
|
| 446 |
pipeline = TrellisVGGTTo3DPipeline.from_pretrained("esther11/trellis-vggt-v0-2")
|
| 447 |
pipeline.cuda()
|
| 448 |
-
pipeline.VGGT_model.cuda()
|
| 449 |
-
pipeline.birefnet_model.cuda()
|
| 450 |
|
| 451 |
-
#
|
|
|
|
|
|
|
| 452 |
if torch.cuda.device_count() > 1:
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
|
|
|
|
|
|
|
|
|
| 459 |
else:
|
| 460 |
print(f"Running on Single GPU: {torch.cuda.get_device_name(0)}")
|
| 461 |
-
# -----------------------------
|
| 462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
demo.launch()
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# [중요] OOM 방지를 위한 메모리 파편화 설정 (토치 로드 전에 설정해야 함)
|
| 3 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 4 |
+
os.environ['SPCONV_ALGO'] = 'native'
|
| 5 |
+
|
| 6 |
import gradio as gr
|
| 7 |
from gradio_litmodel3d import LitModel3D
|
| 8 |
|
|
|
|
| 9 |
import shutil
|
|
|
|
| 10 |
from typing import *
|
| 11 |
import torch
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
import imageio
|
| 14 |
+
import gc # 가비지 컬렉션 추가
|
| 15 |
from easydict import EasyDict as edict
|
| 16 |
from PIL import Image
|
| 17 |
from trellis.pipelines import TrellisVGGTTo3DPipeline
|
| 18 |
from trellis.representations import Gaussian, MeshExtractResult
|
| 19 |
from trellis.utils import render_utils, postprocessing_utils
|
| 20 |
|
| 21 |
+
# [중요] 모델 분산을 위한 accelerate 라이브러리 체크
|
| 22 |
+
try:
|
| 23 |
+
from accelerate import dispatch_model
|
| 24 |
+
ACCELERATE_AVAILABLE = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
ACCELERATE_AVAILABLE = False
|
|
|
|
|
|
|
| 27 |
|
| 28 |
MAX_SEED = np.iinfo(np.int32).max
|
| 29 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
|
|
|
| 38 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 39 |
if os.path.exists(user_dir):
|
| 40 |
shutil.rmtree(user_dir)
|
| 41 |
+
# 세션 종료 시 메모리 정리
|
| 42 |
+
gc.collect()
|
| 43 |
+
torch.cuda.empty_cache()
|
| 44 |
|
| 45 |
def preprocess_image(image: Image.Image) -> Image.Image:
|
| 46 |
"""
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
| 125 |
+
# 언팩 시 바로 CUDA로 올리면 메모리 튈 수 있으므로 상황에 맞게 device 설정
|
| 126 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 127 |
+
|
| 128 |
gs = Gaussian(
|
| 129 |
aabb=state['gaussian']['aabb'],
|
| 130 |
sh_degree=state['gaussian']['sh_degree'],
|
|
|
|
| 133 |
opacity_bias=state['gaussian']['opacity_bias'],
|
| 134 |
scaling_activation=state['gaussian']['scaling_activation'],
|
| 135 |
)
|
| 136 |
+
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device=device)
|
| 137 |
+
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device=device)
|
| 138 |
+
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device=device)
|
| 139 |
+
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device=device)
|
| 140 |
+
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device=device)
|
| 141 |
|
| 142 |
mesh = edict(
|
| 143 |
+
vertices=torch.tensor(state['mesh']['vertices'], device=device),
|
| 144 |
+
faces=torch.tensor(state['mesh']['faces'], device=device),
|
| 145 |
)
|
| 146 |
|
| 147 |
return gs, mesh
|
|
|
|
| 198 |
str: The path to the extracted GLB file.
|
| 199 |
str: The path to the extracted GLB file (for download).
|
| 200 |
"""
|
| 201 |
+
# [수정] 추론 시작 전 가비지 컬렉션 수행
|
| 202 |
+
gc.collect()
|
| 203 |
+
torch.cuda.empty_cache()
|
| 204 |
+
|
| 205 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 206 |
image_files = [image[0] for image in multiimages]
|
| 207 |
|
| 208 |
# Generate 3D model
|
| 209 |
+
# [수정] torch.no_grad()를 사용하여 불필요한 그라디언트 메모리 사용 방지
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
outputs, _, _ = pipeline.run(
|
| 212 |
+
image=image_files,
|
| 213 |
+
seed=seed,
|
| 214 |
+
formats=["gaussian", "mesh"],
|
| 215 |
+
preprocess_image=False,
|
| 216 |
+
sparse_structure_sampler_params={
|
| 217 |
+
"steps": ss_sampling_steps,
|
| 218 |
+
"cfg_strength": ss_guidance_strength,
|
| 219 |
+
},
|
| 220 |
+
slat_sampler_params={
|
| 221 |
+
"steps": slat_sampling_steps,
|
| 222 |
+
"cfg_strength": slat_guidance_strength,
|
| 223 |
+
},
|
| 224 |
+
mode=multiimage_algo,
|
| 225 |
+
)
|
| 226 |
|
| 227 |
# Render video
|
| 228 |
# import uuid
|
|
|
|
| 247 |
# Pack state for optional Gaussian extraction
|
| 248 |
state = pack_state(gs, mesh)
|
| 249 |
|
| 250 |
+
# [수정] 사용 끝난 변수 명시적 삭제 및 메모리 정리
|
| 251 |
+
del outputs, gs, mesh, glb
|
| 252 |
+
gc.collect()
|
| 253 |
torch.cuda.empty_cache()
|
| 254 |
+
|
| 255 |
return state, video_path, glb_path, glb_path
|
| 256 |
|
| 257 |
|
|
|
|
| 274 |
gs, _ = unpack_state(state)
|
| 275 |
gaussian_path = os.path.join(user_dir, 'sample.ply')
|
| 276 |
gs.save_ply(gaussian_path)
|
| 277 |
+
|
| 278 |
+
# [수정] 메모리 정리
|
| 279 |
+
del gs
|
| 280 |
torch.cuda.empty_cache()
|
| 281 |
+
|
| 282 |
return gaussian_path, gaussian_path
|
| 283 |
|
| 284 |
|
| 285 |
def prepare_multi_example() -> List[Image.Image]:
|
| 286 |
+
# 에러 방지용 경로 체크
|
| 287 |
if not os.path.exists("assets/example_multi_image"):
|
| 288 |
return []
|
| 289 |
multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
|
|
|
|
| 466 |
print("Initializing Pipeline...")
|
| 467 |
pipeline = TrellisVGGTTo3DPipeline.from_pretrained("esther11/trellis-vggt-v0-2")
|
| 468 |
pipeline.cuda()
|
|
|
|
|
|
|
| 469 |
|
| 470 |
+
# [수정] 멀티 GPU 처리 로직 (Model Parallelism)
|
| 471 |
+
# 기존 DataParallel은 모델을 복제하여 VRAM을 2배로 쓰므로 OOM 발생함.
|
| 472 |
+
# 대신 accelerate의 dispatch_model을 사용하여 모델을 여러 GPU에 분할 적재해야 함.
|
| 473 |
if torch.cuda.device_count() > 1:
|
| 474 |
+
if ACCELERATE_AVAILABLE:
|
| 475 |
+
print(f"⚡ Accelerate detected: {torch.cuda.device_count()} GPUs found.")
|
| 476 |
+
print("Applying 'device_map=auto' to VGGT_model to split layers across GPUs (Memory Efficient).")
|
| 477 |
+
# VGGT_model이 가장 무거우므로 이를 여러 GPU에 쪼개서 올림
|
| 478 |
+
pipeline.VGGT_model = dispatch_model(pipeline.VGGT_model, device_map="auto")
|
| 479 |
+
else:
|
| 480 |
+
print("⚠️ 'accelerate' library not found. Cannot split model across GPUs.")
|
| 481 |
+
print("Installing 'accelerate' (`pip install accelerate`) is highly recommended for multi-GPU inference.")
|
| 482 |
+
# 라이브러리가 없으면 기본 동작(단일 GPU or 기존 상태) 유지
|
| 483 |
else:
|
| 484 |
print(f"Running on Single GPU: {torch.cuda.get_device_name(0)}")
|
|
|
|
| 485 |
|
| 486 |
+
# 나머지 모델들도 CUDA로 이동 (accelerate 적용 안된 경우)
|
| 487 |
+
if not ACCELERATE_AVAILABLE or torch.cuda.device_count() <= 1:
|
| 488 |
+
pipeline.VGGT_model.cuda()
|
| 489 |
+
pipeline.birefnet_model.cuda()
|
| 490 |
+
|
| 491 |
demo.launch()
|