aadarsh99 commited on
Commit
461d3a3
·
1 Parent(s): 39dd4a1

update app

Browse files
Files changed (1) hide show
  1. app.py +48 -33
app.py CHANGED
@@ -197,27 +197,26 @@ def get_text_to_image_attention(decoder: MaskDecoder):
197
  text_attn = attn[..., n_output_tokens:, :]
198
  return text_attn
199
 
200
- def load_models():
201
  print("Loading models on CPU...")
202
 
203
- # 1. Base SAM2 Model
204
- # We assume files are present locally (uploaded via CLI or LFS)
205
  if not os.path.exists(BASE_CKPT_NAME):
206
  raise FileNotFoundError(f"{BASE_CKPT_NAME} not found")
207
 
208
  model = build_sam2(SAM2_CONFIG, BASE_CKPT_NAME, device="cpu")
209
- predictor = SAM2ImagePredictor(model)
210
- predictor.model.eval()
211
-
212
  # 2. Fine-tuned Weights
213
  if not os.path.exists(FINAL_CKPT_NAME):
214
  raise FileNotFoundError(f"{FINAL_CKPT_NAME} not found")
215
 
216
  sd = torch.load(FINAL_CKPT_NAME, map_location="cpu")
217
- predictor.model.load_state_dict(sd.get("model", sd), strict=True)
 
 
218
 
219
  # 3. PLM Adapter
220
- C = predictor.model.sam_mask_decoder.transformer_dim
221
  plm = PLMLanguageAdapter(
222
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
223
  transformer_dim=C,
@@ -242,59 +241,69 @@ def load_models():
242
  plm.load_lora(LORA_CKPT_NAME)
243
 
244
  print("Models loaded successfully (CPU).")
245
- return predictor, plm
246
 
247
  # Initialize global models on CPU
248
  try:
249
- PREDICTOR, PLM = load_models()
 
250
  except Exception as e:
251
  print(f"Error loading models: {e}")
252
- PREDICTOR, PLM = None, None
 
253
 
254
- @spaces.GPU # <--- REQUIRED FOR ZEROGPU
255
  def run_prediction(image_pil, text_prompt):
256
- if PREDICTOR is None or PLM is None:
257
  return None, None, None
258
 
259
  if image_pil is None or not text_prompt:
260
  return None, None, None
261
 
 
 
262
  try:
263
- # 1. Move models to GPU for this inference session
264
  print("Moving models to CUDA...")
265
- PREDICTOR.model.to("cuda")
266
  PLM.to("cuda")
267
 
268
- # 2. Preprocess
 
 
 
 
269
  rgb_orig = np.array(image_pil.convert("RGB"))
270
  Hgt, Wgt = rgb_orig.shape[:2]
271
  meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
272
  rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
273
 
274
- # 3. SAM2 Image Encoding
275
- # set_image puts features on the model's device (now cuda)
276
- PREDICTOR.set_image(rgb_sq)
277
- image_emb = PREDICTOR._features["image_embed"][-1].unsqueeze(0)
278
- hi = [lvl[-1].unsqueeze(0) for lvl in PREDICTOR._features["high_res_feats"]]
279
  _, _, H_feat, W_feat = image_emb.shape
280
 
281
- # 4. PLM Inference
282
  temp_path = "temp_input.jpg"
283
  image_pil.save(temp_path)
284
 
 
 
285
  sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
286
 
287
- # 5. Prepare SAM2 Decoder inputs (ensure they are on CUDA)
288
- dec = PREDICTOR.model.sam_mask_decoder
289
- dev = next(dec.parameters()).device # should be cuda now
290
  dtype = next(dec.parameters()).dtype
291
 
292
- image_pe = PREDICTOR.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype)
293
  image_emb = image_emb.to(dev, dtype)
294
  hi = [h.to(dev, dtype) for h in hi]
295
  sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
296
 
297
- # 6. SAM2 Decoding
298
  low, scores, _, _ = dec(
299
  image_embeddings=image_emb,
300
  image_pe=image_pe,
@@ -305,7 +314,7 @@ def run_prediction(image_pil, text_prompt):
305
  high_res_features=hi,
306
  )
307
 
308
- logits_sq = PREDICTOR._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
309
  best = scores.argmax(dim=1).item()
310
  logit_sq = logits_sq[0, best]
311
  logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
@@ -313,7 +322,7 @@ def run_prediction(image_pil, text_prompt):
313
  prob = torch.sigmoid(logit_gt)
314
  mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
315
 
316
- # 7. Visualization
317
  overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
318
 
319
  # Attention
@@ -321,7 +330,6 @@ def run_prediction(image_pil, text_prompt):
321
  attn_overlay_img = None
322
 
323
  if text_attn is not None:
324
- # Move attn back to CPU for numpy processing
325
  text_attn = text_attn.cpu()
326
  attn_flat = text_attn.mean(dim=(0, 2, 3))
327
  global_flat = attn_flat[0]
@@ -341,13 +349,20 @@ def run_prediction(image_pil, text_prompt):
341
  mask_img = Image.fromarray(mask, mode="L")
342
 
343
  return overlay_img, mask_img, attn_overlay_img
 
 
 
 
 
344
 
345
  finally:
346
- # Cleanup: Move models back to CPU to free GPU memory for other users/sessions
347
- # This is courteous in ZeroGPU environment
348
  print("Moving models back to CPU...")
