xinjie.wang commited on
Commit
6cb207b
·
1 Parent(s): b635124
Files changed (1) hide show
  1. common.py +18 -5
common.py CHANGED
@@ -88,12 +88,14 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
88
  os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
89
  MAX_SEED = 100000
90
 
 
 
 
 
91
  # DELIGHT = DelightingModel()
92
  # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
93
  # IMAGESR_MODEL = ImageStableSR()
94
  if os.getenv("GRADIO_APP").startswith("imageto3d"):
95
- RBG_REMOVER = RembgRemover()
96
- RBG14_REMOVER = BMGG14Remover()
97
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
98
  if "sam3d" in os.getenv("GRADIO_APP"):
99
  PIPELINE = Sam3dInference(device="cuda")
@@ -111,8 +113,6 @@ if os.getenv("GRADIO_APP").startswith("imageto3d"):
111
  )
112
  os.makedirs(TMP_DIR, exist_ok=True)
113
  elif os.getenv("GRADIO_APP").startswith("textto3d"):
114
- RBG_REMOVER = RembgRemover()
115
- RBG14_REMOVER = BMGG14Remover()
116
  if "sam3d" in os.getenv("GRADIO_APP"):
117
  PIPELINE = Sam3dInference(device="cuda")
118
  else:
@@ -161,11 +161,15 @@ def end_session(req: gr.Request) -> None:
161
  shutil.rmtree(user_dir)
162
 
163
 
 
164
  def preprocess_image_fn(
165
  image: str | np.ndarray | Image.Image,
166
  rmbg_tag: str = "rembg",
167
  preprocess: bool = True,
168
  ) -> tuple[Image.Image, Image.Image]:
 
 
 
169
  if isinstance(image, str):
170
  image = Image.open(image)
171
  elif isinstance(image, np.ndarray):
@@ -173,7 +177,16 @@ def preprocess_image_fn(
173
 
174
  image_cache = image.copy() # resize_pil(image.copy(), 1024)
175
 
176
- bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
 
 
 
 
 
 
 
 
 
177
  image = bg_remover(image)
178
  image = keep_largest_connected_component(image)
179
 
 
88
  os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
89
  MAX_SEED = 100000
90
 
91
+ # Global variables for lazy initialization
92
+ _RBG_REMOVER = None
93
+ _RBG14_REMOVER = None
94
+
95
  # DELIGHT = DelightingModel()
96
  # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
97
  # IMAGESR_MODEL = ImageStableSR()
98
  if os.getenv("GRADIO_APP").startswith("imageto3d"):
 
 
99
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
100
  if "sam3d" in os.getenv("GRADIO_APP"):
101
  PIPELINE = Sam3dInference(device="cuda")
 
113
  )
114
  os.makedirs(TMP_DIR, exist_ok=True)
115
  elif os.getenv("GRADIO_APP").startswith("textto3d"):
 
 
116
  if "sam3d" in os.getenv("GRADIO_APP"):
117
  PIPELINE = Sam3dInference(device="cuda")
118
  else:
 
161
  shutil.rmtree(user_dir)
162
 
163
 
164
+ @spaces.GPU
165
  def preprocess_image_fn(
166
  image: str | np.ndarray | Image.Image,
167
  rmbg_tag: str = "rembg",
168
  preprocess: bool = True,
169
  ) -> tuple[Image.Image, Image.Image]:
170
+ """Preprocess image with lazy model initialization to avoid CUDA init at import time."""
171
+ global _RBG_REMOVER, _RBG14_REMOVER
172
+
173
  if isinstance(image, str):
174
  image = Image.open(image)
175
  elif isinstance(image, np.ndarray):
 
177
 
178
  image_cache = image.copy() # resize_pil(image.copy(), 1024)
179
 
180
+ # Lazy initialization - models are created on first call within @spaces.GPU context
181
+ if rmbg_tag == "rembg":
182
+ if _RBG_REMOVER is None:
183
+ _RBG_REMOVER = RembgRemover()
184
+ bg_remover = _RBG_REMOVER
185
+ else:
186
+ if _RBG14_REMOVER is None:
187
+ _RBG14_REMOVER = BMGG14Remover()
188
+ bg_remover = _RBG14_REMOVER
189
+
190
  image = bg_remover(image)
191
  image = keep_largest_connected_component(image)
192