notenoughram commited on
Commit
8e3a585
ยท
verified ยท
1 Parent(s): b088897

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -29
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
-
5
  import os
6
  import shutil
 
7
  os.environ['SPCONV_ALGO'] = 'native'
8
  from typing import *
9
  import torch
@@ -15,22 +15,19 @@ from trellis.pipelines import TrellisVGGTTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
18
-
19
-
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
- # TMP_DIR = "tmp/Trellis-demo"
23
- # os.environ['GRADIO_TEMP_DIR'] = 'tmp'
24
  os.makedirs(TMP_DIR, exist_ok=True)
25
 
26
  def start_session(req: gr.Request):
27
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
  os.makedirs(user_dir, exist_ok=True)
29
-
30
-
31
  def end_session(req: gr.Request):
32
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
- shutil.rmtree(user_dir)
 
 
34
 
35
  @spaces.GPU
36
  def preprocess_image(image: Image.Image) -> Image.Image:
@@ -436,29 +433,32 @@ with demo:
436
  )
437
 
438
 
439
- # ํŒŒ์ผ ์ตœํ•˜๋‹จ๋ถ€ ์ˆ˜์ •
440
  if __name__ == "__main__":
441
- # 1. ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ
442
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
443
 
444
- # 2. ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ๊ฐœ์ˆ˜ ํ™•์ธ
445
- num_gpus = torch.cuda.device_count()
446
- print(f"์‹œ์Šคํ…œ์—์„œ ๊ฐ์ง€๋œ GPU ๊ฐœ์ˆ˜: {num_gpus}")
447
-
448
- if num_gpus > 1:
449
- print("๋ฉ€ํ‹ฐ GPU ๋ชจ๋“œ๋กœ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
450
- pipeline.cuda()
451
- # ์ฃผ์š” ๋ชจ๋ธ๋“ค์„ DataParallel๋กœ ๊ฐ์‹ธ์„œ ๋ชจ๋“  GPU์— ๋ถ„์‚ฐ
452
- pipeline.VGGT_model = torch.nn.DataParallel(pipeline.VGGT_model)
453
- pipeline.birefnet_model = torch.nn.DataParallel(pipeline.birefnet_model)
454
- # pipeline.dreamsim_model์ด ์žˆ๋Š” ๊ฒฝ์šฐ ์•„๋ž˜ ์ฃผ์„ ํ•ด์ œ
455
- # if hasattr(pipeline, 'dreamsim_model'):
456
- # pipeline.dreamsim_model = torch.nn.DataParallel(pipeline.dreamsim_model)
457
- else:
458
- print("๋‹จ์ผ GPU ๋ชจ๋“œ๋กœ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
459
- pipeline.cuda()
460
- pipeline.VGGT_model.cuda()
461
- pipeline.birefnet_model.cuda()
462
-
 
 
 
 
463
  # 3. ์•ฑ ์‹คํ–‰
464
  demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
 
4
  import os
5
  import shutil
6
+
7
  os.environ['SPCONV_ALGO'] = 'native'
8
  from typing import *
9
  import torch
 
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
 
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
 
20
  os.makedirs(TMP_DIR, exist_ok=True)
21
 
22
  def start_session(req: gr.Request):
23
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
24
  os.makedirs(user_dir, exist_ok=True)
25
+
 
26
  def end_session(req: gr.Request):
27
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
+ # ํด๋”๊ฐ€ ์กด์žฌํ•  ๋•Œ๋งŒ ์‚ญ์ œํ•˜๋„๋ก ์ˆ˜์ • (FileNotFoundError ๋ฐฉ์ง€)
29
+ if os.path.exists(user_dir):
30
+ shutil.rmtree(user_dir)
31
 
32
  @spaces.GPU
33
  def preprocess_image(image: Image.Image) -> Image.Image:
 
433
  )
434
 
435
 
 
436
  if __name__ == "__main__":
437
+ # 1. ๋ชจ๋ธ ๋กœ๋“œ
438
  pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
439
 
440
+ # 2. ๋ฉ€ํ‹ฐ GPU ๊ฐ•์ œ ์„ค์ •
441
+ if torch.cuda.is_available():
442
+ device = torch.device("cuda:0")
443
+ # ๋ชจ๋“  ํ•˜์œ„ ๋ชจ๋ธ๋“ค์„ ๋จผ์ € 0๋ฒˆ GPU๋กœ ํ™•์‹คํžˆ ๋ณด๋ƒ…๋‹ˆ๋‹ค
444
+ pipeline.to(device)
445
+
446
+ num_gpus = torch.cuda.device_count()
447
+ if num_gpus > 1:
448
+ print(f"--- ๋ฉ€ํ‹ฐ GPU ํ™œ์„ฑํ™”: {num_gpus}๊ฐœ์˜ GPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค ---")
449
+ # ์—๋Ÿฌ๊ฐ€ ๋‚ฌ๋˜ birefnet๊ณผ ์ฃผ์š” ๋ชจ๋ธ๋“ค์„ DataParallel๋กœ ๋ž˜ํ•‘
450
+ try:
451
+ if hasattr(pipeline, 'VGGT_model'):
452
+ pipeline.VGGT_model = torch.nn.DataParallel(pipeline.VGGT_model).cuda()
453
+ if hasattr(pipeline, 'birefnet_model'):
454
+ pipeline.birefnet_model = torch.nn.DataParallel(pipeline.birefnet_model).cuda()
455
+ if hasattr(pipeline, 'sparse_structure_decoder'):
456
+ pipeline.sparse_structure_decoder = torch.nn.DataParallel(pipeline.sparse_structure_decoder).cuda()
457
+ if hasattr(pipeline, 'slat_decoder'):
458
+ pipeline.slat_decoder = torch.nn.DataParallel(pipeline.slat_decoder).cuda()
459
+ except Exception as e:
460
+ print(f"๋ฉ€ํ‹ฐ GPU ์„ค์ • ์ค‘ ๊ฒฝ๊ณ  ๋ฐœ์ƒ(๋‹จ์ผ GPU๋กœ ์ „ํ™˜): {e}")
461
+ pipeline.to(device)
462
+
463
  # 3. ์•ฑ ์‹คํ–‰
464
  demo.launch()