aadarsh99 commited on
Commit
3d4a272
·
1 Parent(s): 7e46975

update app

Browse files
Files changed (1) hide show
  1. app.py +93 -39
app.py CHANGED
@@ -4,6 +4,7 @@ import hashlib
4
  import sys
5
  import traceback
6
  import copy
 
7
 
8
  import cv2
9
  import numpy as np
@@ -41,6 +42,7 @@ MODEL_CACHE = {
41
  # ----------------- Helper Functions -----------------
42
  def download_if_needed(repo_id, filename):
43
  try:
 
44
  return hf_hub_download(repo_id=repo_id, filename=filename)
45
  except Exception as e:
46
  raise FileNotFoundError(f"Could not find {filename} in {repo_id}. Error: {e}")
@@ -48,39 +50,63 @@ def download_if_needed(repo_id, filename):
48
  def stable_color(key: str):
49
  h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
50
  EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
51
- colors = [tuple(int(h.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) for h in EDGE_COLORS_HEX]
52
  return colors[h % len(colors)]
53
 
54
  def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.Image:
55
- base = Image.fromarray(rgb.astype(np.uint8)).convert("RGB").convert("RGBA")
 
56
  mask_bool = mask > 0
57
  color = stable_color(key)
58
 
 
59
  fill_layer = Image.new("RGBA", base.size, color + (0,))
60
- fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * 178), "L")
61
  fill_layer.putalpha(fill_alpha)
62
 
 
63
  m = Image.fromarray((mask_bool.astype(np.uint8) * 255), "L")
64
  edges = ImageChops.difference(m.filter(ImageFilter.MaxFilter(3)), m.filter(ImageFilter.MinFilter(3)))
65
- stroke = Image.new("RGBA", base.size, color + (0,))
66
- stroke.putalpha(edges)
67
 
68
- return Image.alpha_composite(base, fill_layer).alpha_composite(stroke).convert("RGB")
 
 
 
 
69
 
70
  def ensure_models_loaded(stage):
71
  global MODEL_CACHE
72
- if MODEL_CACHE[stage]["sam"] is not None: return
 
 
73
  repo_id = REPO_MAP[stage]
 
 
 
74
  base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
75
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
76
- sd = torch.load(download_if_needed(repo_id, FINAL_CKPT_NAME), map_location="cpu")
 
77
  model.load_state_dict(sd.get("model", sd), strict=True)
78
- plm = PLMLanguageAdapter(model_name="Qwen/Qwen2.5-VL-3B-Instruct", transformer_dim=model.sam_mask_decoder.transformer_dim, n_sparse_tokens=0, use_dense_bias=True, use_lora=True, lora_r=16, lora_alpha=32, lora_dropout=0.05, dtype=torch.bfloat16, device="cpu")
79
- plm.load_state_dict(torch.load(download_if_needed(repo_id, PLM_CKPT_NAME), map_location="cpu")["plm"], strict=True)
 
 
 
 
 
 
 
 
 
 
80
  plm.eval()
 
81
  MODEL_CACHE[stage]["sam"], MODEL_CACHE[stage]["plm"] = model, plm
82
 
83
- # ----------------- Core Logic -----------------
84
 
85
  @spaces.GPU(duration=120)
86
  def run_prediction(image_pil, text_prompt, threshold, stage_choice):
@@ -88,52 +114,77 @@ def run_prediction(image_pil, text_prompt, threshold, stage_choice):
88
  return None, None, None
89
 
90
  ensure_models_loaded(stage_choice)
91
- sam_model, plm_model = MODEL_CACHE[stage_choice]["sam"], MODEL_CACHE[stage_choice]["plm"]
92
- sam_model.to("cuda"), plm_model.to("cuda")
 
 
 
93
 
94
  try:
95
  with torch.inference_mode():
96
  predictor = SAM2ImagePredictor(sam_model)
97
  rgb_orig = np.array(image_pil.convert("RGB"))
98
  H, W = rgb_orig.shape[:2]
 
 
99
  scale = SQUARE_DIM / max(H, W)
100
  nw, nh = int(W * scale), int(H * scale)
101
  top, left = (SQUARE_DIM - nh) // 2, (SQUARE_DIM - nw) // 2
102
 
