aadarsh99 commited on
Commit
c84ea63
·
1 Parent(s): 96c10ec

update app

Browse files
Files changed (1) hide show
  1. app.py +117 -131
app.py CHANGED
@@ -20,7 +20,10 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
20
  from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
21
 
22
  # ----------------- Configuration -----------------
23
- HF_REPO_ID = "aadarsh99/ConvSeg-Stage1"
 
 
 
24
  SAM2_CONFIG = "sam2_hiera_l.yaml"
25
 
26
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
@@ -31,24 +34,25 @@ LORA_CKPT_NAME = None
31
  SQUARE_DIM = 1024
32
  logging.basicConfig(level=logging.INFO)
33
 
34
- MODEL_SAM_CPU = None
35
- PLM_CPU = None
 
 
 
36
 
37
- # ----------------- Helper Functions -----------------
38
-
39
- def download_if_needed(filename):
40
- if os.path.exists(filename):
41
- return filename
42
  try:
43
- return hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
 
44
  except Exception as e:
45
- raise FileNotFoundError(f"Could not find {filename} in HF repo {HF_REPO_ID}. Error: {e}")
46
 
 
 
47
  def _hex_to_rgb(h: str):
48
  h = h.lstrip("#")
49
  return tuple(int(h[i : i + 2], 16) for i in (0, 2, 4))
50
-
51
- EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
52
  EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
53
 
54
  def stable_color(key: str):
@@ -58,79 +62,42 @@ def stable_color(key: str):
58
  def tint(rgb, amt: float = 0.1):
59
  return tuple(int(255 - (255 - c) * (1 - amt)) for c in rgb)
60
 
61
- def edge_map(mask_bool: np.ndarray, width_px: int = 2) -> Image.Image:
62
- m = Image.fromarray((mask_bool.astype(np.uint8) * 255), "L")
63
- edges = ImageChops.difference(m.filter(ImageFilter.MaxFilter(3)), m.filter(ImageFilter.MinFilter(3)))
64
- for _ in range(max(0, width_px - 1)):
65
- edges = edges.filter(ImageFilter.MaxFilter(3))
66
- return edges.point(lambda p: 255 if p > 0 else 0)
67
-
68
- def _apply_rounded_corners(img_rgb: Image.Image, radius: int) -> Image.Image:
69
- w, h = img_rgb.size
70
- mask = Image.new("L", (w, h), 0)
71
- ImageDraw.Draw(mask).rounded_rectangle([0, 0, w - 1, h - 1], radius=radius, fill=255)
72
- bg = Image.new("RGB", (w, h), "white")
73
- img_rgba = img_rgb.convert("RGBA")
74
- img_rgba.putalpha(mask)
75
- bg.paste(img_rgba.convert("RGB"), (0, 0), mask)
76
- return bg
77
-
78
  def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.Image:
79
  base = Image.fromarray(rgb.astype(np.uint8)).convert("RGB")
80
- H, W = mask.shape[:2]
81
- if base.size != (W, H):
82
- base = base.resize((W, H), Image.BICUBIC)
83
  base_rgba = base.convert("RGBA")
84
  mask_bool = mask > 0
85
  color = stable_color(key)
86
  fill_rgb = tint(color, 0.1)
87
- fill_layer = Image.new("RGBA", base_rgba.size, fill_rgb + (0,))
 
88
  fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * 178), "L")
89
  fill_layer.putalpha(fill_alpha)
90
- edgesL = edge_map(mask_bool, width_px=2)
91
- stroke = Image.new("RGBA", base_rgba.size, color + (0,))
92
- stroke.putalpha(edgesL)
93
- out = Image.alpha_composite(base_rgba, fill_layer)
94
- out = Image.alpha_composite(out, stroke)
95
- return _apply_rounded_corners(out.convert("RGB"), max(12, int(0.06 * min(out.size))))
96
 
