aadarsh99 commited on
Commit
7e46975
·
1 Parent(s): c84ea63

update app

Browse files
Files changed (1) hide show
  1. app.py +62 -108
app.py CHANGED
@@ -29,47 +29,34 @@ SAM2_CONFIG = "sam2_hiera_l.yaml"
29
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
30
  FINAL_CKPT_NAME = "fine_tuned_sam2_batched_100000.torch"
31
  PLM_CKPT_NAME = "fine_tuned_sam2_batched_plm_100000.torch"
32
- LORA_CKPT_NAME = None
33
 
34
  SQUARE_DIM = 1024
35
  logging.basicConfig(level=logging.INFO)
36
 
37
- # ----------------- Globals (Ram Cache) -----------------
38
  MODEL_CACHE = {
39
  "Stage 1": {"sam": None, "plm": None},
40
  "Stage 2": {"sam": None, "plm": None}
41
  }
42
 
43
- # ----------------- Helper: Download Logic -----------------
44
  def download_if_needed(repo_id, filename):
45
  try:
46
- logging.info(f"Downloading {filename} from {repo_id}...")
47
  return hf_hub_download(repo_id=repo_id, filename=filename)
48
  except Exception as e:
49
  raise FileNotFoundError(f"Could not find {filename} in {repo_id}. Error: {e}")
50
 
51
- # ----------------- Overlay & Heatmap Helpers -----------------
52
- EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
53
- def _hex_to_rgb(h: str):
54
- h = h.lstrip("#")
55
- return tuple(int(h[i : i + 2], 16) for i in (0, 2, 4))
56
- EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
57
-
58
  def stable_color(key: str):
59
  h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
60
- return EDGE_COLORS[h % len(EDGE_COLORS)]
61
-
62
- def tint(rgb, amt: float = 0.1):
63
- return tuple(int(255 - (255 - c) * (1 - amt)) for c in rgb)
64
 
65
  def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.Image:
66
- base = Image.fromarray(rgb.astype(np.uint8)).convert("RGB")
67
- base_rgba = base.convert("RGBA")
68
  mask_bool = mask > 0
69
  color = stable_color(key)
70
- fill_rgb = tint(color, 0.1)
71
 
72
- fill_layer = Image.new("RGBA", base.size, fill_rgb + (0,))
73
  fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * 178), "L")
74
  fill_layer.putalpha(fill_alpha)
75
 
@@ -78,146 +65,113 @@ def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.
78
  stroke = Image.new("RGBA", base.size, color + (0,))
79
  stroke.putalpha(edges)
80
 
81
- out = Image.alpha_composite(base_rgba, fill_layer)
82
- out = Image.alpha_composite(out, stroke)
83
- return out.convert("RGB")
84
 
85
- # ----------------- Model Loading (CPU Caching) -----------------
86
  def ensure_models_loaded(stage):
87
  global MODEL_CACHE
88
- if MODEL_CACHE[stage]["sam"] is not None:
89
- return
90
-
91
  repo_id = REPO_MAP[stage]
92
- logging.info(f"Loading {stage} models from {repo_id}...")
93
-
94
  base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
95
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
96
-
97
- final_path = download_if_needed(repo_id, FINAL_CKPT_NAME)
98
- sd = torch.load(final_path, map_location="cpu")
99
  model.load_state_dict(sd.get("model", sd), strict=True)
100
-
101
- plm = PLMLanguageAdapter(
102
- model_name="Qwen/Qwen2.5-VL-3B-Instruct",
103
- transformer_dim=model.sam_mask_decoder.transformer_dim,
104
- n_sparse_tokens=0, use_dense_bias=True, use_lora=True,
105
- lora_r=16, lora_alpha=32, lora_dropout=0.05,
106
- dtype=torch.bfloat16, device="cpu",
107
- )
108
- plm_path = download_if_needed(repo_id, PLM_CKPT_NAME)
109
- plm_sd = torch.load(plm_path, map_location="cpu")
110
- plm.load_state_dict(plm_sd["plm"], strict=True)
111
  plm.eval()
 
112
 