103
- # Preprocess & Encode
104
- rgb_sq = cv2.copyMakeBorder(cv2.resize(rgb_orig, (nw, nh)), top, SQUARE_DIM-nh-top, left, SQUARE_DIM-nw-left, cv2.BORDER_CONSTANT, value=0)
 
 
105
  predictor.set_image(rgb_sq)
106
  image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
107
  hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
108
 
109
- # PLM & SAM2 Decoder
110
- temp_path = "temp.jpg"
111
- image_pil.save(temp_path)
112
- sp, dp = plm_model([text_prompt], image_emb.shape[2], image_emb.shape[3], [temp_path])
113
- low, scores, _, _ = sam_model.sam_mask_decoder(
114
- image_embeddings=image_emb.to("cuda"), image_pe=sam_model.sam_prompt_encoder.get_dense_pe().to("cuda"),
115
- sparse_prompt_embeddings=sp.to("cuda"), dense_prompt_embeddings=dp.to("cuda"),
116
- multimask_output=True, repeat_image=False, high_res_features=[h.to("cuda") for h in hi]
 
 
 
 
 
 
 
 
117
  )
118
 
119
- # Postprocess to full size
120
  logits = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
121
- logit_crop = logits[0, scores.argmax().item(), top:top+nh, left:left+nw].unsqueeze(0).unsqueeze(0)
 
122
  logit_full = F.interpolate(logit_crop, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
123
- prob = torch.sigmoid(logit_full).float().detach().cpu().numpy()
 
124
 
125
- # Initial visualization
126
- heatmap = cv2.applyColorMap((prob * 255).astype(np.uint8), cv2.COLORMAP_JET)
127
- overlay = make_overlay(rgb_orig, (prob > threshold).astype(np.uint8) * 255, key=text_prompt)
128
 
129
- return overlay, Image.fromarray(cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)), prob
 
 
 
 
130
 
 
 
 
131
  finally:
132
- sam_model.to("cpu"), plm_model.to("cpu")
 
133
  torch.cuda.empty_cache()
134
 
135
  def update_threshold_ui(image_pil, text_prompt, threshold, cached_prob):
136
- """Updates the overlay instantly without rerunning the GPU model."""
137
  if image_pil is None or cached_prob is None:
138
  return None
139
  rgb_orig = np.array(image_pil.convert("RGB"))
@@ -142,32 +193,35 @@ def update_threshold_ui(image_pil, text_prompt, threshold, cached_prob):
142
 
143
  # ----------------- Gradio UI -----------------
144
 
145
- with gr.Blocks(title="SAM2 + PLM Interactive") as demo:
146
- prob_state = gr.State() # Caches the probability map
147
 
148
- gr.Markdown("# SAM2 + PLM Segmentation\n*Change the model/prompt and click **Run Inference**. Then, adjust the **Threshold** slider for instant mask updates.*")
 
149
 
150
  with gr.Row():
151
  with gr.Column():
152
  input_image = gr.Image(type="pil", label="Input Image")
153
- text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the blue scissors'")
 
154
  with gr.Row():
155
- stage_select = gr.Radio(choices=["Stage 1", "Stage 2"], value="Stage 1", label="Model")
156
  threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold")
 
157
  run_btn = gr.Button("Run Inference", variant="primary")
158
 
159
  with gr.Column():
160
  out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
161
  out_heatmap = gr.Image(label="Probability Heatmap", type="pil")
162
 
163
- # 1. Clicking the button runs the heavy inference
164
  run_btn.click(
165
  fn=run_prediction,
166
  inputs=[input_image, text_prompt, threshold_slider, stage_select],
167
  outputs=[out_overlay, out_heatmap, prob_state]
168
  )
169
 
