update app
Browse files
app.py
CHANGED
|
@@ -197,27 +197,26 @@ def get_text_to_image_attention(decoder: MaskDecoder):
|
|
| 197 |
text_attn = attn[..., n_output_tokens:, :]
|
| 198 |
return text_attn
|
| 199 |
|
| 200 |
-
def
|
| 201 |
print("Loading models on CPU...")
|
| 202 |
|
| 203 |
-
# 1. Base SAM2 Model
|
| 204 |
-
# We assume files are present locally (uploaded via CLI or LFS)
|
| 205 |
if not os.path.exists(BASE_CKPT_NAME):
|
| 206 |
raise FileNotFoundError(f"{BASE_CKPT_NAME} not found")
|
| 207 |
|
| 208 |
model = build_sam2(SAM2_CONFIG, BASE_CKPT_NAME, device="cpu")
|
| 209 |
-
|
| 210 |
-
predictor.model.eval()
|
| 211 |
-
|
| 212 |
# 2. Fine-tuned Weights
|
| 213 |
if not os.path.exists(FINAL_CKPT_NAME):
|
| 214 |
raise FileNotFoundError(f"{FINAL_CKPT_NAME} not found")
|
| 215 |
|
| 216 |
sd = torch.load(FINAL_CKPT_NAME, map_location="cpu")
|
| 217 |
-
|
|
|
|
|
|
|
| 218 |
|
| 219 |
# 3. PLM Adapter
|
| 220 |
-
C =
|
| 221 |
plm = PLMLanguageAdapter(
|
| 222 |
model_name="Qwen/Qwen2.5-VL-3B-Instruct",
|
| 223 |
transformer_dim=C,
|
|
@@ -242,59 +241,69 @@ def load_models():
|
|
| 242 |
plm.load_lora(LORA_CKPT_NAME)
|
| 243 |
|
| 244 |
print("Models loaded successfully (CPU).")
|
| 245 |
-
return
|
| 246 |
|
| 247 |
# Initialize global models on CPU
|
| 248 |
try:
|
| 249 |
-
|
|
|
|
| 250 |
except Exception as e:
|
| 251 |
print(f"Error loading models: {e}")
|
| 252 |
-
|
|
|
|
| 253 |
|
| 254 |
-
@spaces.GPU #
|
| 255 |
def run_prediction(image_pil, text_prompt):
|
| 256 |
-
if
|
| 257 |
return None, None, None
|
| 258 |
|
| 259 |
if image_pil is None or not text_prompt:
|
| 260 |
return None, None, None
|
| 261 |
|
|
|
|
|
|
|
| 262 |
try:
|
| 263 |
-
# 1. Move models to GPU
|
| 264 |
print("Moving models to CUDA...")
|
| 265 |
-
|
| 266 |
PLM.to("cuda")
|
| 267 |
|
| 268 |
-
# 2.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
rgb_orig = np.array(image_pil.convert("RGB"))
|
| 270 |
Hgt, Wgt = rgb_orig.shape[:2]
|
| 271 |
meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
|
| 272 |
rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
|
| 273 |
|
| 274 |
-
#
|
| 275 |
-
# set_image puts features on the model's device
|
| 276 |
-
|
| 277 |
-
image_emb =
|
| 278 |
-
hi = [lvl[-1].unsqueeze(0) for lvl in
|
| 279 |
_, _, H_feat, W_feat = image_emb.shape
|
| 280 |
|
| 281 |
-
#
|
| 282 |
temp_path = "temp_input.jpg"
|
| 283 |
image_pil.save(temp_path)
|
| 284 |
|
|
|
|
|
|
|
| 285 |
sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
|
| 286 |
|
| 287 |
-
#
|
| 288 |
-
dec =
|
| 289 |
-
dev = next(dec.parameters()).device
|
| 290 |
dtype = next(dec.parameters()).dtype
|
| 291 |
|
| 292 |
-
image_pe =
|
| 293 |
image_emb = image_emb.to(dev, dtype)
|
| 294 |
hi = [h.to(dev, dtype) for h in hi]
|
| 295 |
sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
|
| 296 |
|
| 297 |
-
#
|
| 298 |
low, scores, _, _ = dec(
|
| 299 |
image_embeddings=image_emb,
|
| 300 |
image_pe=image_pe,
|
|
@@ -305,7 +314,7 @@ def run_prediction(image_pil, text_prompt):
|
|
| 305 |
high_res_features=hi,
|
| 306 |
)
|
| 307 |
|
| 308 |
-
logits_sq =
|
| 309 |
best = scores.argmax(dim=1).item()
|
| 310 |
logit_sq = logits_sq[0, best]
|
| 311 |
logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
|
|
@@ -313,7 +322,7 @@ def run_prediction(image_pil, text_prompt):
|
|
| 313 |
prob = torch.sigmoid(logit_gt)
|
| 314 |
mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
|
| 315 |
|
| 316 |
-
#
|
| 317 |
overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
|
| 318 |
|
| 319 |
# Attention
|
|
@@ -321,7 +330,6 @@ def run_prediction(image_pil, text_prompt):
|
|
| 321 |
attn_overlay_img = None
|
| 322 |
|
| 323 |
if text_attn is not None:
|
| 324 |
-
# Move attn back to CPU for numpy processing
|
| 325 |
text_attn = text_attn.cpu()
|
| 326 |
attn_flat = text_attn.mean(dim=(0, 2, 3))
|
| 327 |
global_flat = attn_flat[0]
|
|
@@ -341,13 +349,20 @@ def run_prediction(image_pil, text_prompt):
|
|
| 341 |
mask_img = Image.fromarray(mask, mode="L")
|
| 342 |
|
| 343 |
return overlay_img, mask_img, attn_overlay_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
finally:
|
| 346 |
-
# Cleanup: Move models back to CPU
|
| 347 |
-
# This is courteous in ZeroGPU environment
|
| 348 |
print("Moving models back to CPU...")
|
| 349 |
-
|
| 350 |
PLM.to("cpu")
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
# ----------------- Gradio UI -----------------
|
| 353 |
|
|
|
|
| 197 |
text_attn = attn[..., n_output_tokens:, :]
|
| 198 |
return text_attn
|
| 199 |
|
| 200 |
+
def load_models_cpu():
|
| 201 |
print("Loading models on CPU...")
|
| 202 |
|
| 203 |
+
# 1. Base SAM2 Model (Raw Model, not Predictor)
|
|
|
|
| 204 |
if not os.path.exists(BASE_CKPT_NAME):
|
| 205 |
raise FileNotFoundError(f"{BASE_CKPT_NAME} not found")
|
| 206 |
|
| 207 |
model = build_sam2(SAM2_CONFIG, BASE_CKPT_NAME, device="cpu")
|
| 208 |
+
|
|
|
|
|
|
|
| 209 |
# 2. Fine-tuned Weights
|
| 210 |
if not os.path.exists(FINAL_CKPT_NAME):
|
| 211 |
raise FileNotFoundError(f"{FINAL_CKPT_NAME} not found")
|
| 212 |
|
| 213 |
sd = torch.load(FINAL_CKPT_NAME, map_location="cpu")
|
| 214 |
+
# Load into the model directly
|
| 215 |
+
model.load_state_dict(sd.get("model", sd), strict=True)
|
| 216 |
+
model.eval()
|
| 217 |
|
| 218 |
# 3. PLM Adapter
|
| 219 |
+
C = model.sam_mask_decoder.transformer_dim
|
| 220 |
plm = PLMLanguageAdapter(
|
| 221 |
model_name="Qwen/Qwen2.5-VL-3B-Instruct",
|
| 222 |
transformer_dim=C,
|
|
|
|
| 241 |
plm.load_lora(LORA_CKPT_NAME)
|
| 242 |
|
| 243 |
print("Models loaded successfully (CPU).")
|
| 244 |
+
return model, plm
|
| 245 |
|
| 246 |
# Initialize global models on CPU
|
| 247 |
try:
|
| 248 |
+
# NOTE: We hold the raw MODEL_SAM here, not the predictor
|
| 249 |
+
MODEL_SAM, PLM = load_models_cpu()
|
| 250 |
except Exception as e:
|
| 251 |
print(f"Error loading models: {e}")
|
| 252 |
+
traceback.print_exc()
|
| 253 |
+
MODEL_SAM, PLM = None, None
|
| 254 |
|
| 255 |
+
@spaces.GPU(duration=60) # Ensure we have enough time (default is often 60s)
|
| 256 |
def run_prediction(image_pil, text_prompt):
|
| 257 |
+
if MODEL_SAM is None or PLM is None:
|
| 258 |
return None, None, None
|
| 259 |
|
| 260 |
if image_pil is None or not text_prompt:
|
| 261 |
return None, None, None
|
| 262 |
|
| 263 |
+
predictor = None
|
| 264 |
+
|
| 265 |
try:
|
| 266 |
+
# 1. Move models to GPU
|
| 267 |
print("Moving models to CUDA...")
|
| 268 |
+
MODEL_SAM.to("cuda")
|
| 269 |
PLM.to("cuda")
|
| 270 |
|
| 271 |
+
# 2. Instantiate Predictor ON GPU (Crucial Fix)
|
| 272 |
+
# This ensures the predictor knows it's on CUDA
|
| 273 |
+
predictor = SAM2ImagePredictor(MODEL_SAM)
|
| 274 |
+
|
| 275 |
+
# 3. Preprocess Image
|
| 276 |
rgb_orig = np.array(image_pil.convert("RGB"))
|
| 277 |
Hgt, Wgt = rgb_orig.shape[:2]
|
| 278 |
meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
|
| 279 |
rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
|
| 280 |
|
| 281 |
+
# 4. SAM2 Image Encoding
|
| 282 |
+
# set_image puts features on the model's device
|
| 283 |
+
predictor.set_image(rgb_sq)
|
| 284 |
+
image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
|
| 285 |
+
hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
|
| 286 |
_, _, H_feat, W_feat = image_emb.shape
|
| 287 |
|
| 288 |
+
# 5. PLM Inference
|
| 289 |
temp_path = "temp_input.jpg"
|
| 290 |
image_pil.save(temp_path)
|
| 291 |
|
| 292 |
+
# PLM inference usually handles device mapping internally if written well,
|
| 293 |
+
# but we ensure inputs are passed cleanly.
|
| 294 |
sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
|
| 295 |
|
| 296 |
+
# 6. Prepare SAM2 Decoder inputs (ensure they are on CUDA)
|
| 297 |
+
dec = predictor.model.sam_mask_decoder
|
| 298 |
+
dev = next(dec.parameters()).device
|
| 299 |
dtype = next(dec.parameters()).dtype
|
| 300 |
|
| 301 |
+
image_pe = predictor.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype)
|
| 302 |
image_emb = image_emb.to(dev, dtype)
|
| 303 |
hi = [h.to(dev, dtype) for h in hi]
|
| 304 |
sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
|
| 305 |
|
| 306 |
+
# 7. SAM2 Decoding
|
| 307 |
low, scores, _, _ = dec(
|
| 308 |
image_embeddings=image_emb,
|
| 309 |
image_pe=image_pe,
|
|
|
|
| 314 |
high_res_features=hi,
|
| 315 |
)
|
| 316 |
|
| 317 |
+
logits_sq = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
|
| 318 |
best = scores.argmax(dim=1).item()
|
| 319 |
logit_sq = logits_sq[0, best]
|
| 320 |
logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
|
|
|
|
| 322 |
prob = torch.sigmoid(logit_gt)
|
| 323 |
mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
|
| 324 |
|
| 325 |
+
# 8. Visualization
|
| 326 |
overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
|
| 327 |
|
| 328 |
# Attention
|
|
|
|
| 330 |
attn_overlay_img = None
|
| 331 |
|
| 332 |
if text_attn is not None:
|
|
|
|
| 333 |
text_attn = text_attn.cpu()
|
| 334 |
attn_flat = text_attn.mean(dim=(0, 2, 3))
|
| 335 |
global_flat = attn_flat[0]
|
|
|
|
| 349 |
mask_img = Image.fromarray(mask, mode="L")
|
| 350 |
|
| 351 |
return overlay_img, mask_img, attn_overlay_img
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
print("An error occurred during inference:")
|
| 355 |
+
traceback.print_exc()
|
| 356 |
+
raise e # Let Gradio show the error
|
| 357 |
|
| 358 |
finally:
|
| 359 |
+
# Cleanup: Move models back to CPU
|
|
|
|
| 360 |
print("Moving models back to CPU...")
|
| 361 |
+
MODEL_SAM.to("cpu")
|
| 362 |
PLM.to("cpu")
|
| 363 |
+
if predictor:
|
| 364 |
+
del predictor
|
| 365 |
+
torch.cuda.empty_cache()
|
| 366 |
|
| 367 |
# ----------------- Gradio UI -----------------
|
| 368 |
|