113
- MODEL_CACHE[stage]["sam"] = model
114
- MODEL_CACHE[stage]["plm"] = plm
115
-
116
- def _resize_pad_square(arr, max_dim):
117
- h, w = arr.shape[:2]
118
- scale = float(max_dim) / float(max(h, w))
119
- nw, nh = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
120
- arr = cv2.resize(arr, (nw, nh), interpolation=cv2.INTER_LINEAR)
121
- pad_w, pad_h = max_dim - nw, max_dim - nh
122
- return cv2.copyMakeBorder(arr, pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2, cv2.BORDER_CONSTANT, value=0)
123
 
124
- # ----------------- Main Prediction -----------------
125
  @spaces.GPU(duration=120)
126
  def run_prediction(image_pil, text_prompt, threshold, stage_choice):
127
  if image_pil is None or not text_prompt:
128
- return None, None
129
 
130
  ensure_models_loaded(stage_choice)
131
- sam_model = MODEL_CACHE[stage_choice]["sam"]
132
- plm_model = MODEL_CACHE[stage_choice]["plm"]
133
-
134
- sam_model.to("cuda")
135
- plm_model.to("cuda")
136
 
137
  try:
138
- # 1. Use Inference Mode to avoid grad errors and save memory
139
  with torch.inference_mode():
140
  predictor = SAM2ImagePredictor(sam_model)
141
  rgb_orig = np.array(image_pil.convert("RGB"))
142
- Hgt, Wgt = rgb_orig.shape[:2]
143
-
144
- # Setup crop/padding metadata
145
- scale = SQUARE_DIM / max(Hgt, Wgt)
146
- nw, nh = int(Wgt * scale), int(Hgt * scale)
147
  top, left = (SQUARE_DIM - nh) // 2, (SQUARE_DIM - nw) // 2
148
 
149
- rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM)
 
150
  predictor.set_image(rgb_sq)
151
-
152
  image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
153
  hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
154
 
155
- # PLM Inference
156
- temp_path = "temp_input.jpg"
157
  image_pil.save(temp_path)
158
  sp, dp = plm_model([text_prompt], image_emb.shape[2], image_emb.shape[3], [temp_path])
159
-
160
- # SAM2 Decoding
161
- dec = predictor.model.sam_mask_decoder
162
- dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
163
- low, scores, _, _ = dec(
164
- image_embeddings=image_emb.to(dev, dtype),
165
- image_pe=predictor.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype),
166
- sparse_prompt_embeddings=sp.to(dev, dtype),
167
- dense_prompt_embeddings=dp.to(dev, dtype),
168
- multimask_output=True, repeat_image=False,
169
- high_res_features=[h.to(dev, dtype) for h in hi],
170
  )
171
 
172
- # Postprocess to full image size
173
- logits_sq = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
174
- best_idx = scores.argmax(dim=1).item()
175
- logit_crop = logits_sq[0, best_idx, top:top+nh, left:left+nw].unsqueeze(0).unsqueeze(0)
176
- logit_full = F.interpolate(logit_crop, size=(Hgt, Wgt), mode="bilinear", align_corners=False)[0, 0]
177
-
178
- # FIX: Detach and convert to float before moving to cpu/numpy
179
  prob = torch.sigmoid(logit_full).float().detach().cpu().numpy()
180
 
181
- # 2. Visualizations
182
- heatmap_cv = cv2.applyColorMap((prob * 255).astype(np.uint8), cv2.COLORMAP_JET)
183
- heatmap_rgb = cv2.cvtColor(heatmap_cv, cv2.COLOR_BGR2RGB)
184
 
185
- mask = (prob > threshold).astype(np.uint8) * 255
186
- overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
187
 
188
- return overlay_img, Image.fromarray(heatmap_rgb)
189
-
190
- except Exception:
191
- traceback.print_exc()
192
- return None, None
193
  finally:
194
- sam_model.to("cpu")
195
- plm_model.to("cpu")
196
  torch.cuda.empty_cache()
197
 
 
 
 
 
 
 
 
 
198
  # ----------------- Gradio UI -----------------
199
- with gr.Blocks(title="SAM2 + PLM Multi-Stage") as demo:
200
- gr.Markdown("# SAM2 + PLM Interactive Segmentation")
 
 
 
201
 
202
  with gr.Row():
203
  with gr.Column():
204
  input_image = gr.Image(type="pil", label="Input Image")
205
- text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the surgical tool'")
206
-
207
  with gr.Row():
208
- stage_select = gr.Radio(choices=["Stage 1", "Stage 2"], value="Stage 1", label="Model Stage")
209
- threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
210
-
211
  run_btn = gr.Button("Run Inference", variant="primary")
