aadarsh99 commited on
Commit
b6001f7
·
1 Parent(s): eb03911

update app

Browse files
Files changed (1) hide show
  1. app.py +52 -44
app.py CHANGED
@@ -36,8 +36,11 @@ SQUARE_DIM = 1024
36
 
37
  logging.basicConfig(level=logging.INFO)
38
 
39
- # ----------------- Overlay Style Helpers -----------------
 
 
40
 
 
41
  EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
42
 
43
  def _hex_to_rgb(h: str):
@@ -176,8 +179,6 @@ def _unpad_and_resize_pred_to_gt(logit_sq: torch.Tensor, meta: dict, out_hw: tup
176
  up = F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)
177
  return up[0, 0]
178
 
179
- # ----------------- Model Logic -----------------
180
-
181
  def get_text_to_image_attention(decoder: MaskDecoder):
182
  two_way = decoder.transformer
183
  attn_blocks = []
@@ -198,13 +199,29 @@ def get_text_to_image_attention(decoder: MaskDecoder):
198
  text_attn = attn[..., n_output_tokens:, :]
199
  return text_attn
200
 
201
- def load_models_cpu():
202
- print("Loading models on CPU...")
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- # 1. Base SAM2 Model (Raw Model, not Predictor)
205
  if not os.path.exists(BASE_CKPT_NAME):
206
  raise FileNotFoundError(f"{BASE_CKPT_NAME} not found")
207
 
 
 
 
 
 
208
  model = build_sam2(SAM2_CONFIG, BASE_CKPT_NAME, device="cpu")
209
 
210
  # 2. Fine-tuned Weights
@@ -212,9 +229,11 @@ def load_models_cpu():
212
  raise FileNotFoundError(f"{FINAL_CKPT_NAME} not found")
213
 
214
  sd = torch.load(FINAL_CKPT_NAME, map_location="cpu")
215
- # Load into the model directly
216
  model.load_state_dict(sd.get("model", sd), strict=True)
217
- model.eval()
 
 
 
218
 
219
  # 3. PLM Adapter
220
  C = model.sam_mask_decoder.transformer_dim
@@ -228,10 +247,9 @@ def load_models_cpu():
228
  lora_alpha=32,
229
  lora_dropout=0.05,
230
  dtype=torch.bfloat16,
231
- device="cpu",
232
- ).to("cpu")
233
- plm.eval()
234
-
235
  if not os.path.exists(PLM_CKPT_NAME):
236
  raise FileNotFoundError(f"{PLM_CKPT_NAME} not found")
237
 
@@ -240,38 +258,34 @@ def load_models_cpu():
240
 
241
  if LORA_CKPT_NAME and os.path.exists(LORA_CKPT_NAME):
242
  plm.load_lora(LORA_CKPT_NAME)
 
 
 
 
 
243
 
244
- print("Models loaded successfully (CPU).")
245
- return model, plm
246
 
247
- # Initialize global models on CPU
248
- try:
249
- # NOTE: We hold the raw MODEL_SAM here, not the predictor
250
- MODEL_SAM, PLM = load_models_cpu()
251
- except Exception as e:
252
- print(f"Error loading models: {e}")
253
- traceback.print_exc()
254
- MODEL_SAM, PLM = None, None
255
 
256
- @spaces.GPU(duration=60) # Ensure we have enough time (default is often 60s)
257
  def run_prediction(image_pil, text_prompt):
258
- if MODEL_SAM is None or PLM is None:
259
- return None, None, None
260
-
261
  if image_pil is None or not text_prompt:
262
  return None, None, None
263
 
264
  predictor = None
265
-
266
  try:
267
- # 1. Move models to GPU
268
- print("Moving models to CUDA...")
269
- MODEL_SAM.to("cuda")
270
- PLM.to("cuda")
 
 
 
 
271
 
272
- # 2. Instantiate Predictor ON GPU (Crucial Fix)
273
- # This ensures the predictor knows it's on CUDA
274
- predictor = SAM2ImagePredictor(MODEL_SAM)
275
 
276
  # 3. Preprocess Image
277
  rgb_orig = np.array(image_pil.convert("RGB"))
@@ -280,7 +294,6 @@ def run_prediction(image_pil, text_prompt):
280
  rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
281
 
282
  # 4. SAM2 Image Encoding
283
- # set_image puts features on the model's device
284
  predictor.set_image(rgb_sq)
285
  image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
286
  hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
@@ -290,11 +303,9 @@ def run_prediction(image_pil, text_prompt):
290
  temp_path = "temp_input.jpg"
291
  image_pil.save(temp_path)
292
 
293
- # PLM inference usually handles device mapping internally if written well,
294
- # but we ensure inputs are passed cleanly.
295
- sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
296
 
297
- # 6. Prepare SAM2 Decoder inputs (ensure they are on CUDA)
298
  dec = predictor.model.sam_mask_decoder
299
  dev = next(dec.parameters()).device
300
  dtype = next(dec.parameters()).dtype
@@ -354,13 +365,10 @@ def run_prediction(image_pil, text_prompt):
354
  except Exception as e:
355
  print("An error occurred during inference:")
356
  traceback.print_exc()
357
- raise e # Let Gradio show the error
358
 