97
- # ----------------- Image Processing -----------------
98
-
99
- def _resize_pad_square(arr: np.ndarray, max_dim: int, *, is_mask: bool) -> np.ndarray:
100
- h, w = arr.shape[:2]
101
- scale = float(max_dim) / float(max(h, w))
102
- new_w, new_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
103
- interp = cv2.INTER_NEAREST if is_mask else (cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR)
104
- arr = cv2.resize(arr, (new_w, new_h), interpolation=interp)
105
- pad_w, pad_h = max_dim - new_w, max_dim - new_h
106
- left, top = pad_w // 2, pad_h // 2
107
- return np.ascontiguousarray(cv2.copyMakeBorder(arr, top, pad_h - top, left, pad_w - left, cv2.BORDER_CONSTANT, value=0))
108
-
109
- def _resize_pad_square_meta(h: int, w: int, max_dim: int):
110
- scale = float(max_dim) / float(max(h, w))
111
- new_w, new_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
112
- return {"scale": scale, "new_w": new_w, "new_h": new_h, "left": (max_dim - new_w) // 2, "top": (max_dim - new_h) // 2}
113
-
114
- def _unpad_and_resize_pred_to_gt(logit_sq: torch.Tensor, meta: dict, out_hw: tuple[int, int]) -> torch.Tensor:
115
- top, left = meta["top"], meta["left"]
116
- nh, nw = meta["new_h"], meta["new_w"]
117
- crop = logit_sq[top : top + nh, left : left + nw].unsqueeze(0).unsqueeze(0)
118
- return F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)[0, 0]
119
 
120
- # ----------------- Prediction Logic -----------------
 
 
121
 
122
- def ensure_models_loaded_on_cpu():
123
- global MODEL_SAM_CPU, PLM_CPU
124
- if MODEL_SAM_CPU is not None and PLM_CPU is not None:
 
125
  return
126
- logging.info("Loading models into CPU RAM...")
127
- base_path = download_if_needed(BASE_CKPT_NAME)
 
 
 
128
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
129
- final_path = download_if_needed(FINAL_CKPT_NAME)
 
130
  sd = torch.load(final_path, map_location="cpu")
131
  model.load_state_dict(sd.get("model", sd), strict=True)
132
- MODEL_SAM_CPU = model
133
-
134
  plm = PLMLanguageAdapter(
135
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
136
  transformer_dim=model.sam_mask_decoder.transformer_dim,
@@ -138,99 +105,118 @@ def ensure_models_loaded_on_cpu():
138
  lora_r=16, lora_alpha=32, lora_dropout=0.05,
139
  dtype=torch.bfloat16, device="cpu",
140
  )
141
- plm_path = download_if_needed(PLM_CKPT_NAME)
142
  plm_sd = torch.load(plm_path, map_location="cpu")
143
  plm.load_state_dict(plm_sd["plm"], strict=True)
144
  plm.eval()
145
- PLM_CPU = plm
146
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  @spaces.GPU(duration=120)
148
- def run_prediction(image_pil, text_prompt, threshold):
149
  if image_pil is None or not text_prompt:
150
  return None, None
151
 
152
- ensure_models_loaded_on_cpu()
153
- MODEL_SAM_CPU.to("cuda")
154
- PLM_CPU.to("cuda")
 
 
 
155
 
156
- predictor = None
157
  try:
158
- predictor = SAM2ImagePredictor(MODEL_SAM_CPU)
159
- rgb_orig = np.array(image_pil.convert("RGB"))
160
- Hgt, Wgt = rgb_orig.shape[:2]
161
- meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
162
- rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
163
-
164
- predictor.set_image(rgb_sq)
165
- image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
166
- hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
167
-
168
- temp_path = "temp_input.jpg"
169
- image_pil.save(temp_path)
170
- sp, dp = PLM_CPU([text_prompt], image_emb.shape[2], image_emb.shape[3], [temp_path])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- dec = predictor.model.sam_mask_decoder
173
- dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
174
-
175
- low, scores, _, _ = dec(
176
- image_embeddings=image_emb.to(dev, dtype),
177
- image_pe=predictor.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype),
178
- sparse_prompt_embeddings=sp.to(dev, dtype),
179
- dense_prompt_embeddings=dp.to(dev, dtype),
180
- multimask_output=True, repeat_image=False,
181
- high_res_features=[h.to(dev, dtype) for h in hi],
182
- )
183
-
184
- logits_sq = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
185
- logit_gt = _unpad_and_resize_pred_to_gt(logits_sq[0, scores.argmax(dim=1).item()], meta, (Hgt, Wgt))
186
 
