aadarsh99 commited on
Commit
97c4d82
·
1 Parent(s): 07bd6be

added app.py

Browse files
Files changed (1) hide show
  1. app.py +376 -0
app.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import hashlib
4
+ import sys
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import gradio as gr
11
+ from PIL import Image, ImageFilter, ImageChops, ImageDraw
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ # --- IMPORT YOUR CUSTOM MODULES ---
15
+ # Ensure the 'sam2' folder and 'plm_adapter_...' file are uploaded to your Space
16
+ from sam2.build_sam import build_sam2
17
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
18
+ from sam2.modeling.sam.mask_decoder import MaskDecoder
19
+ from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
20
+
21
+ # ----------------- Configuration -----------------
22
+ # UPDATE THESE TO MATCH YOUR HF REPO IF YOU STORE WEIGHTS THERE
23
+ HF_REPO_ID = "your-username/your-model-repo"
24
+ SAM2_CONFIG = "sam2_hiera_l.yaml"
25
+
26
+ # Checkpoint filenames (assumed to be in the root or downloaded)
27
+ BASE_CKPT_NAME = "sam2_hiera_large.pt"
28
+ FINAL_CKPT_NAME = "fine_tuned_sam2_batched_100000.torch" # Update with your filename
29
+ PLM_CKPT_NAME = "fine_tuned_sam2_batched_plm_100000.torch" # Update with your filename
30
+ LORA_CKPT_NAME = "lora_plm_adapter_100000" # Set filename if you use LoRA, else None
31
+
32
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
+ SQUARE_DIM = 1024
34
+
35
+ logging.basicConfig(level=logging.INFO)
36
+
37
+ # ----------------- Overlay Style Helpers -----------------
38
+
39
+ EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
40
+
41
+ def _hex_to_rgb(h: str):
42
+ h = h.lstrip("#")
43
+ return tuple(int(h[i : i + 2], 16) for i in (0, 2, 4))
44
+
45
+ EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
46
+
47
+ def stable_color(key: str):
48
+ # Use a fixed key if simple color is desired
49
+ h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
50
+ return EDGE_COLORS[h % len(EDGE_COLORS)]
51
+
52
+ def tint(rgb, amt: float = 0.1):
53
+ return tuple(int(255 - (255 - c) * (1 - amt)) for c in rgb)
54
+
55
+ def edge_map(mask_bool: np.ndarray, width_px: int = 2) -> Image.Image:
56
+ m = Image.fromarray((mask_bool.astype(np.uint8) * 255), "L")
57
+ edges = ImageChops.difference(
58
+ m.filter(ImageFilter.MaxFilter(3)), m.filter(ImageFilter.MinFilter(3))
59
+ )
60
+ for _ in range(max(0, width_px - 1)):
61
+ edges = edges.filter(ImageFilter.MaxFilter(3))
62
+ return edges.point(lambda p: 255 if p > 0 else 0)
63
+
64
+ def _apply_rounded_corners(img_rgb: Image.Image, radius: int) -> Image.Image:
65
+ w, h = img_rgb.size
66
+ mask = Image.new("L", (w, h), 0)
67
+ ImageDraw.Draw(mask).rounded_rectangle([0, 0, w - 1, h - 1], radius=radius, fill=255)
68
+ bg = Image.new("RGB", (w, h), "white")
69
+ img_rgba = img_rgb.convert("RGBA")
70
+ img_rgba.putalpha(mask)
71
+ bg.paste(img_rgba.convert("RGB"), (0, 0), mask)
72
+ return bg
73
+
74
+ def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.Image:
75
+ base = Image.fromarray(rgb.astype(np.uint8)).convert("RGB")
76
+ H, W = mask.shape[:2]
77
+ if base.size != (W, H):
78
+ base = base.resize((W, H), Image.BICUBIC)
79
+
80
+ base_rgba = base.convert("RGBA")
81
+ mask_bool = mask > 0
82
+
83
+ color = stable_color(key)
84
+ fill_rgb = tint(color, 0.1)
85
+ alpha_fill = 0.7
86
+ edge_width = 2
87
+
88
+ a = int(round(alpha_fill * 255))
89
+ tgt_w, tgt_h = base_rgba.size
90
+
91
+ fill_layer = Image.new("RGBA", (tgt_w, tgt_h), fill_rgb + (0,))
92
+ fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * a), "L")
93
+ fill_layer.putalpha(fill_alpha)
94
+
95
+ edgesL = edge_map(mask_bool, width_px=edge_width)
96
+ stroke = Image.new("RGBA", (tgt_w, tgt_h), color + (0,))
97
+ stroke.putalpha(edgesL)
98
+
99
+ out = Image.alpha_composite(base_rgba, fill_layer)
100
+ out = Image.alpha_composite(out, stroke)
101
+ out = out.convert("RGB")
102
+ return _apply_rounded_corners(out, max(12, int(0.06 * min(out.size))))
103
+
104
+ def make_attn_overlay(rgb: np.ndarray, attn: np.ndarray, alpha: float = 0.6) -> Image.Image:
105
+ h, w = rgb.shape[:2]
106
+ ah, aw = attn.shape
107
+ if (ah, aw) != (h, w):
108
+ attn_resized = cv2.resize(attn.astype(np.float32), (w, h), interpolation=cv2.INTER_LINEAR)
109
+ else:
110
+ attn_resized = attn.astype(np.float32)
111
+
112
+ attn_resized = attn_resized - attn_resized.min()
113
+ denom = attn_resized.max()
114
+ if denom < 1e-6: denom = 1e-6
115
+ attn_norm = (attn_resized / denom * 255.0).clip(0, 255).astype(np.uint8)
116
+
117
+ heatmap_bgr = cv2.applyColorMap(attn_norm, cv2.COLORMAP_JET)
118
+ heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
119
+
120
+ rgb_f = rgb.astype(np.float32)
121
+ heat_f = heatmap_rgb.astype(np.float32)
122
+ blended = (1.0 - alpha) * rgb_f + alpha * heat_f
123
+ blended = blended.clip(0, 255).astype(np.uint8)
124
+
125
+ return Image.fromarray(blended, mode="RGB")
126
+
127
+ # ----------------- Image Processing Helpers -----------------
128
+
129
+ def _resize_pad_square(arr: np.ndarray, max_dim: int, *, is_mask: bool) -> np.ndarray:
130
+ h, w = arr.shape[:2]
131
+ scale = float(max_dim) / float(max(h, w))
132
+ new_w = max(1, int(round(w * scale)))
133
+ new_h = max(1, int(round(h * scale)))
134
+
135
+ if is_mask:
136
+ interp = cv2.INTER_NEAREST
137
+ else:
138
+ interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
139
+
140
+ arr = cv2.resize(arr, (new_w, new_h), interpolation=interp)
141
+
142
+ pad_w = max_dim - new_w
143
+ pad_h = max_dim - new_h
144
+ left = pad_w // 2
145
+ right = pad_w - left
146
+ top = pad_h // 2
147
+ bottom = pad_h - top
148
+
149
+ border_val = 0 if is_mask else (0, 0, 0)
150
+ arr = cv2.copyMakeBorder(
151
+ arr, top, bottom, left, right, borderType=cv2.BORDER_CONSTANT, value=border_val
152
+ )
153
+ return np.ascontiguousarray(arr)
154
+
155
+ def _resize_pad_square_meta(h: int, w: int, max_dim: int):
156
+ scale = float(max_dim) / float(max(h, w))
157
+ new_w = max(1, int(round(w * scale)))
158
+ new_h = max(1, int(round(h * scale)))
159
+ pad_w = max_dim - new_w
160
+ pad_h = max_dim - new_h
161
+ left = pad_w // 2
162
+ right = pad_w - left
163
+ top = pad_h // 2
164
+ bottom = pad_h - top
165
+ return {
166
+ "scale": scale, "new_w": new_w, "new_h": new_h,
167
+ "left": left, "right": right, "top": top, "bottom": bottom,
168
+ }
169
+
170
+ def _unpad_and_resize_pred_to_gt(logit_sq: torch.Tensor, meta: dict, out_hw: tuple[int, int]) -> torch.Tensor:
171
+ top, left = meta["top"], meta["left"]
172
+ nh, nw = meta["new_h"], meta["new_w"]
173
+ crop = logit_sq[top : top + nh, left : left + nw]
174
+ crop = crop.unsqueeze(0).unsqueeze(0)
175
+ up = F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)
176
+ return up[0, 0]
177
+
178
+ # ----------------- Model Logic -----------------
179
+
180
+ def get_text_to_image_attention(decoder: MaskDecoder):
181
+ two_way = decoder.transformer
182
+ attn_blocks = []
183
+ for blk in two_way.layers:
184
+ a = blk.cross_attn_token_to_image.last_attn
185
+ if a is not None:
186
+ attn_blocks.append(a)
187
+ final = two_way.final_attn_token_to_image.last_attn
188
+ if final is not None:
189
+ attn_blocks.append(final)
190
+
191
+ if not attn_blocks:
192
+ return None
193
+
194
+ attn = torch.stack(attn_blocks, dim=0)
195
+ s = 1 if decoder.pred_obj_scores else 0
196
+ n_output_tokens = s + 1 + decoder.num_mask_tokens
197
+ text_attn = attn[..., n_output_tokens:, :]
198
+ return text_attn
199
+
200
+ def download_model_if_needed(filename):
201
+ """Checks local disk, else downloads from HF Hub."""
202
+ if os.path.exists(filename):
203
+ return filename
204
+ try:
205
+ print(f"Downloading {filename} from {HF_REPO_ID}...")
206
+ path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
207
+ return path
208
+ except Exception as e:
209
+ print(f"Could not download {filename}. Ensure it exists locally or in the HF repo.")
210
+ # Fallback for Space: if files are uploaded directly to the Files tab,
211
+ # they are in the current working directory.
212
+ if os.path.exists(filename):
213
+ return filename
214
+ raise e
215
+
216
+ def load_models():
217
+ print("Loading models...")
218
+
219
+ # 1. Base SAM2 Model
220
+ base_ckpt_path = download_model_if_needed(BASE_CKPT_NAME)
221
+ model = build_sam2(SAM2_CONFIG, base_ckpt_path, device=DEVICE)
222
+ predictor = SAM2ImagePredictor(model)
223
+ predictor.model.eval()
224
+
225
+ # 2. Fine-tuned Weights
226
+ final_ckpt_path = download_model_if_needed(FINAL_CKPT_NAME)
227
+ sd = torch.load(final_ckpt_path, map_location=DEVICE)
228
+ predictor.model.load_state_dict(sd.get("model", sd), strict=True)
229
+
230
+ # 3. PLM Adapter
231
+ C = predictor.model.sam_mask_decoder.transformer_dim
232
+ plm = PLMLanguageAdapter(
233
+ model_name="Qwen/Qwen2.5-VL-3B-Instruct",
234
+ transformer_dim=C,
235
+ n_sparse_tokens=0,
236
+ use_dense_bias=True,
237
+ use_lora=True,
238
+ lora_r=16,
239
+ lora_alpha=32,
240
+ lora_dropout=0.05,
241
+ dtype=torch.bfloat16,
242
+ device=DEVICE,
243
+ ).to(DEVICE)
244
+ plm.eval()
245
+
246
+ plm_ckpt_path = download_model_if_needed(PLM_CKPT_NAME)
247
+ plm_sd = torch.load(plm_ckpt_path, map_location=DEVICE)
248
+ plm.load_state_dict(plm_sd["plm"], strict=True)
249
+
250
+ if LORA_CKPT_NAME:
251
+ lora_path = download_model_if_needed(LORA_CKPT_NAME)
252
+ plm.load_lora(lora_path)
253
+
254
+ print("Models loaded successfully.")
255
+ return predictor, plm
256
+
257
+ # Initialize global models
258
+ try:
259
+ PREDICTOR, PLM = load_models()
260
+ except Exception as e:
261
+ print(f"Error loading models: {e}")
262
+ print("Please check your checkpoint filenames and HF_REPO_ID in the script.")
263
+ PREDICTOR, PLM = None, None
264
+
265
+ @torch.no_grad()
266
+ def run_prediction(image_pil, text_prompt):
267
+ if PREDICTOR is None or PLM is None:
268
+ return None, None, None
269
+
270
+ if image_pil is None or not text_prompt:
271
+ return None, None, None
272
+
273
+ # Preprocess
274
+ rgb_orig = np.array(image_pil.convert("RGB"))
275
+ Hgt, Wgt = rgb_orig.shape[:2]
276
+ meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
277
+ rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
278
+
279
+ PREDICTOR.set_image(rgb_sq)
280
+ image_emb = PREDICTOR._features["image_embed"][-1].unsqueeze(0)
281
+ hi = [lvl[-1].unsqueeze(0) for lvl in PREDICTOR._features["high_res_feats"]]
282
+ _, _, H_feat, W_feat = image_emb.shape
283
+
284
+ # PLM Inference
285
+ # Note: PLM expects a path list for 'images', but the Qwen adapter likely handles
286
+ # the internal logic. If your PLM adapter strictly requires disk paths,
287
+ # save 'image_pil' to a temp file here.
288
+ # Assuming PLM adapter needs a placeholder path or we save temp:
289
+ temp_path = "temp_input.jpg"
290
+ image_pil.save(temp_path)
291
+
292
+ sp, dp = PLM([text_prompt], H_feat, W_feat, [temp_path])
293
+
294
+ dec = PREDICTOR.model.sam_mask_decoder
295
+ dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
296
+ image_pe = PREDICTOR.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype)
297
+ image_emb = image_emb.to(dev, dtype)
298
+ hi = [h.to(dev, dtype) for h in hi]
299
+ sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
300
+
301
+ # SAM2 Decoding
302
+ low, scores, _, _ = dec(
303
+ image_embeddings=image_emb,
304
+ image_pe=image_pe,
305
+ sparse_prompt_embeddings=sp,
306
+ dense_prompt_embeddings=dp,
307
+ multimask_output=True,
308
+ repeat_image=False,
309
+ high_res_features=hi,
310
+ )
311
+
312
+ logits_sq = PREDICTOR._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
313
+ best = scores.argmax(dim=1).item()
314
+ logit_sq = logits_sq[0, best]
315
+ logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
316
+
317
+ prob = torch.sigmoid(logit_gt)
318
+ mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
319
+
320
+ # Visualization: Overlay
321
+ overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
322
+
323
+ # Visualization: Attention
324
+ text_attn = get_text_to_image_attention(dec)
325
+ attn_overlay_img = None
326
+
327
+ if text_attn is not None:
328
+ L_layer, B, H_heads, N_text, N_img = text_attn.shape
329
+ attn_flat = text_attn.mean(dim=(0, 2, 3)) # Mean over layers, heads, text
330
+ global_flat = attn_flat[0]
331
+ a = global_flat.view(H_feat, W_feat)
332
+
333
+ # Upsample attention
334
+ a_sq = F.interpolate(
335
+ a.unsqueeze(0).unsqueeze(0),
336
+ size=(SQUARE_DIM, SQUARE_DIM),
337
+ mode="bilinear",
338
+ align_corners=False,
339
+ )[0, 0]
340
+
341
+ a_gt = _unpad_and_resize_pred_to_gt(a_sq, meta, (Hgt, Wgt))
342
+ global_attn_orig = a_gt.cpu().numpy()
343
+ attn_overlay_img = make_attn_overlay(rgb_orig, global_attn_orig)
344
+
345
+ # Return list of images for Gallery or individual blocks
346
+ # Mask as an image
347
+ mask_img = Image.fromarray(mask, mode="L")
348
+
349
+ return overlay_img, mask_img, attn_overlay_img
350
+
351
+ # ----------------- Gradio UI -----------------
352
+
353
+ with gr.Blocks(title="SAM2 + PLM Interactive Segmentation") as demo:
354
+ gr.Markdown("# SAM2 + PLM Interactive Segmentation")
355
+ gr.Markdown("Enter a text prompt to segment objects in the image.")
356
+
357
+ with gr.Row():
358
+ with gr.Column():
359
+ input_image = gr.Image(type="pil", label="Input Image")
360
+ text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the red car'")
361
+ run_btn = gr.Button("Segment", variant="primary")
362
+
363
+ with gr.Column():
364
+ out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
365
+ with gr.Row():
366
+ out_mask = gr.Image(label="Binary Mask", type="pil")
367
+ out_attn = gr.Image(label="Attention Heatmap", type="pil")
368
+
369
+ run_btn.click(
370
+ fn=run_prediction,
371
+ inputs=[input_image, text_prompt],
372
+ outputs=[out_overlay, out_mask, out_attn]
373
+ )
374
+
375
+ if __name__ == "__main__":
376
+ demo.launch()