aadarsh99 commited on
Commit
aeaa431
·
1 Parent(s): edf7653

update app

Browse files
Files changed (1) hide show
  1. app.py +107 -107
app.py CHANGED
@@ -45,7 +45,6 @@ def _hex_to_rgb(h: str):
45
  EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
46
 
47
  def stable_color(key: str):
48
- # Use a fixed key if simple color is desired
49
  h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
50
  return EDGE_COLORS[h % len(EDGE_COLORS)]
51
 
@@ -197,34 +196,23 @@ def get_text_to_image_attention(decoder: MaskDecoder):
197
  text_attn = attn[..., n_output_tokens:, :]
198
  return text_attn
199
 
200
- def download_model_if_needed(filename):
201
- """Checks local disk, else downloads from HF Hub."""
202
- if os.path.exists(filename):
203
- return filename
204
- try:
205
- print(f"Downloading {filename} from {HF_REPO_ID}...")
206
- path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
207
- return path
208
- except Exception as e:
209
- print(f"Could not download {filename}. Ensure it exists locally or in the HF repo.")
210
- # Fallback for Space: if files are uploaded directly to the Files tab,
211
- # they are in the current working directory.
212
- if os.path.exists(filename):
213
- return filename
214
- raise e
215
-
216
  def load_models():
217
- print("Loading models...")
218
 
219
  # 1. Base SAM2 Model
220
- base_ckpt_path = download_model_if_needed(BASE_CKPT_NAME)
221
- model = build_sam2(SAM2_CONFIG, base_ckpt_path, device=DEVICE)
 
 
 
222
  predictor = SAM2ImagePredictor(model)
223
  predictor.model.eval()
224
 
225
  # 2. Fine-tuned Weights
226
- final_ckpt_path = download_model_if_needed(FINAL_CKPT_NAME)
227
- sd = torch.load(final_ckpt_path, map_location=DEVICE)
 
 
228
  predictor.model.load_state_dict(sd.get("model", sd), strict=True)
229
 
230
  # 3. PLM Adapter
@@ -239,30 +227,30 @@ def load_models():
239
  lora_alpha=32,
240
  lora_dropout=0.05,
241
  dtype=torch.bfloat16,
242
- device=DEVICE,
243
- ).to(DEVICE)
244
  plm.eval()
245
 
246
- plm_ckpt_path = download_model_if_needed(PLM_CKPT_NAME)
247
- plm_sd = torch.load(plm_ckpt_path, map_location=DEVICE)
 
 
248
  plm.load_state_dict(plm_sd["plm"], strict=True)
249
 
250
- if LORA_CKPT_NAME:
251
- lora_path = download_model_if_needed(LORA_CKPT_NAME)
252
- plm.load_lora(lora_path)
253
 
254
- print("Models loaded successfully.")
255
  return predictor, plm
256
 
257
- # Initialize global models
258
  try:
259
  PREDICTOR, PLM = load_models()
260
  except Exception as e:
261
  print(f"Error loading models: {e}")
262
- print("Please check your checkpoint filenames and HF_REPO_ID in the script.")
263
  PREDICTOR, PLM = None, None
264
 
265
- @torch.no_grad()
266
  def run_prediction(image_pil, text_prompt):
267
  if PREDICTOR is None or PLM is None:
268
  return None, None, None
@@ -270,83 +258,95 @@ def run_prediction(image_pil, text_prompt):
270
  if image_pil is None or not text_prompt:
271
  return None, None, None
272
 