187
- # 1. Calculate Probabilities (Heatmap)
188
- prob = torch.sigmoid(logit_gt).cpu().numpy()
189
-
190
- # 2. Apply dynamic threshold for overlay
191
  mask = (prob > threshold).astype(np.uint8) * 255
192
  overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
193
 
194
- # 3. Create Heatmap Visualization
195
- # Scale 0.0-1.0 to 0-255
196
- prob_uint8 = (prob * 255).astype(np.uint8)
197
- heatmap_color = cv2.applyColorMap(prob_uint8, cv2.COLORMAP_JET)
198
- heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
199
- heatmap_pil = Image.fromarray(heatmap_color)
200
-
201
- return overlay_img, heatmap_pil
202
 
203
- except Exception as e:
204
  traceback.print_exc()
205
- raise e
206
  finally:
207
- MODEL_SAM_CPU.to("cpu")
208
- PLM_CPU.to("cpu")
209
- if predictor: del predictor
210
  torch.cuda.empty_cache()
211
 
212
  # ----------------- Gradio UI -----------------
213
-
214
- with gr.Blocks(title="SAM2 + PLM Segmentation") as demo:
215
  gr.Markdown("# SAM2 + PLM Interactive Segmentation")
216
 
217
  with gr.Row():
218
  with gr.Column():
219
  input_image = gr.Image(type="pil", label="Input Image")
220
- text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the red car'")
221
- threshold_slider = gr.Slider(
222
- minimum=0.0, maximum=1.0, value=0.5, step=0.01,
223
- label="Confidence Threshold", info="Adjust to include more/less of the object"
224
- )
225
- run_btn = gr.Button("Segment", variant="primary")
 
226
 
227
  with gr.Column():
228
  out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
229
- out_heatmap = gr.Image(label="Probability Heatmap", type="pil")
230
 
231
  run_btn.click(
232
  fn=run_prediction,
233
- inputs=[input_image, text_prompt, threshold_slider],
234
  outputs=[out_overlay, out_heatmap]
235
  )
236
 
 
20
  from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAdapter
21
 
22
  # ----------------- Configuration -----------------
23
+ REPO_MAP = {
24
+ "Stage 1": "aadarsh99/ConvSeg-Stage1",
25
+ "Stage 2": "aadarsh99/ConvSeg-Stage2"
26
+ }
27
  SAM2_CONFIG = "sam2_hiera_l.yaml"
28
 
29
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
 
34
  SQUARE_DIM = 1024
35
  logging.basicConfig(level=logging.INFO)
36
 
37
+ # ----------------- Globals (Ram Cache) -----------------
38
+ MODEL_CACHE = {
39
+ "Stage 1": {"sam": None, "plm": None},
40
+ "Stage 2": {"sam": None, "plm": None}
41
+ }
42
 
43
+ # ----------------- Helper: Download Logic -----------------
44
+ def download_if_needed(repo_id, filename):
 
 
 
45
  try:
46
+ logging.info(f"Downloading {filename} from {repo_id}...")
47
+ return hf_hub_download(repo_id=repo_id, filename=filename)
48
  except Exception as e:
49
+ raise FileNotFoundError(f"Could not find {filename} in {repo_id}. Error: {e}")
50
 
51
+ # ----------------- Overlay & Heatmap Helpers -----------------
52
+ EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
53
  def _hex_to_rgb(h: str):