349
- PREDICTOR.model.to("cpu")
350
  PLM.to("cpu")
 
 
 
351
 
352
  # ----------------- Gradio UI -----------------
353
 
 
197
  text_attn = attn[..., n_output_tokens:, :]
198
  return text_attn
199
 
200
+ def load_models_cpu():
201
  print("Loading models on CPU...")
202
 
203
+ # 1. Base SAM2 Model (Raw Model, not Predictor)
 
204
  if not os.path.exists(BASE_CKPT_NAME):
205
  raise FileNotFoundError(f"{BASE_CKPT_NAME} not found")
206
 
207
  model = build_sam2(SAM2_CONFIG, BASE_CKPT_NAME, device="cpu")
208
+
 
 
209
  # 2. Fine-tuned Weights
210
  if not os.path.exists(FINAL_CKPT_NAME):
211
  raise FileNotFoundError(f"{FINAL_CKPT_NAME} not found")
212
 
213
  sd = torch.load(FINAL_CKPT_NAME, map_location="cpu")
214
+ # Load into the model directly
215
+ model.load_state_dict(sd.get("model", sd), strict=True)
216
+ model.eval()
217
 
218
  # 3. PLM Adapter
219
+ C = model.sam_mask_decoder.transformer_dim
220
  plm = PLMLanguageAdapter(
221
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
222
  transformer_dim=C,
 
241
  plm.load_lora(LORA_CKPT_NAME)
242
 
243
  print("Models loaded successfully (CPU).")
244
+ return model, plm
245
 
246
  # Initialize global models on CPU
247
  try:
248
+ # NOTE: We hold the raw MODEL_SAM here, not the predictor
249
+ MODEL_SAM, PLM = load_models_cpu()
250
  except Exception as e:
251
  print(f"Error loading models: {e}")
252
+ traceback.print_exc()
253
+ MODEL_SAM, PLM = None, None
254
 
255
+ @spaces.GPU(duration=60) # Ensure we have enough time (default is often 60s)
256
  def run_prediction(image_pil, text_prompt):
257
+ if MODEL_SAM is None or PLM is None:
258
  return None, None, None
259
 
260
  if image_pil is None or not text_prompt:
261
  return None, None, None
262
 
263
+ predictor = None
264
+
265
  try:
266
+ # 1. Move models to GPU
267
  print("Moving models to CUDA...")
268
+ MODEL_SAM.to("cuda")
269
  PLM.to("cuda")
270
 
271
+ # 2. Instantiate Predictor ON GPU (Crucial Fix)
272
+ # This ensures the predictor knows it's on CUDA
273
+ predictor = SAM2ImagePredictor(MODEL_SAM)
274
+
275
+ # 3. Preprocess Image
276
  rgb_orig = np.array(image_pil.convert("RGB"))
277
  Hgt, Wgt = rgb_orig.shape[:2]
278
  meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
279
  rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
280
 
281
+ # 4. SAM2 Image Encoding
282
+ # set_image puts features on the model's device
283
+ predictor.set_image(rgb_sq)
284
+ image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
285
+ hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
286
  _, _, H_feat, W_feat = image_emb.shape
287
 
288
+ # 5. PLM Inference
289
  temp_path = "temp_input.jpg"
290
  image_pil.save(temp_path)
291
 
292
+ # PLM inference usually handles device mapping internally if written well,
293
+ # but we ensure inputs are passed cleanly.
294
  sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
295
 
296
+ # 6. Prepare SAM2 Decoder inputs (ensure they are on CUDA)
297
+ dec = predictor.model.sam_mask_decoder
298
+ dev = next(dec.parameters()).device
299
  dtype = next(dec.parameters()).dtype
300
 
301
+ image_pe = predictor.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype)
302
  image_emb = image_emb.to(dev, dtype)
303
  hi = [h.to(dev, dtype) for h in hi]
304
  sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
305
 
306
+ # 7. SAM2 Decoding
307
  low, scores, _, _ = dec(
308
  image_embeddings=image_emb,
309
  image_pe=image_pe,
 
314
  high_res_features=hi,
315
  )
316
 
317
+ logits_sq = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
318
  best = scores.argmax(dim=1).item()
319
  logit_sq = logits_sq[0, best]
320
  logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
 
322
  prob = torch.sigmoid(logit_gt)
323
  mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
324
 
325
+ # 8. Visualization
326
  overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
327
 
328
  # Attention
 
330
  attn_overlay_img = None
331
 
332
  if text_attn is not None:
 
333
  text_attn = text_attn.cpu()
334
  attn_flat = text_attn.mean(dim=(0, 2, 3))
335
  global_flat = attn_flat[0]
 
349
  mask_img = Image.fromarray(mask, mode="L")
350
 
351
  return overlay_img, mask_img, attn_overlay_img
352
+
353
+ except Exception as e:
354
+ print("An error occurred during inference:")
355
+ traceback.print_exc()
356
+ raise e # Let Gradio show the error
357
 
358
  finally:
359
+ # Cleanup: Move models back to CPU
 
360
  print("Moving models back to CPU...")
361
+ MODEL_SAM.to("cpu")
362
  PLM.to("cpu")
363
+ if predictor:
364
+ del predictor
365
+ torch.cuda.empty_cache()
366
 
367
  # ----------------- Gradio UI -----------------
368