Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import spaces
|
| 3 |
from gradio_litmodel3d import LitModel3D
|
| 4 |
-
|
| 5 |
import os
|
| 6 |
import shutil
|
|
|
|
| 7 |
os.environ['SPCONV_ALGO'] = 'native'
|
| 8 |
from typing import *
|
| 9 |
import torch
|
|
@@ -15,22 +15,19 @@ from trellis.pipelines import TrellisVGGTTo3DPipeline
|
|
| 15 |
from trellis.representations import Gaussian, MeshExtractResult
|
| 16 |
from trellis.utils import render_utils, postprocessing_utils
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
MAX_SEED = np.iinfo(np.int32).max
|
| 21 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
| 22 |
-
# TMP_DIR = "tmp/Trellis-demo"
|
| 23 |
-
# os.environ['GRADIO_TEMP_DIR'] = 'tmp'
|
| 24 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 25 |
|
| 26 |
def start_session(req: gr.Request):
|
| 27 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 28 |
os.makedirs(user_dir, exist_ok=True)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
def end_session(req: gr.Request):
|
| 32 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
| 35 |
@spaces.GPU
|
| 36 |
def preprocess_image(image: Image.Image) -> Image.Image:
|
|
@@ -436,29 +433,32 @@ with demo:
|
|
| 436 |
)
|
| 437 |
|
| 438 |
|
| 439 |
-
# ํ์ผ ์ตํ๋จ๋ถ ์์
|
| 440 |
if __name__ == "__main__":
|
| 441 |
-
# 1.
|
| 442 |
pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
|
| 443 |
|
| 444 |
-
# 2.
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
# 3. ์ฑ ์คํ
|
| 464 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import spaces
|
| 3 |
from gradio_litmodel3d import LitModel3D
|
|
|
|
| 4 |
import os
|
| 5 |
import shutil
|
| 6 |
+
|
| 7 |
os.environ['SPCONV_ALGO'] = 'native'
|
| 8 |
from typing import *
|
| 9 |
import torch
|
|
|
|
| 15 |
from trellis.representations import Gaussian, MeshExtractResult
|
| 16 |
from trellis.utils import render_utils, postprocessing_utils
|
| 17 |
|
|
|
|
|
|
|
| 18 |
MAX_SEED = np.iinfo(np.int32).max
|
| 19 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
|
|
|
|
|
|
| 20 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 21 |
|
| 22 |
def start_session(req: gr.Request):
|
| 23 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 24 |
os.makedirs(user_dir, exist_ok=True)
|
| 25 |
+
|
|
|
|
| 26 |
def end_session(req: gr.Request):
|
| 27 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 28 |
+
# ํด๋๊ฐ ์กด์ฌํ ๋๋ง ์ญ์ ํ๋๋ก ์์ (FileNotFoundError ๋ฐฉ์ง)
|
| 29 |
+
if os.path.exists(user_dir):
|
| 30 |
+
shutil.rmtree(user_dir)
|
| 31 |
|
| 32 |
@spaces.GPU
|
| 33 |
def preprocess_image(image: Image.Image) -> Image.Image:
|
|
|
|
| 433 |
)
|
| 434 |
|
| 435 |
|
|
|
|
| 436 |
if __name__ == "__main__":
|
| 437 |
+
# 1. ๋ชจ๋ธ ๋ก๋
|
| 438 |
pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
|
| 439 |
|
| 440 |
+
# 2. ๋ฉํฐ GPU ๊ฐ์ ์ค์
|
| 441 |
+
if torch.cuda.is_available():
|
| 442 |
+
device = torch.device("cuda:0")
|
| 443 |
+
# ๋ชจ๋ ํ์ ๋ชจ๋ธ๋ค์ ๋จผ์ 0๋ฒ GPU๋ก ํ์คํ ๋ณด๋
๋๋ค
|
| 444 |
+
pipeline.to(device)
|
| 445 |
+
|
| 446 |
+
num_gpus = torch.cuda.device_count()
|
| 447 |
+
if num_gpus > 1:
|
| 448 |
+
print(f"--- ๋ฉํฐ GPU ํ์ฑํ: {num_gpus}๊ฐ์ GPU๋ฅผ ์ฌ์ฉํฉ๋๋ค ---")
|
| 449 |
+
# ์๋ฌ๊ฐ ๋ฌ๋ birefnet๊ณผ ์ฃผ์ ๋ชจ๋ธ๋ค์ DataParallel๋ก ๋ํ
|
| 450 |
+
try:
|
| 451 |
+
if hasattr(pipeline, 'VGGT_model'):
|
| 452 |
+
pipeline.VGGT_model = torch.nn.DataParallel(pipeline.VGGT_model).cuda()
|
| 453 |
+
if hasattr(pipeline, 'birefnet_model'):
|
| 454 |
+
pipeline.birefnet_model = torch.nn.DataParallel(pipeline.birefnet_model).cuda()
|
| 455 |
+
if hasattr(pipeline, 'sparse_structure_decoder'):
|
| 456 |
+
pipeline.sparse_structure_decoder = torch.nn.DataParallel(pipeline.sparse_structure_decoder).cuda()
|
| 457 |
+
if hasattr(pipeline, 'slat_decoder'):
|
| 458 |
+
pipeline.slat_decoder = torch.nn.DataParallel(pipeline.slat_decoder).cuda()
|
| 459 |
+
except Exception as e:
|
| 460 |
+
print(f"๋ฉํฐ GPU ์ค์ ์ค ๊ฒฝ๊ณ ๋ฐ์(๋จ์ผ GPU๋ก ์ ํ): {e}")
|
| 461 |
+
pipeline.to(device)
|
| 462 |
+
|
| 463 |
# 3. ์ฑ ์คํ
|
| 464 |
demo.launch()
|