JasonYinnnn commited on
Commit
e95f3aa
·
1 Parent(s): 454fd88

lazy init pipeline

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -7,7 +7,6 @@ import gradio as gr
7
  import spaces
8
 
9
  import os
10
- os.environ['SPCONV_ALGO'] = 'native'
11
  import uuid
12
  from typing import Any, List, Optional, Union
13
  import cv2
@@ -18,6 +17,7 @@ import trimesh
18
  import random
19
  import imageio
20
  from einops import repeat
 
21
  from threeDFixer.moge.model.v2 import MoGeModel
22
  from threeDFixer.pipelines import ThreeDFixerPipeline
23
  from threeDFixer.datasets.utils import (
@@ -362,6 +362,12 @@ def run_generation(
362
  generated_object_map = {}
363
  run_id = str(uuid.uuid4())
364
 
 
 
 
 
 
 
365
  if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
366
  rgb_image = rgb_image["image"]
367
 
@@ -858,10 +864,12 @@ if __name__ == '__main__':
858
 
859
  ############## 3D-Fixer model
860
  model_dir = 'HorizonRobotics/3D-Fixer'
861
- pipeline = ThreeDFixerPipeline.from_pretrained(
862
- model_dir, compile=False
863
- )
864
- pipeline.to(device=DEVICE)
 
 
865
  ############## 3D-Fixer model
866
 
867
  rot = np.array([
 
7
  import spaces
8
 
9
  import os
 
10
  import uuid
11
  from typing import Any, List, Optional, Union
12
  import cv2
 
17
  import random
18
  import imageio
19
  from einops import repeat
20
+ from huggingface_hub import snapshot_download
21
  from threeDFixer.moge.model.v2 import MoGeModel
22
  from threeDFixer.pipelines import ThreeDFixerPipeline
23
  from threeDFixer.datasets.utils import (
 
362
  generated_object_map = {}
363
  run_id = str(uuid.uuid4())
364
 
365
+ DEVICE = 'cuda'
366
+ pipeline = ThreeDFixerPipeline.from_pretrained(
367
+ local_dir, compile=False
368
+ )
369
+ pipeline.to(device=DEVICE)
370
+
371
  if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
372
  rgb_image = rgb_image["image"]
373
 
 
864
 
865
  ############## 3D-Fixer model
866
  model_dir = 'HorizonRobotics/3D-Fixer'
867
+ local_dir = 'checkpoints/3D-Fixer'
868
+ snapshot_download(repo_id=model_dir, local_dir=local_dir)
869
+ # pipeline = ThreeDFixerPipeline.from_pretrained(
870
+ # model_dir, compile=False
871
+ # )
872
+ # pipeline.to(device=DEVICE)
873
  ############## 3D-Fixer model
874
 
875
  rot = np.array([