359
  finally:
360
- # Cleanup: Move models back to CPU
361
- print("Moving models back to CPU...")
362
- MODEL_SAM.to("cpu")
363
- PLM.to("cpu")
364
  if predictor:
365
  del predictor
366
  torch.cuda.empty_cache()
 
36
 
37
  logging.basicConfig(level=logging.INFO)
38
 
39
+ # ----------------- Globals (Lazy Loading) -----------------
40
+ MODEL_SAM = None
41
+ PLM = None
42
 
43
+ # ----------------- Overlay Style Helpers -----------------
44
  EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
45
 
46
  def _hex_to_rgb(h: str):
 
179
  up = F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)
180
  return up[0, 0]
181
 
 
 
182
  def get_text_to_image_attention(decoder: MaskDecoder):
183
  two_way = decoder.transformer
184
  attn_blocks = []
 
199
  text_attn = attn[..., n_output_tokens:, :]
200
  return text_attn
201
 
202
+ # ----------------- Model Loading -----------------
203
+
204
+ def load_models_lazy():
205
+ """
206
+ Loads the models. This must be called INSIDE the @spaces.GPU context
207
+ so that devices match (everything on 'cuda' or 'zero').
208
+ """
209
+ global MODEL_SAM, PLM
210
+
211
+ if MODEL_SAM is not None and PLM is not None:
212
+ return MODEL_SAM, PLM
213
+
214
+ print("Lazy loading models inside GPU context...")
215
 
216
+ # 1. Base SAM2 Model
217
  if not os.path.exists(BASE_CKPT_NAME):
218
  raise FileNotFoundError(f"{BASE_CKPT_NAME} not found")
219
 
220
+ # On ZeroGPU, we can load to 'cuda' directly, or 'cpu' then move.
221
+ # To be safe against the deepcopy error, we load to cpu then move.
222
+ # If the deepcopy error persists, we might need to load directly to 'cuda'.
223
+ # Let's try CPU load -> move to cuda.
224
+
225
  model = build_sam2(SAM2_CONFIG, BASE_CKPT_NAME, device="cpu")
226
 
227
  # 2. Fine-tuned Weights
 
229
  raise FileNotFoundError(f"{FINAL_CKPT_NAME} not found")
230
 
231
  sd = torch.load(FINAL_CKPT_NAME, map_location="cpu")
 
232
  model.load_state_dict(sd.get("model", sd), strict=True)
233
+
234
+ # Move SAM to CUDA now
235
+ model.to("cuda")
236
+ MODEL_SAM = model
237
 
238
  # 3. PLM Adapter
239
  C = model.sam_mask_decoder.transformer_dim
 
247
  lora_alpha=32,
248
  lora_dropout=0.05,
249
  dtype=torch.bfloat16,
250
+ device="cpu", # Init on CPU
251
+ )
252
+
 
253
  if not os.path.exists(PLM_CKPT_NAME):
254
  raise FileNotFoundError(f"{PLM_CKPT_NAME} not found")
255
 
 
258
 
259
  if LORA_CKPT_NAME and os.path.exists(LORA_CKPT_NAME):
260
  plm.load_lora(LORA_CKPT_NAME)
261
+
262
+ # Move PLM to CUDA
263
+ plm.to("cuda")
264
+ plm.eval()
265
+ PLM = plm
266
 
267
+ print("Models loaded successfully.")
268
+ return MODEL_SAM, PLM
269
 
 
 
 
 
 
 
 
 
270
 
271
+ @spaces.GPU(duration=120) # Increased duration for first-time load
272
  def run_prediction(image_pil, text_prompt):
 
 
 
273
  if image_pil is None or not text_prompt:
274
  return None, None, None
275
 
276
  predictor = None
277
+
278
  try:
279
+ # 1. Ensure models are loaded (Lazy Load)
280
+ model_sam, plm = load_models_lazy()
281
+
282
+ # 2. Instantiate Predictor
283
+ # We assume models are already on CUDA from load_models_lazy
284
+ # Just to be sure, we can call .to("cuda") again (cheap if already there)
285
+ model_sam.to("cuda")
286
+ plm.to("cuda")
287
 
288
+ predictor = SAM2ImagePredictor(model_sam)
 
 
289
 
290
  # 3. Preprocess Image
291
  rgb_orig = np.array(image_pil.convert("RGB"))
 
294
  rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
295
 
296
  # 4. SAM2 Image Encoding
 
297
  predictor.set_image(rgb_sq)
298
  image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
299
  hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
 
303
  temp_path = "temp_input.jpg"
304
  image_pil.save(temp_path)
305
 
306
+ sp, dp = plm([text_prompt], H_feat, W_feat, [temp_path])
 
 
307
 
308
+ # 6. Prepare SAM2 Decoder inputs
309
  dec = predictor.model.sam_mask_decoder
310
  dev = next(dec.parameters()).device
311
  dtype = next(dec.parameters()).dtype
 
365
  except Exception as e:
366
  print("An error occurred during inference:")
367
  traceback.print_exc()
368
+ raise e
369
 
370
  finally:
371
+ # Cleanup
 
 
 
372
  if predictor:
373
  del predictor
374
  torch.cuda.empty_cache()