212
 
213
  with gr.Column():
214
  out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
215
- out_heatmap = gr.Image(label="Pixel-wise Probability Heatmap", type="pil")
216
 
 
217
  run_btn.click(
218
  fn=run_prediction,
219
  inputs=[input_image, text_prompt, threshold_slider, stage_select],
220
- outputs=[out_overlay, out_heatmap]
 
 
 
 
 
 
 
221
  )
222
 
223
  if __name__ == "__main__":
 
29
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
30
  FINAL_CKPT_NAME = "fine_tuned_sam2_batched_100000.torch"
31
  PLM_CKPT_NAME = "fine_tuned_sam2_batched_plm_100000.torch"
 
32
 
33
  SQUARE_DIM = 1024
34
  logging.basicConfig(level=logging.INFO)
35
 
 
36
  MODEL_CACHE = {
37
  "Stage 1": {"sam": None, "plm": None},
38
  "Stage 2": {"sam": None, "plm": None}
39
  }
40
 
41
+ # ----------------- Helper Functions -----------------
42
  def download_if_needed(repo_id, filename):
43
  try:
 
44
  return hf_hub_download(repo_id=repo_id, filename=filename)
45
  except Exception as e:
46
  raise FileNotFoundError(f"Could not find {filename} in {repo_id}. Error: {e}")
47
 
 
 
 
 
 
 
 
48
  def stable_color(key: str):
49
  h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
50
+ EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
51
+ colors = [tuple(int(h.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) for h in EDGE_COLORS_HEX]
52
+ return colors[h % len(colors)]
 
53
 
54
  def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.Image:
55
+ base = Image.fromarray(rgb.astype(np.uint8)).convert("RGB").convert("RGBA")
 
56
  mask_bool = mask > 0
57
  color = stable_color(key)
 
58
 
59
+ fill_layer = Image.new("RGBA", base.size, color + (0,))
60
  fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * 178), "L")
61
  fill_layer.putalpha(fill_alpha)
62
 
 
65
  stroke = Image.new("RGBA", base.size, color + (0,))
66
  stroke.putalpha(edges)
67
 
68
+ return Image.alpha_composite(base, fill_layer).alpha_composite(stroke).convert("RGB")
 
 
69
 
 
70
  def ensure_models_loaded(stage):
71
  global MODEL_CACHE
72
+ if MODEL_CACHE[stage]["sam"] is not None: return
 
 
73
  repo_id = REPO_MAP[stage]
 
 
74
  base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
75
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
76
+ sd = torch.load(download_if_needed(repo_id, FINAL_CKPT_NAME), map_location="cpu")
 
 
77
  model.load_state_dict(sd.get("model", sd), strict=True)
78
+ plm = PLMLanguageAdapter(model_name="Qwen/Qwen2.5-VL-3B-Instruct", transformer_dim=model.sam_mask_decoder.transformer_dim, n_sparse_tokens=0, use_dense_bias=True, use_lora=True, lora_r=16, lora_alpha=32, lora_dropout=0.05, dtype=torch.bfloat16, device="cpu")
79
+ plm.load_state_dict(torch.load(download_if_needed(repo_id, PLM_CKPT_NAME), map_location="cpu")["plm"], strict=True)
 
 
 
 
 
 
 
 
 
80
  plm.eval()
81
+ MODEL_CACHE[stage]["sam"], MODEL_CACHE[stage]["plm"] = model, plm
82
 
83
+ # ----------------- Core Logic -----------------
 
 
 
 
 
 
 
 
 
84
 
 
85
  @spaces.GPU(duration=120)
86
  def run_prediction(image_pil, text_prompt, threshold, stage_choice):
87
  if image_pil is None or not text_prompt:
88
+ return None, None, None
89
 
90
  ensure_models_loaded(stage_choice)
91
+ sam_model, plm_model = MODEL_CACHE[stage_choice]["sam"], MODEL_CACHE[stage_choice]["plm"]
92
+ sam_model.to("cuda"), plm_model.to("cuda")
 
 
 
93
 
94
  try:
 
95
  with torch.inference_mode():
96
  predictor = SAM2ImagePredictor(sam_model)
97
  rgb_orig = np.array(image_pil.convert("RGB"))