273
- # Preprocess
274
- rgb_orig = np.array(image_pil.convert("RGB"))
275
- Hgt, Wgt = rgb_orig.shape[:2]
276
- meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
277
- rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
278
-
279
- PREDICTOR.set_image(rgb_sq)
280
- image_emb = PREDICTOR._features["image_embed"][-1].unsqueeze(0)
281
- hi = [lvl[-1].unsqueeze(0) for lvl in PREDICTOR._features["high_res_feats"]]
282
- _, _, H_feat, W_feat = image_emb.shape
283
-
284
- # PLM Inference
285
- # Note: PLM expects a path list for 'images', but the Qwen adapter likely handles
286
- # the internal logic. If your PLM adapter strictly requires disk paths,
287
- # save 'image_pil' to a temp file here.
288
- # Assuming PLM adapter needs a placeholder path or we save temp:
289
- temp_path = "temp_input.jpg"
290
- image_pil.save(temp_path)
291
-
292
- sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
293
-
294
- dec = PREDICTOR.model.sam_mask_decoder
295
- dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
296
- image_pe = PREDICTOR.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype)
297
- image_emb = image_emb.to(dev, dtype)
298
- hi = [h.to(dev, dtype) for h in hi]
299
- sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
300
-
301
- # SAM2 Decoding
302
- low, scores, _, _ = dec(
303
- image_embeddings=image_emb,
304
- image_pe=image_pe,
305
- sparse_prompt_embeddings=sp,
306
- dense_prompt_embeddings=dp,
307
- multimask_output=True,
308
- repeat_image=False,
309
- high_res_features=hi,
310
- )
311
-
312
- logits_sq = PREDICTOR._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
313
- best = scores.argmax(dim=1).item()
314
- logit_sq = logits_sq[0, best]
315
- logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
316
-
317
- prob = torch.sigmoid(logit_gt)
318
- mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
319
-
320
- # Visualization: Overlay
321
- overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
322
-
323
- # Visualization: Attention
324
- text_attn = get_text_to_image_attention(dec)
325
- attn_overlay_img = None
326
-
327
- if text_attn is not None:
328
- L_layer, B, H_heads, N_text, N_img = text_attn.shape
329
- attn_flat = text_attn.mean(dim=(0, 2, 3)) # Mean over layers, heads, text
330
- global_flat = attn_flat[0]
331
- a = global_flat.view(H_feat, W_feat)
332
 
333
- # Upsample attention
334
- a_sq = F.interpolate(
335
- a.unsqueeze(0).unsqueeze(0),
336
- size=(SQUARE_DIM, SQUARE_DIM),
337
- mode="bilinear",
338
- align_corners=False,
339
- )[0, 0]
 
 
 
 
 
 
 
 
 
340
 
341
- a_gt = _unpad_and_resize_pred_to_gt(a_sq, meta, (Hgt, Wgt))
342
- global_attn_orig = a_gt.cpu().numpy()
343
- attn_overlay_img = make_attn_overlay(rgb_orig, global_attn_orig)
344
-
345
- # Return list of images for Gallery or individual blocks
346
- # Mask as an image
347
- mask_img = Image.fromarray(mask, mode="L")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
- return overlay_img, mask_img, attn_overlay_img
 
 
 
 
 
350
 
351
  # ----------------- Gradio UI -----------------
352
 
 
45
  EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
46
 
47
  def stable_color(key: str):
 
48
  h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
49
  return EDGE_COLORS[h % len(EDGE_COLORS)]
50
 
 
196
  text_attn = attn[..., n_output_tokens:, :]
197
  return text_attn
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def load_models():
200
+ print("Loading models on CPU...")
201
 
202
  # 1. Base SAM2 Model
203
+ # We assume files are present locally (uploaded via CLI or LFS)
204
+ if not os.path.exists(BASE_CKPT_NAME):
205
+ raise FileNotFoundError(f"{BASE_CKPT_NAME} not found")
206
+
207
+ model = build_sam2(SAM2_CONFIG, BASE_CKPT_NAME, device="cpu")
208
  predictor = SAM2ImagePredictor(model)
209
  predictor.model.eval()
210
 
211
  # 2. Fine-tuned Weights
212
+ if not os.path.exists(FINAL_CKPT_NAME):
213
+ raise FileNotFoundError(f"{FINAL_CKPT_NAME} not found")
214
+
215
+ sd = torch.load(FINAL_CKPT_NAME, map_location="cpu")
216
  predictor.model.load_state_dict(sd.get("model", sd), strict=True)
217
 
218
  # 3. PLM Adapter
 
227
  lora_alpha=32,
228
  lora_dropout=0.05,
229
  dtype=torch.bfloat16,
230
+ device="cpu",
231
+ ).to("cpu")
232
  plm.eval()
233
 
234
+ if not os.path.exists(PLM_CKPT_NAME):
235
+ raise FileNotFoundError(f"{PLM_CKPT_NAME} not found")
236
+
237
+ plm_sd = torch.load(PLM_CKPT_NAME, map_location="cpu")
238
  plm.load_state_dict(plm_sd["plm"], strict=True)
239
 
240
+ if LORA_CKPT_NAME and os.path.exists(LORA_CKPT_NAME):
241
+ plm.load_lora(LORA_CKPT_NAME)
 
242
 