54
  h = h.lstrip("#")
55
  return tuple(int(h[i : i + 2], 16) for i in (0, 2, 4))
 
 
56
  EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
57
 
58
  def stable_color(key: str):
 
62
  def tint(rgb, amt: float = 0.1):
63
  return tuple(int(255 - (255 - c) * (1 - amt)) for c in rgb)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.Image:
66
  base = Image.fromarray(rgb.astype(np.uint8)).convert("RGB")
 
 
 
67
  base_rgba = base.convert("RGBA")
68
  mask_bool = mask > 0
69
  color = stable_color(key)
70
  fill_rgb = tint(color, 0.1)
71
+
72
+ fill_layer = Image.new("RGBA", base.size, fill_rgb + (0,))
73
  fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * 178), "L")
74
  fill_layer.putalpha(fill_alpha)
 
 
 
 
 
 
75
 
76
+ m = Image.fromarray((mask_bool.astype(np.uint8) * 255), "L")
77
+ edges = ImageChops.difference(m.filter(ImageFilter.MaxFilter(3)), m.filter(ImageFilter.MinFilter(3)))
78
+ stroke = Image.new("RGBA", base.size, color + (0,))
79
+ stroke.putalpha(edges)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ out = Image.alpha_composite(base_rgba, fill_layer)
82
+ out = Image.alpha_composite(out, stroke)
83
+ return out.convert("RGB")
84
 
85
+ # ----------------- Model Loading (CPU Caching) -----------------
86
+ def ensure_models_loaded(stage):
87
+ global MODEL_CACHE
88
+ if MODEL_CACHE[stage]["sam"] is not None:
89
  return
90
+
91
+ repo_id = REPO_MAP[stage]
92
+ logging.info(f"Loading {stage} models from {repo_id}...")
93
+
94
+ base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
95
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
96
+
97
+ final_path = download_if_needed(repo_id, FINAL_CKPT_NAME)
98
  sd = torch.load(final_path, map_location="cpu")
99
  model.load_state_dict(sd.get("model", sd), strict=True)
100
+
 
101
  plm = PLMLanguageAdapter(
102
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
103
  transformer_dim=model.sam_mask_decoder.transformer_dim,
 
105
  lora_r=16, lora_alpha=32, lora_dropout=0.05,
106
  dtype=torch.bfloat16, device="cpu",
107
  )
108
+ plm_path = download_if_needed(repo_id, PLM_CKPT_NAME)
109
  plm_sd = torch.load(plm_path, map_location="cpu")
110
  plm.load_state_dict(plm_sd["plm"], strict=True)
111
  plm.eval()
 
112
 
113
+ MODEL_CACHE[stage]["sam"] = model
114
+ MODEL_CACHE[stage]["plm"] = plm
115
+
116
+ def _resize_pad_square(arr, max_dim):
117
+ h, w = arr.shape[:2]
118
+ scale = float(max_dim) / float(max(h, w))
119
+ nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
120
+ arr = cv2.resize(arr, (nw, nh), interpolation=cv2.INTER_LINEAR)
121
+ pad_w, pad_h = max_dim - nw, max_dim - nh
122
+ return cv2.copyMakeBorder(arr, pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2, cv2.BORDER_CONSTANT, value=0)
123
+
124
+ # ----------------- Main Prediction -----------------
125
  @spaces.GPU(duration=120)
126
+ def run_prediction(image_pil, text_prompt, threshold, stage_choice):
127
  if image_pil is None or not text_prompt:
128
  return None, None
129
 
130
+ ensure_models_loaded(stage_choice)
131
+ sam_model = MODEL_CACHE[stage_choice]["sam"]
132
+ plm_model = MODEL_CACHE[stage_choice]["plm"]
133
+
134
+ sam_model.to("cuda")
135
+ plm_model.to("cuda")
136
 
 
137
  try:
