update app
Browse files
app.py
CHANGED
|
@@ -45,7 +45,6 @@ def _hex_to_rgb(h: str):
|
|
| 45 |
EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
|
| 46 |
|
| 47 |
def stable_color(key: str):
|
| 48 |
-
# Use a fixed key if simple color is desired
|
| 49 |
h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
|
| 50 |
return EDGE_COLORS[h % len(EDGE_COLORS)]
|
| 51 |
|
|
@@ -197,34 +196,23 @@ def get_text_to_image_attention(decoder: MaskDecoder):
|
|
| 197 |
text_attn = attn[..., n_output_tokens:, :]
|
| 198 |
return text_attn
|
| 199 |
|
| 200 |
-
def download_model_if_needed(filename):
|
| 201 |
-
"""Checks local disk, else downloads from HF Hub."""
|
| 202 |
-
if os.path.exists(filename):
|
| 203 |
-
return filename
|
| 204 |
-
try:
|
| 205 |
-
print(f"Downloading {filename} from {HF_REPO_ID}...")
|
| 206 |
-
path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
|
| 207 |
-
return path
|
| 208 |
-
except Exception as e:
|
| 209 |
-
print(f"Could not download {filename}. Ensure it exists locally or in the HF repo.")
|
| 210 |
-
# Fallback for Space: if files are uploaded directly to the Files tab,
|
| 211 |
-
# they are in the current working directory.
|
| 212 |
-
if os.path.exists(filename):
|
| 213 |
-
return filename
|
| 214 |
-
raise e
|
| 215 |
-
|
| 216 |
def load_models():
|
| 217 |
-
print("Loading models...")
|
| 218 |
|
| 219 |
# 1. Base SAM2 Model
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
| 222 |
predictor = SAM2ImagePredictor(model)
|
| 223 |
predictor.model.eval()
|
| 224 |
|
| 225 |
# 2. Fine-tuned Weights
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
| 228 |
predictor.model.load_state_dict(sd.get("model", sd), strict=True)
|
| 229 |
|
| 230 |
# 3. PLM Adapter
|
|
@@ -239,30 +227,30 @@ def load_models():
|
|
| 239 |
lora_alpha=32,
|
| 240 |
lora_dropout=0.05,
|
| 241 |
dtype=torch.bfloat16,
|
| 242 |
-
device=
|
| 243 |
-
).to(
|
| 244 |
plm.eval()
|
| 245 |
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
| 248 |
plm.load_state_dict(plm_sd["plm"], strict=True)
|
| 249 |
|
| 250 |
-
if LORA_CKPT_NAME:
|
| 251 |
-
|
| 252 |
-
plm.load_lora(lora_path)
|
| 253 |
|
| 254 |
-
print("Models loaded successfully.")
|
| 255 |
return predictor, plm
|
| 256 |
|
| 257 |
-
# Initialize global models
|
| 258 |
try:
|
| 259 |
PREDICTOR, PLM = load_models()
|
| 260 |
except Exception as e:
|
| 261 |
print(f"Error loading models: {e}")
|
| 262 |
-
print("Please check your checkpoint filenames and HF_REPO_ID in the script.")
|
| 263 |
PREDICTOR, PLM = None, None
|
| 264 |
|
| 265 |
-
@
|
| 266 |
def run_prediction(image_pil, text_prompt):
|
| 267 |
if PREDICTOR is None or PLM is None:
|
| 268 |
return None, None, None
|
|
@@ -270,83 +258,95 @@ def run_prediction(image_pil, text_prompt):
|
|
| 270 |
if image_pil is None or not text_prompt:
|
| 271 |
return None, None, None
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
PREDICTOR.set_image(rgb_sq)
|
| 280 |
-
image_emb = PREDICTOR._features["image_embed"][-1].unsqueeze(0)
|
| 281 |
-
hi = [lvl[-1].unsqueeze(0) for lvl in PREDICTOR._features["high_res_feats"]]
|
| 282 |
-
_, _, H_feat, W_feat = image_emb.shape
|
| 283 |
-
|
| 284 |
-
# PLM Inference
|
| 285 |
-
# Note: PLM expects a path list for 'images', but the Qwen adapter likely handles
|
| 286 |
-
# the internal logic. If your PLM adapter strictly requires disk paths,
|
| 287 |
-
# save 'image_pil' to a temp file here.
|
| 288 |
-
# Assuming PLM adapter needs a placeholder path or we save temp:
|
| 289 |
-
temp_path = "temp_input.jpg"
|
| 290 |
-
image_pil.save(temp_path)
|
| 291 |
-
|
| 292 |
-
sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
|
| 293 |
-
|
| 294 |
-
dec = PREDICTOR.model.sam_mask_decoder
|
| 295 |
-
dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
|
| 296 |
-
image_pe = PREDICTOR.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype)
|
| 297 |
-
image_emb = image_emb.to(dev, dtype)
|
| 298 |
-
hi = [h.to(dev, dtype) for h in hi]
|
| 299 |
-
sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
|
| 300 |
-
|
| 301 |
-
# SAM2 Decoding
|
| 302 |
-
low, scores, _, _ = dec(
|
| 303 |
-
image_embeddings=image_emb,
|
| 304 |
-
image_pe=image_pe,
|
| 305 |
-
sparse_prompt_embeddings=sp,
|
| 306 |
-
dense_prompt_embeddings=dp,
|
| 307 |
-
multimask_output=True,
|
| 308 |
-
repeat_image=False,
|
| 309 |
-
high_res_features=hi,
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
logits_sq = PREDICTOR._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
|
| 313 |
-
best = scores.argmax(dim=1).item()
|
| 314 |
-
logit_sq = logits_sq[0, best]
|
| 315 |
-
logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
|
| 316 |
-
|
| 317 |
-
prob = torch.sigmoid(logit_gt)
|
| 318 |
-
mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
|
| 319 |
-
|
| 320 |
-
# Visualization: Overlay
|
| 321 |
-
overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
|
| 322 |
-
|
| 323 |
-
# Visualization: Attention
|
| 324 |
-
text_attn = get_text_to_image_attention(dec)
|
| 325 |
-
attn_overlay_img = None
|
| 326 |
-
|
| 327 |
-
if text_attn is not None:
|
| 328 |
-
L_layer, B, H_heads, N_text, N_img = text_attn.shape
|
| 329 |
-
attn_flat = text_attn.mean(dim=(0, 2, 3)) # Mean over layers, heads, text
|
| 330 |
-
global_flat = attn_flat[0]
|
| 331 |
-
a = global_flat.view(H_feat, W_feat)
|
| 332 |
|
| 333 |
-
#
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
# ----------------- Gradio UI -----------------
|
| 352 |
|
|
|
|
| 45 |
EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
|
| 46 |
|
| 47 |
def stable_color(key: str):
|
|
|
|
| 48 |
h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
|
| 49 |
return EDGE_COLORS[h % len(EDGE_COLORS)]
|
| 50 |
|
|
|
|
| 196 |
text_attn = attn[..., n_output_tokens:, :]
|
| 197 |
return text_attn
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
def load_models():
|
| 200 |
+
print("Loading models on CPU...")
|
| 201 |
|
| 202 |
# 1. Base SAM2 Model
|
| 203 |
+
# We assume files are present locally (uploaded via CLI or LFS)
|
| 204 |
+
if not os.path.exists(BASE_CKPT_NAME):
|
| 205 |
+
raise FileNotFoundError(f"{BASE_CKPT_NAME} not found")
|
| 206 |
+
|
| 207 |
+
model = build_sam2(SAM2_CONFIG, BASE_CKPT_NAME, device="cpu")
|
| 208 |
predictor = SAM2ImagePredictor(model)
|
| 209 |
predictor.model.eval()
|
| 210 |
|
| 211 |
# 2. Fine-tuned Weights
|
| 212 |
+
if not os.path.exists(FINAL_CKPT_NAME):
|
| 213 |
+
raise FileNotFoundError(f"{FINAL_CKPT_NAME} not found")
|
| 214 |
+
|
| 215 |
+
sd = torch.load(FINAL_CKPT_NAME, map_location="cpu")
|
| 216 |
predictor.model.load_state_dict(sd.get("model", sd), strict=True)
|
| 217 |
|
| 218 |
# 3. PLM Adapter
|
|
|
|
| 227 |
lora_alpha=32,
|
| 228 |
lora_dropout=0.05,
|
| 229 |
dtype=torch.bfloat16,
|
| 230 |
+
device="cpu",
|
| 231 |
+
).to("cpu")
|
| 232 |
plm.eval()
|
| 233 |
|
| 234 |
+
if not os.path.exists(PLM_CKPT_NAME):
|
| 235 |
+
raise FileNotFoundError(f"{PLM_CKPT_NAME} not found")
|
| 236 |
+
|
| 237 |
+
plm_sd = torch.load(PLM_CKPT_NAME, map_location="cpu")
|
| 238 |
plm.load_state_dict(plm_sd["plm"], strict=True)
|
| 239 |
|
| 240 |
+
if LORA_CKPT_NAME and os.path.exists(LORA_CKPT_NAME):
|
| 241 |
+
plm.load_lora(LORA_CKPT_NAME)
|
|
|
|
| 242 |
|
| 243 |
+
print("Models loaded successfully (CPU).")
|
| 244 |
return predictor, plm
|
| 245 |
|
| 246 |
+
# Initialize global models on CPU
|
| 247 |
try:
|
| 248 |
PREDICTOR, PLM = load_models()
|
| 249 |
except Exception as e:
|
| 250 |
print(f"Error loading models: {e}")
|
|
|
|
| 251 |
PREDICTOR, PLM = None, None
|
| 252 |
|
| 253 |
+
@spaces.GPU # <--- REQUIRED FOR ZEROGPU
|
| 254 |
def run_prediction(image_pil, text_prompt):
|
| 255 |
if PREDICTOR is None or PLM is None:
|
| 256 |
return None, None, None
|
|
|
|
| 258 |
if image_pil is None or not text_prompt:
|
| 259 |
return None, None, None
|
| 260 |
|
| 261 |
+
try:
|
| 262 |
+
# 1. Move models to GPU for this inference session
|
| 263 |
+
print("Moving models to CUDA...")
|
| 264 |
+
PREDICTOR.model.to("cuda")
|
| 265 |
+
PLM.to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
+
# 2. Preprocess
|
| 268 |
+
rgb_orig = np.array(image_pil.convert("RGB"))
|
| 269 |
+
Hgt, Wgt = rgb_orig.shape[:2]
|
| 270 |
+
meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
|
| 271 |
+
rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
|
| 272 |
+
|
| 273 |
+
# 3. SAM2 Image Encoding
|
| 274 |
+
# set_image puts features on the model's device (now cuda)
|
| 275 |
+
PREDICTOR.set_image(rgb_sq)
|
| 276 |
+
image_emb = PREDICTOR._features["image_embed"][-1].unsqueeze(0)
|
| 277 |
+
hi = [lvl[-1].unsqueeze(0) for lvl in PREDICTOR._features["high_res_feats"]]
|
| 278 |
+
_, _, H_feat, W_feat = image_emb.shape
|
| 279 |
+
|
| 280 |
+
# 4. PLM Inference
|
| 281 |
+
temp_path = "temp_input.jpg"
|
| 282 |
+
image_pil.save(temp_path)
|
| 283 |
|
| 284 |
+
sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
|
| 285 |
+
|
| 286 |
+
# 5. Prepare SAM2 Decoder inputs (ensure they are on CUDA)
|
| 287 |
+
dec = PREDICTOR.model.sam_mask_decoder
|
| 288 |
+
dev = next(dec.parameters()).device # should be cuda now
|
| 289 |
+
dtype = next(dec.parameters()).dtype
|
| 290 |
+
|
| 291 |
+
image_pe = PREDICTOR.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype)
|
| 292 |
+
image_emb = image_emb.to(dev, dtype)
|
| 293 |
+
hi = [h.to(dev, dtype) for h in hi]
|
| 294 |
+
sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
|
| 295 |
+
|
| 296 |
+
# 6. SAM2 Decoding
|
| 297 |
+
low, scores, _, _ = dec(
|
| 298 |
+
image_embeddings=image_emb,
|
| 299 |
+
image_pe=image_pe,
|
| 300 |
+
sparse_prompt_embeddings=sp,
|
| 301 |
+
dense_prompt_embeddings=dp,
|
| 302 |
+
multimask_output=True,
|
| 303 |
+
repeat_image=False,
|
| 304 |
+
high_res_features=hi,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
logits_sq = PREDICTOR._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
|
| 308 |
+
best = scores.argmax(dim=1).item()
|
| 309 |
+
logit_sq = logits_sq[0, best]
|
| 310 |
+
logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
|
| 311 |
+
|
| 312 |
+
prob = torch.sigmoid(logit_gt)
|
| 313 |
+
mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
|
| 314 |
+
|
| 315 |
+
# 7. Visualization
|
| 316 |
+
overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
|
| 317 |
+
|
| 318 |
+
# Attention
|
| 319 |
+
text_attn = get_text_to_image_attention(dec)
|
| 320 |
+
attn_overlay_img = None
|
| 321 |
+
|
| 322 |
+
if text_attn is not None:
|
| 323 |
+
# Move attn back to CPU for numpy processing
|
| 324 |
+
text_attn = text_attn.cpu()
|
| 325 |
+
attn_flat = text_attn.mean(dim=(0, 2, 3))
|
| 326 |
+
global_flat = attn_flat[0]
|
| 327 |
+
a = global_flat.view(H_feat, W_feat)
|
| 328 |
+
|
| 329 |
+
a_sq = F.interpolate(
|
| 330 |
+
a.unsqueeze(0).unsqueeze(0),
|
| 331 |
+
size=(SQUARE_DIM, SQUARE_DIM),
|
| 332 |
+
mode="bilinear",
|
| 333 |
+
align_corners=False,
|
| 334 |
+
)[0, 0]
|
| 335 |
+
|
| 336 |
+
a_gt = _unpad_and_resize_pred_to_gt(a_sq, meta, (Hgt, Wgt))
|
| 337 |
+
global_attn_orig = a_gt.numpy()
|
| 338 |
+
attn_overlay_img = make_attn_overlay(rgb_orig, global_attn_orig)
|
| 339 |
+
|
| 340 |
+
mask_img = Image.fromarray(mask, mode="L")
|
| 341 |
+
|
| 342 |
+
return overlay_img, mask_img, attn_overlay_img
|
| 343 |
|
| 344 |
+
finally:
|
| 345 |
+
# Cleanup: Move models back to CPU to free GPU memory for other users/sessions
|
| 346 |
+
# This is courteous in ZeroGPU environment
|
| 347 |
+
print("Moving models back to CPU...")
|
| 348 |
+
PREDICTOR.model.to("cpu")
|
| 349 |
+
PLM.to("cpu")
|
| 350 |
|
| 351 |
# ----------------- Gradio UI -----------------
|
| 352 |
|