243
+ print("Models loaded successfully (CPU).")
244
  return predictor, plm
245
 
246
+ # Initialize global models on CPU
247
  try:
248
  PREDICTOR, PLM = load_models()
249
  except Exception as e:
250
  print(f"Error loading models: {e}")
 
251
  PREDICTOR, PLM = None, None
252
 
253
+ @spaces.GPU # <--- REQUIRED FOR ZEROGPU
254
  def run_prediction(image_pil, text_prompt):
255
  if PREDICTOR is None or PLM is None:
256
  return None, None, None
 
258
  if image_pil is None or not text_prompt:
259
  return None, None, None
260
 
261
+ try:
262
+ # 1. Move models to GPU for this inference session
263
+ print("Moving models to CUDA...")
264
+ PREDICTOR.model.to("cuda")
265
+ PLM.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
+ # 2. Preprocess
268
+ rgb_orig = np.array(image_pil.convert("RGB"))
269
+ Hgt, Wgt = rgb_orig.shape[:2]
270
+ meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
271
+ rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
272
+
273
+ # 3. SAM2 Image Encoding
274
+ # set_image puts features on the model's device (now cuda)
275
+ PREDICTOR.set_image(rgb_sq)
276
+ image_emb = PREDICTOR._features["image_embed"][-1].unsqueeze(0)
277
+ hi = [lvl[-1].unsqueeze(0) for lvl in PREDICTOR._features["high_res_feats"]]
278
+ _, _, H_feat, W_feat = image_emb.shape
279
+
280
+ # 4. PLM Inference
281
+ temp_path = "temp_input.jpg"
282
+ image_pil.save(temp_path)
283
 
284
+ sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
285
+
286
+ # 5. Prepare SAM2 Decoder inputs (ensure they are on CUDA)
287
+ dec = PREDICTOR.model.sam_mask_decoder
288
+ dev = next(dec.parameters()).device # should be cuda now
289
+ dtype = next(dec.parameters()).dtype
290
+
291
+ image_pe = PREDICTOR.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype)
292
+ image_emb = image_emb.to(dev, dtype)
293
+ hi = [h.to(dev, dtype) for h in hi]
294
+ sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
295
+
296
+ # 6. SAM2 Decoding
297
+ low, scores, _, _ = dec(
298
+ image_embeddings=image_emb,
299
+ image_pe=image_pe,
300
+ sparse_prompt_embeddings=sp,
301
+ dense_prompt_embeddings=dp,
302
+ multimask_output=True,
303
+ repeat_image=False,
304
+ high_res_features=hi,
305
+ )
306
+
307
+ logits_sq = PREDICTOR._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
308
+ best = scores.argmax(dim=1).item()
309
+ logit_sq = logits_sq[0, best]
310
+ logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
311
+
312
+ prob = torch.sigmoid(logit_gt)
313
+ mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
314
+
315
+ # 7. Visualization
316
+ overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
317
+
318
+ # Attention
319
+ text_attn = get_text_to_image_attention(dec)
320
+ attn_overlay_img = None
321
+
322
+ if text_attn is not None:
323
+ # Move attn back to CPU for numpy processing
324
+ text_attn = text_attn.cpu()
325
+ attn_flat = text_attn.mean(dim=(0, 2, 3))
326
+ global_flat = attn_flat[0]
327
+ a = global_flat.view(H_feat, W_feat)
328
+
329
+ a_sq = F.interpolate(
330
+ a.unsqueeze(0).unsqueeze(0),
331
+ size=(SQUARE_DIM, SQUARE_DIM),
332
+ mode="bilinear",
333
+ align_corners=False,
334
+ )[0, 0]
335
+
336
+ a_gt = _unpad_and_resize_pred_to_gt(a_sq, meta, (Hgt, Wgt))
337
+ global_attn_orig = a_gt.numpy()
338
+ attn_overlay_img = make_attn_overlay(rgb_orig, global_attn_orig)
339
+
340
+ mask_img = Image.fromarray(mask, mode="L")
341
+
342
+ return overlay_img, mask_img, attn_overlay_img
343
 
344
+ finally:
345
+ # Cleanup: Move models back to CPU to free GPU memory for other users/sessions
346
+ # This is courteous in ZeroGPU environment
347
+ print("Moving models back to CPU...")
348
+ PREDICTOR.model.to("cpu")
349
+ PLM.to("cpu")
350
 
351
  # ----------------- Gradio UI -----------------
352