170
- # 2. Moving the slider triggers only the lightweight update
171
  threshold_slider.change(
172
  fn=update_threshold_ui,
173
  inputs=[input_image, text_prompt, threshold_slider, prob_state],
 
4
  import sys
5
  import traceback
6
  import copy
7
+ import tempfile
8
 
9
  import cv2
10
  import numpy as np
 
42
  # ----------------- Helper Functions -----------------
43
  def download_if_needed(repo_id, filename):
44
  try:
45
+ logging.info(f"Checking {filename} in {repo_id}...")
46
  return hf_hub_download(repo_id=repo_id, filename=filename)
47
  except Exception as e:
48
  raise FileNotFoundError(f"Could not find {filename} in {repo_id}. Error: {e}")
 
50
  def stable_color(key: str):
51
  h = int(hashlib.sha256(str(key).encode("utf-8")).hexdigest(), 16)
52
  EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
53
+ colors = [tuple(int(c.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) for c in EDGE_COLORS_HEX]
54
  return colors[h % len(colors)]
55
 
56
  def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.Image:
57
+ # Convert base to RGBA
58
+ base = Image.fromarray(rgb.astype(np.uint8)).convert("RGBA")
59
  mask_bool = mask > 0
60
  color = stable_color(key)
61
 
62
+ # Create fill layer (Semi-transparent)
63
  fill_layer = Image.new("RGBA", base.size, color + (0,))
64
+ fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * 140), "L")
65
  fill_layer.putalpha(fill_alpha)
66
 
67
+ # Create stroke/edge layer
68
  m = Image.fromarray((mask_bool.astype(np.uint8) * 255), "L")
69
  edges = ImageChops.difference(m.filter(ImageFilter.MaxFilter(3)), m.filter(ImageFilter.MinFilter(3)))
70
+ stroke_layer = Image.new("RGBA", base.size, color + (255,))
71
+ stroke_layer.putalpha(edges)
72
 
73
+ # Composite safely (Module-level returns new images, no in-place None issues)
74
+ out = Image.alpha_composite(base, fill_layer)
75
+ out = Image.alpha_composite(out, stroke_layer)
76
+
77
+ return out.convert("RGB")
78
 
79
  def ensure_models_loaded(stage):
80
  global MODEL_CACHE
81
+ if MODEL_CACHE[stage]["sam"] is not None:
82
+ return
83
+
84
  repo_id = REPO_MAP[stage]
85
+ logging.info(f"Loading {stage} models from {repo_id} into CPU RAM...")
86
+
87
+ # SAM2
88
  base_path = download_if_needed(repo_id, BASE_CKPT_NAME)
89
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
90
+ final_path = download_if_needed(repo_id, FINAL_CKPT_NAME)
91
+ sd = torch.load(final_path, map_location="cpu")
92
  model.load_state_dict(sd.get("model", sd), strict=True)
93
+
94
+ # PLM
95
+ plm_path = download_if_needed(repo_id, PLM_CKPT_NAME)
96
+ plm = PLMLanguageAdapter(
97
+ model_name="Qwen/Qwen2.5-VL-3B-Instruct",
98
+ transformer_dim=model.sam_mask_decoder.transformer_dim,
99
+ n_sparse_tokens=0, use_dense_bias=True, use_lora=True,
100
+ lora_r=16, lora_alpha=32, lora_dropout=0.05,
101
+ dtype=torch.bfloat16, device="cpu"
102
+ )
103
+ plm_sd = torch.load(plm_path, map_location="cpu")
104
+ plm.load_state_dict(plm_sd["plm"], strict=True)
105
  plm.eval()
106
+
107
  MODEL_CACHE[stage]["sam"], MODEL_CACHE[stage]["plm"] = model, plm
108
 
109
+ # ----------------- GPU Inference -----------------
110
 
111
  @spaces.GPU(duration=120)
112
  def run_prediction(image_pil, text_prompt, threshold, stage_choice):
 
114
  return None, None, None
115
 
116
  ensure_models_loaded(stage_choice)
117
+ sam_model = MODEL_CACHE[stage_choice]["sam"]
118
+ plm_model = MODEL_CACHE[stage_choice]["plm"]
119
+
120
+ sam_model.to("cuda")
121
+ plm_model.to("cuda")
122
 
123
  try:
124
  with torch.inference_mode():
125
  predictor = SAM2ImagePredictor(sam_model)
126
  rgb_orig = np.array(image_pil.convert("RGB"))
127
  H, W = rgb_orig.shape[:2]
128
+
129
+ # Padding math
130
  scale = SQUARE_DIM / max(H, W)
131
  nw, nh = int(W * scale), int(H * scale)
132
  top, left = (SQUARE_DIM - nh) // 2, (SQUARE_DIM - nw) // 2
133
 
