thisaraniana commited on
Commit
19b13e0
·
1 Parent(s): 433d974

Fix: implement lazy loading for CUDA pipelines to support Hugging Face ZeroGPU

Browse files
Files changed (1) hide show
  1. common.py +67 -54
common.py CHANGED
@@ -88,65 +88,75 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
88
  os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
89
  MAX_SEED = 100000
90
 
91
- # DELIGHT = DelightingModel()
92
- # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
93
- # IMAGESR_MODEL = ImageStableSR()
94
- if os.getenv("GRADIO_APP").startswith("imageto3d"):
95
- RBG_REMOVER = RembgRemover()
96
- RBG14_REMOVER = BMGG14Remover()
97
- SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
98
- if "sam3d" in os.getenv("GRADIO_APP"):
99
- PIPELINE = Sam3dInference(device="cuda")
100
- else:
101
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
102
- "microsoft/TRELLIS-image-large"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
- # PIPELINE.cuda()
105
- SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
106
- GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
107
- AESTHETIC_CHECKER = ImageAestheticChecker()
108
- CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
109
- TMP_DIR = os.path.join(
110
- os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
111
- )
112
- os.makedirs(TMP_DIR, exist_ok=True)
113
- elif os.getenv("GRADIO_APP").startswith("textto3d"):
114
  RBG_REMOVER = RembgRemover()
115
  RBG14_REMOVER = BMGG14Remover()
116
- if "sam3d" in os.getenv("GRADIO_APP"):
117
- PIPELINE = Sam3dInference(device="cuda")
118
- else:
119
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
120
- "microsoft/TRELLIS-image-large"
121
- )
122
- # PIPELINE.cuda()
123
- text_model_dir = "weights/Kolors"
124
- PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
125
- PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
126
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
127
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
128
  AESTHETIC_CHECKER = ImageAestheticChecker()
129
  CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
130
- TMP_DIR = os.path.join(
131
- os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
132
- )
133
  os.makedirs(TMP_DIR, exist_ok=True)
 
134
  elif os.getenv("GRADIO_APP") == "texture_edit":
135
  DELIGHT = DelightingModel()
136
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
137
- PIPELINE_IP = build_texture_gen_pipe(
138
- base_ckpt_dir="./weights",
139
- ip_adapt_scale=0.7,
140
- device="cuda",
141
- )
142
- PIPELINE = build_texture_gen_pipe(
143
- base_ckpt_dir="./weights",
144
- ip_adapt_scale=0,
145
- device="cuda",
146
- )
147
- TMP_DIR = os.path.join(
148
- os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
149
- )
150
  os.makedirs(TMP_DIR, exist_ok=True)
151
 
152
 