98
+ H, W = rgb_orig.shape[:2]
99
+ scale = SQUARE_DIM / max(H, W)
100
+ nw, nh = int(W * scale), int(H * scale)
 
 
101
  top, left = (SQUARE_DIM - nh) // 2, (SQUARE_DIM - nw) // 2
102
 
103
+ # Preprocess & Encode
104
+ rgb_sq = cv2.copyMakeBorder(cv2.resize(rgb_orig, (nw, nh)), top, SQUARE_DIM-nh-top, left, SQUARE_DIM-nw-left, cv2.BORDER_CONSTANT, value=0)
105
  predictor.set_image(rgb_sq)
 
106
  image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
107
  hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
108
 
109
+ # PLM & SAM2 Decoder
110
+ temp_path = "temp.jpg"
111
  image_pil.save(temp_path)
112
  sp, dp = plm_model([text_prompt], image_emb.shape[2], image_emb.shape[3], [temp_path])
113
+ low, scores, _, _ = sam_model.sam_mask_decoder(
114
+ image_embeddings=image_emb.to("cuda"), image_pe=sam_model.sam_prompt_encoder.get_dense_pe().to("cuda"),
115
+ sparse_prompt_embeddings=sp.to("cuda"), dense_prompt_embeddings=dp.to("cuda"),
116
+ multimask_output=True, repeat_image=False, high_res_features=[h.to("cuda") for h in hi]
 
 
 
 
 
 
 
117
  )
118
 
119
+ # Postprocess to full size
120
+ logits = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
121
+ logit_crop = logits[0, scores.argmax().item(), top:top+nh, left:left+nw].unsqueeze(0).unsqueeze(0)
122
+ logit_full = F.interpolate(logit_crop, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
 
 
 
123
  prob = torch.sigmoid(logit_full).float().detach().cpu().numpy()
124
 
125
+ # Initial visualization
126
+ heatmap = cv2.applyColorMap((prob * 255).astype(np.uint8), cv2.COLORMAP_JET)
127
+ overlay = make_overlay(rgb_orig, (prob > threshold).astype(np.uint8) * 255, key=text_prompt)
128
 
129
+ return overlay, Image.fromarray(cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)), prob
 
130
 
 
 
 
 
 
131
  finally:
132
+ sam_model.to("cpu"), plm_model.to("cpu")
 
133
  torch.cuda.empty_cache()
134
 
135
+ def update_threshold_ui(image_pil, text_prompt, threshold, cached_prob):
136
+ """Updates the overlay instantly without rerunning the GPU model."""
137
+ if image_pil is None or cached_prob is None:
138
+ return None
139
+ rgb_orig = np.array(image_pil.convert("RGB"))
140
+ mask = (cached_prob > threshold).astype(np.uint8) * 255
141
+ return make_overlay(rgb_orig, mask, key=text_prompt)
142
+
143
  # ----------------- Gradio UI -----------------
144
+
145
+ with gr.Blocks(title="SAM2 + PLM Interactive") as demo:
146
+ prob_state = gr.State() # Caches the probability map
147
+
148
+ gr.Markdown("# SAM2 + PLM Segmentation\n*Change the model/prompt and click **Run Inference**. Then, adjust the **Threshold** slider for instant mask updates.*")
149
 
150
  with gr.Row():
151
  with gr.Column():
152
  input_image = gr.Image(type="pil", label="Input Image")
153
+ text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the blue scissors'")
 
154
  with gr.Row():
155
+ stage_select = gr.Radio(choices=["Stage 1", "Stage 2"], value="Stage 1", label="Model")
156
+ threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold")
 
157
  run_btn = gr.Button("Run Inference", variant="primary")
158
 
159
  with gr.Column():
160
  out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
161
+ out_heatmap = gr.Image(label="Probability Heatmap", type="pil")
162
 
163
+ # 1. Clicking the button runs the heavy inference
164
  run_btn.click(
165
  fn=run_prediction,
166
  inputs=[input_image, text_prompt, threshold_slider, stage_select],
167
+ outputs=[out_overlay, out_heatmap, prob_state]
168
+ )
169
+
170
+ # 2. Moving the slider triggers only the lightweight update
171
+ threshold_slider.change(
172
+ fn=update_threshold_ui,
173
+ inputs=[input_image, text_prompt, threshold_slider, prob_state],
174
+ outputs=[out_overlay]
175
  )
176
 
177
  if __name__ == "__main__":