138
+ # 1. Use Inference Mode to avoid grad errors and save memory
139
+ with torch.inference_mode():
140
+ predictor = SAM2ImagePredictor(sam_model)
141
+ rgb_orig = np.array(image_pil.convert("RGB"))
142
+ Hgt, Wgt = rgb_orig.shape[:2]
143
+
144
+ # Setup crop/padding metadata
145
+ scale = SQUARE_DIM / max(Hgt, Wgt)
146
+ nw, nh = int(Wgt * scale), int(Hgt * scale)
147
+ top, left = (SQUARE_DIM - nh) // 2, (SQUARE_DIM - nw) // 2
148
+
149
+ rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM)
150
+ predictor.set_image(rgb_sq)
151
+
152
+ image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
153
+ hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
154
+
155
+ # PLM Inference
156
+ temp_path = "temp_input.jpg"
157
+ image_pil.save(temp_path)
158
+ sp, dp = plm_model([text_prompt], image_emb.shape[2], image_emb.shape[3], [temp_path])
159
+
160
+ # SAM2 Decoding
161
+ dec = predictor.model.sam_mask_decoder
162
+ dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
163
+ low, scores, _, _ = dec(
164
+ image_embeddings=image_emb.to(dev, dtype),
165
+ image_pe=predictor.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype),
166
+ sparse_prompt_embeddings=sp.to(dev, dtype),
167
+ dense_prompt_embeddings=dp.to(dev, dtype),
168
+ multimask_output=True, repeat_image=False,
169
+ high_res_features=[h.to(dev, dtype) for h in hi],
170
+ )
171
 
172
+ # Postprocess to full image size
173
+ logits_sq = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
174
+ best_idx = scores.argmax(dim=1).item()
175
+ logit_crop = logits_sq[0, best_idx, top:top+nh, left:left+nw].unsqueeze(0).unsqueeze(0)
176
+ logit_full = F.interpolate(logit_crop, size=(Hgt, Wgt), mode="bilinear", align_corners=False)[0, 0]
177
+
178
+ # FIX: Detach and convert to float before moving to cpu/numpy
179
+ prob = torch.sigmoid(logit_full).float().detach().cpu().numpy()
180
+
181
+ # 2. Visualizations
182
+ heatmap_cv = cv2.applyColorMap((prob * 255).astype(np.uint8), cv2.COLORMAP_JET)
183
+ heatmap_rgb = cv2.cvtColor(heatmap_cv, cv2.COLOR_BGR2RGB)
 
 
184
 
 
 
 
 
185
  mask = (prob > threshold).astype(np.uint8) * 255
186
  overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
187
 
188
+ return overlay_img, Image.fromarray(heatmap_rgb)
 
 
 
 
 
 
 
189
 
190
+ except Exception:
191
  traceback.print_exc()
192
+ return None, None
193
  finally:
194
+ sam_model.to("cpu")
195
+ plm_model.to("cpu")
 
196
  torch.cuda.empty_cache()
197
 
198
  # ----------------- Gradio UI -----------------
199
+ with gr.Blocks(title="SAM2 + PLM Multi-Stage") as demo:
 
200
  gr.Markdown("# SAM2 + PLM Interactive Segmentation")
201
 
202
  with gr.Row():
203
  with gr.Column():
204
  input_image = gr.Image(type="pil", label="Input Image")
205
+ text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the surgical tool'")
206
+
207
+ with gr.Row():
208
+ stage_select = gr.Radio(choices=["Stage 1", "Stage 2"], value="Stage 1", label="Model Stage")
209
+ threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
210
+
211
+ run_btn = gr.Button("Run Inference", variant="primary")
212
 
213
  with gr.Column():
214
  out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
215
+ out_heatmap = gr.Image(label="Pixel-wise Probability Heatmap", type="pil")
216
 
217
  run_btn.click(
218
  fn=run_prediction,
219
+ inputs=[input_image, text_prompt, threshold_slider, stage_select],
220
  outputs=[out_overlay, out_heatmap]
221
  )
222