@@ -290,17 +300,18 @@ def image_to_3d(
290
  seg_image = Image.fromarray(seg_image)
291
 
292
  logger.info("Start generating 3D representation from image...")
293
- if isinstance(PIPELINE, Sam3dInference):
294
- outputs = PIPELINE.run(
 
295
  seg_image,
296
  seed=seed,
297
  stage1_inference_steps=ss_sampling_steps,
298
  stage2_inference_steps=slat_sampling_steps,
299
  )
300
  else:
301
- PIPELINE.cuda()
302
  seg_image = trellis_preprocess(seg_image)
303
- outputs = PIPELINE.run(
304
  seg_image,
305
  seed=seed,
306
  formats=["gaussian", "mesh"],
@@ -589,6 +600,7 @@ def text2image_fn(
589
  output_root = os.path.join(output_root, str(req.session_hash))
590
  os.makedirs(output_root, exist_ok=True)
591
 
 
592
  pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP
593
  if ip_image is not None:
594
  pipeline.set_ip_adapter_scale([ip_adapt_scale])
@@ -656,7 +668,8 @@ def generate_texture_mvimages(
656
  ) -> list[str]:
657
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
658
  use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
659
- PIPELINE_IP.set_ip_adapter_scale([ip_adapt_scale])
 
660
  img_save_paths = infer_pipe(
661
  index_file=f"{output_root}/condition/index.json",
662
  controlnet_cond_scale=controlnet_cond_scale,
@@ -669,7 +682,7 @@ def generate_texture_mvimages(
669
  prompt=prompt,
670
  save_dir=f"{output_root}/multi_view",
671
  sub_idxs=sub_idxs,
672
- pipeline=PIPELINE_IP if use_ip_adapter else PIPELINE,
673
  seed=seed,
674
  )
675
 
 
88
  os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
89
  MAX_SEED = 100000
90
 
91
+ # Global pipeline placeholders
92
+ PIPELINE = None
93
+ PIPELINE_IMG_IP = None
94
+ PIPELINE_IMG = None
95
+ PIPELINE_IP = None
96
+ PIPELINE_EDIT = None
97
+ RBG_REMOVER = None
98
+ RBG14_REMOVER = None
99
+ SAM_PREDICTOR = None
100
+ SEG_CHECKER = None
101
+ GEO_CHECKER = None
102
+ AESTHETIC_CHECKER = None
103
+ CHECKERS = []
104
+
105
+ def get_pipline():
106
+ global PIPELINE, PIPELINE_IMG_IP, PIPELINE_IMG
107
+ if PIPELINE is None:
108
+ if os.getenv("GRADIO_APP").startswith("imageto3d"):
109
+ if "sam3d" in os.getenv("GRADIO_APP"):
110
+ PIPELINE = Sam3dInference(device="cuda")
111
+ else:
112
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
113
+ elif os.getenv("GRADIO_APP").startswith("textto3d"):
114
+ if "sam3d" in os.getenv("GRADIO_APP"):
115
+ PIPELINE = Sam3dInference(device="cuda")
116
+ else:
117
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
118
+
119
+ text_model_dir = "weights/Kolors"
120
+ if PIPELINE_IMG_IP is None:
121
+ PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
122
+ if PIPELINE_IMG is None:
123
+ PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
124
+ return PIPELINE
125
+
126
+ def get_texture_edit_pipeline():
127
+ global PIPELINE_IP, PIPELINE_EDIT
128
+ if PIPELINE_IP is None:
129
+ PIPELINE_IP = build_texture_gen_pipe(
130
+ base_ckpt_dir="./weights",
131
+ ip_adapt_scale=0.7,
132
+ device="cuda",
133
  )
134
+ if PIPELINE_EDIT is None:
135
+ PIPELINE_EDIT = build_texture_gen_pipe(
136
+ base_ckpt_dir="./weights",
137
+ ip_adapt_scale=0,
138
+ device="cuda",
139
+ )
140
+ return PIPELINE_IP, PIPELINE_EDIT
141
+
142
+ # Initialize non-CUDA models
143
+ if os.getenv("GRADIO_APP").startswith("imageto3d") or os.getenv("GRADIO_APP").startswith("textto3d"):
144
  RBG_REMOVER = RembgRemover()
145
  RBG14_REMOVER = BMGG14Remover()
146
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
 
 
 
 
 
 
 
 
 
147
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
148
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
149
  AESTHETIC_CHECKER = ImageAestheticChecker()
150
  CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
151
+
152
+ app_type = "imageto3d" if os.getenv("GRADIO_APP").startswith("imageto3d") else "textto3d"
153
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"sessions/{app_type}")
154
  os.makedirs(TMP_DIR, exist_ok=True)
155
+
156
  elif os.getenv("GRADIO_APP") == "texture_edit":
157
  DELIGHT = DelightingModel()
158
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
159
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit")
 
 
 
 
 
 
 
 
 
 
 
 
160
  os.makedirs(TMP_DIR, exist_ok=True)
161
 
162
 
 
300
  seg_image = Image.fromarray(seg_image)
301
 
302
  logger.info("Start generating 3D representation from image...")
303
+ pipeline = get_pipline()
304
+ if isinstance(pipeline, Sam3dInference):
305
+ outputs = pipeline.run(
306
  seg_image,
307
  seed=seed,
308
  stage1_inference_steps=ss_sampling_steps,
309
  stage2_inference_steps=slat_sampling_steps,
310
  )
311
  else:
312
+ pipeline.cuda()
313
  seg_image = trellis_preprocess(seg_image)
314
+ outputs = pipeline.run(
315
  seg_image,
316
  seed=seed,
317
  formats=["gaussian", "mesh"],
 
600
  output_root = os.path.join(output_root, str(req.session_hash))
601
  os.makedirs(output_root, exist_ok=True)
602
 
603
+ _ = get_pipline()
604
  pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP
605
  if ip_image is not None:
606
  pipeline.set_ip_adapter_scale([ip_adapt_scale])
 
668
  ) -> list[str]:
669
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
670
  use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
671
+ pipeline_ip, pipeline_edit = get_texture_edit_pipeline()
672
+ pipeline_ip.set_ip_adapter_scale([ip_adapt_scale])
673
  img_save_paths = infer_pipe(
674
  index_file=f"{output_root}/condition/index.json",
675
  controlnet_cond_scale=controlnet_cond_scale,
 
682
  prompt=prompt,
683
  save_dir=f"{output_root}/multi_view",
684
  sub_idxs=sub_idxs,
685
+ pipeline=pipeline_ip if use_ip_adapter else pipeline_edit,
686
  seed=seed,
687
  )
688