134
+ # Resize & Pad
135
+ rgb_sq = cv2.resize(rgb_orig, (nw, nh), interpolation=cv2.INTER_LINEAR)
136
+ rgb_sq = cv2.copyMakeBorder(rgb_sq, top, SQUARE_DIM-nh-top, left, SQUARE_DIM-nw-left, cv2.BORDER_CONSTANT, value=0)
137
+
138
  predictor.set_image(rgb_sq)
139
  image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
140
  hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
141
 
142
+ # PLM adapter
143
+ with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp:
144
+ image_pil.save(tmp.name)
145
+ sp, dp = plm_model([text_prompt], image_emb.shape[2], image_emb.shape[3], [tmp.name])
146
+
147
+ # SAM2 Decoding
148
+ dec = sam_model.sam_mask_decoder
149
+ dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
150
+
151
+ low, scores, _, _ = dec(
152
+ image_embeddings=image_emb.to(dev, dtype),
153
+ image_pe=sam_model.sam_prompt_encoder.get_dense_pe().to(dev, dtype),
154
+ sparse_prompt_embeddings=sp.to(dev, dtype),
155
+ dense_prompt_embeddings=dp.to(dev, dtype),
156
+ multimask_output=True, repeat_image=False,
157
+ high_res_features=[h.to(dev, dtype) for h in hi]
158
  )
159
 
160
+ # Postprocess to original dimensions
161
  logits = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
162
+ best_idx = scores.argmax().item()
163
+ logit_crop = logits[0, best_idx, top:top+nh, left:left+nw].unsqueeze(0).unsqueeze(0)
164
  logit_full = F.interpolate(logit_crop, size=(H, W), mode="bilinear", align_corners=False)[0, 0]
165
+
166
+ prob = torch.sigmoid(logit_full).float().cpu().numpy()
167
 
168
+ # Generate Heatmap
169
+ heatmap_cv = cv2.applyColorMap((prob * 255).astype(np.uint8), cv2.COLORMAP_JET)
170
+ heatmap_rgb = cv2.cvtColor(heatmap_cv, cv2.COLOR_BGR2RGB)
171
 
172
+ # Initial Overlay
173
+ mask = (prob > threshold).astype(np.uint8) * 255
174
+ overlay = make_overlay(rgb_orig, mask, key=text_prompt)
175
+
176
+ return overlay, Image.fromarray(heatmap_rgb), prob
177
 
178
+ except Exception:
179
+ traceback.print_exc()
180
+ return None, None, None
181
  finally:
182
+ sam_model.to("cpu")
183
+ plm_model.to("cpu")
184
  torch.cuda.empty_cache()
185
 
186
  def update_threshold_ui(image_pil, text_prompt, threshold, cached_prob):
187
+ """Instant update using CPU only."""
188
  if image_pil is None or cached_prob is None:
189
  return None
190
  rgb_orig = np.array(image_pil.convert("RGB"))
 
193
 
194
  # ----------------- Gradio UI -----------------
195
 
196
+ with gr.Blocks(title="SAM2 + PLM Segmentation") as demo:
197
+ prob_state = gr.State()
198
 
199
+ gr.Markdown("# SAM2 + PLM Interactive Segmentation")
200
+ gr.Markdown("Select a stage, enter a prompt, and run. Adjust the slider for **instant** mask updates.")
201
 
202
  with gr.Row():
203
  with gr.Column():
204
  input_image = gr.Image(type="pil", label="Input Image")
205
+ text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the surgical forceps'")
206
+
207
  with gr.Row():
208
+ stage_select = gr.Radio(choices=["Stage 1", "Stage 2"], value="Stage 1", label="Model Stage")
209
  threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold")
210
+
211
  run_btn = gr.Button("Run Inference", variant="primary")
212
 
213
  with gr.Column():
214
  out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
215
  out_heatmap = gr.Image(label="Probability Heatmap", type="pil")
216
 
217
+ # Full Pipeline
218
  run_btn.click(
219
  fn=run_prediction,
220
  inputs=[input_image, text_prompt, threshold_slider, stage_select],
221
  outputs=[out_overlay, out_heatmap, prob_state]
222
  )
223
 
224
+ # Lightweight update on slider move
225
  threshold_slider.change(
226
  fn=update_threshold_ui,
227
  inputs=[input_image, text_prompt, threshold_slider, prob_state],