aadarsh99 commited on
Commit
96c10ec
·
1 Parent(s): 9d1694a

update app

Browse files
Files changed (1) hide show
  1. app.py +64 -171
app.py CHANGED
@@ -23,7 +23,6 @@ from plm_adapter_lora_with_image_input_only_text_positions import PLMLanguageAda
23
  HF_REPO_ID = "aadarsh99/ConvSeg-Stage1"
24
  SAM2_CONFIG = "sam2_hiera_l.yaml"
25
 
26
- # Filenames
27
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
28
  FINAL_CKPT_NAME = "fine_tuned_sam2_batched_100000.torch"
29
  PLM_CKPT_NAME = "fine_tuned_sam2_batched_plm_100000.torch"
@@ -32,38 +31,24 @@ LORA_CKPT_NAME = None
32
  SQUARE_DIM = 1024
33
  logging.basicConfig(level=logging.INFO)
34
 
35
- # ----------------- Globals (Ram Cache) -----------------
36
- # We keep these on CPU globally so they persist between runs
37
- # without taking up GPU memory (which gets reset).
38
  MODEL_SAM_CPU = None
39
  PLM_CPU = None
40
 
41
- # ----------------- Helper: Download Logic -----------------
 
42
  def download_if_needed(filename):
43
- """
44
- Checks if file exists locally. If not, downloads from HF Repo.
45
- Returns the valid path to the file.
46
- """
47
  if os.path.exists(filename):
48
- logging.info(f"Found local file: {filename}")
49
  return filename
50
-
51
- # hf_hub_download checks the cache automatically.
52
- # It won't re-download if the file is already in the HF cache.
53
- logging.info(f"Checking HF Cache for {filename}...")
54
  try:
55
- path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
56
- return path
57
  except Exception as e:
58
- raise FileNotFoundError(f"Could not find {filename} locally or in HF repo {HF_REPO_ID}. Error: {e}")
59
-
60
- # ----------------- Overlay Style Helpers -----------------
61
- EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
62
 
63
  def _hex_to_rgb(h: str):
64
  h = h.lstrip("#")
65
  return tuple(int(h[i : i + 2], 16) for i in (0, 2, 4))
66
 
 
67
  EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
68
 
69
  def stable_color(key: str):
@@ -75,9 +60,7 @@ def tint(rgb, amt: float = 0.1):
75
 
76
  def edge_map(mask_bool: np.ndarray, width_px: int = 2) -> Image.Image:
77
  m = Image.fromarray((mask_bool.astype(np.uint8) * 255), "L")
78
- edges = ImageChops.difference(
79
- m.filter(ImageFilter.MaxFilter(3)), m.filter(ImageFilter.MinFilter(3))
80
- )
81
  for _ in range(max(0, width_px - 1)):
82
  edges = edges.filter(ImageFilter.MaxFilter(3))
83
  return edges.point(lambda p: 255 if p > 0 else 0)
@@ -97,248 +80,158 @@ def make_overlay(rgb: np.ndarray, mask: np.ndarray, key: str = "mask") -> Image.
97
  H, W = mask.shape[:2]
98
  if base.size != (W, H):
99
  base = base.resize((W, H), Image.BICUBIC)
100
-
101
  base_rgba = base.convert("RGBA")
102
  mask_bool = mask > 0
103
-
104
  color = stable_color(key)
105
  fill_rgb = tint(color, 0.1)
106
- alpha_fill = 0.7
107
- edge_width = 2
108
-
109
- a = int(round(alpha_fill * 255))
110
- tgt_w, tgt_h = base_rgba.size
111
-
112
- fill_layer = Image.new("RGBA", (tgt_w, tgt_h), fill_rgb + (0,))
113
- fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * a), "L")
114
  fill_layer.putalpha(fill_alpha)
115
-
116
- edgesL = edge_map(mask_bool, width_px=edge_width)
117
- stroke = Image.new("RGBA", (tgt_w, tgt_h), color + (0,))
118
  stroke.putalpha(edgesL)
119
-
120
  out = Image.alpha_composite(base_rgba, fill_layer)
121
  out = Image.alpha_composite(out, stroke)
122
- out = out.convert("RGB")
123
- return _apply_rounded_corners(out, max(12, int(0.06 * min(out.size))))
124
 
125
- # ----------------- Image Processing Helpers -----------------
126
 
127
  def _resize_pad_square(arr: np.ndarray, max_dim: int, *, is_mask: bool) -> np.ndarray:
128
  h, w = arr.shape[:2]
129
  scale = float(max_dim) / float(max(h, w))
130
- new_w = max(1, int(round(w * scale)))
131
- new_h = max(1, int(round(h * scale)))
132
-
133
- if is_mask:
134
- interp = cv2.INTER_NEAREST
135
- else:
136
- interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
137
-
138
  arr = cv2.resize(arr, (new_w, new_h), interpolation=interp)
139
-
140
- pad_w = max_dim - new_w
141
- pad_h = max_dim - new_h
142
- left = pad_w // 2
143
- right = pad_w - left
144
- top = pad_h // 2
145
- bottom = pad_h - top
146
-
147
- border_val = 0 if is_mask else (0, 0, 0)
148
- arr = cv2.copyMakeBorder(
149
- arr, top, bottom, left, right, borderType=cv2.BORDER_CONSTANT, value=border_val
150
- )
151
- return np.ascontiguousarray(arr)
152
 
153
  def _resize_pad_square_meta(h: int, w: int, max_dim: int):
154
  scale = float(max_dim) / float(max(h, w))
155
- new_w = max(1, int(round(w * scale)))
156
- new_h = max(1, int(round(h * scale)))
157
- pad_w = max_dim - new_w
158
- pad_h = max_dim - new_h
159
- left = pad_w // 2
160
- right = pad_w - left
161
- top = pad_h // 2
162
- bottom = pad_h - top
163
- return {
164
- "scale": scale, "new_w": new_w, "new_h": new_h,
165
- "left": left, "right": right, "top": top, "bottom": bottom,
166
- }
167
 
168
  def _unpad_and_resize_pred_to_gt(logit_sq: torch.Tensor, meta: dict, out_hw: tuple[int, int]) -> torch.Tensor:
169
  top, left = meta["top"], meta["left"]
170
  nh, nw = meta["new_h"], meta["new_w"]
171
- crop = logit_sq[top : top + nh, left : left + nw]
172
- crop = crop.unsqueeze(0).unsqueeze(0)
173
- up = F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)
174
- return up[0, 0]
175
 
176
- # ----------------- Model Loading (CPU Caching) -----------------
177
 
178
  def ensure_models_loaded_on_cpu():
179
- """
180
- Ensures models are loaded in Global CPU RAM.
181
- This avoids re-reading from disk/cache on every run.
182
- """
183
  global MODEL_SAM_CPU, PLM_CPU
184
-
185
  if MODEL_SAM_CPU is not None and PLM_CPU is not None:
186
- return # Already loaded in RAM
187
-
188
- logging.info("Loading models into CPU RAM (this happens once)...")
189
-
190
- # 1. Base SAM2 Model
191
  base_path = download_if_needed(BASE_CKPT_NAME)
192
-
193
- # Build on CPU
194
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
195
-
196
- # 2. Fine-tuned Weights
197
  final_path = download_if_needed(FINAL_CKPT_NAME)
198
  sd = torch.load(final_path, map_location="cpu")
199
  model.load_state_dict(sd.get("model", sd), strict=True)
200
-
201
- # Save to Global (CPU)
202
  MODEL_SAM_CPU = model
203
 
204
- # 3. PLM Adapter
205
- C = model.sam_mask_decoder.transformer_dim
206
  plm = PLMLanguageAdapter(
207
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
208
- transformer_dim=C,
209
- n_sparse_tokens=0,
210
- use_dense_bias=True,
211
- use_lora=True,
212
- lora_r=16,
213
- lora_alpha=32,
214
- lora_dropout=0.05,
215
- dtype=torch.bfloat16,
216
- device="cpu",
217
  )
218
-
219
  plm_path = download_if_needed(PLM_CKPT_NAME)
220
  plm_sd = torch.load(plm_path, map_location="cpu")
221
  plm.load_state_dict(plm_sd["plm"], strict=True)
222
-
223
- if LORA_CKPT_NAME:
224
- lora_path = download_if_needed(LORA_CKPT_NAME)
225
- plm.load_lora(lora_path)
226
-
227
  plm.eval()
228
  PLM_CPU = plm
229
- logging.info("Models successfully loaded into CPU RAM.")
230
-
231
 
232
  @spaces.GPU(duration=120)
233
- def run_prediction(image_pil, text_prompt):
234
  if image_pil is None or not text_prompt:
235
  return None, None
236
 
237
- # 1. Ensure models are in RAM (Fast check)
238
  ensure_models_loaded_on_cpu()
239
-
240
- # 2. Move to GPU (The only 'loading' cost per run)
241
- # We rely on the global variables
242
- logging.info("Moving models to GPU...")
243
  MODEL_SAM_CPU.to("cuda")
244
  PLM_CPU.to("cuda")
245
 
246
  predictor = None
247
-
248
  try:
249
- # Instantiate Predictor on GPU
250
  predictor = SAM2ImagePredictor(MODEL_SAM_CPU)
251
-
252
- # 3. Preprocess Image
253
  rgb_orig = np.array(image_pil.convert("RGB"))
254
  Hgt, Wgt = rgb_orig.shape[:2]
255
  meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
256
  rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
257
 
258
- # 4. SAM2 Image Encoding
259
  predictor.set_image(rgb_sq)
260
  image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
261
  hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
262
- _, _, H_feat, W_feat = image_emb.shape
263
-
264
- # 5. PLM Inference
265
  temp_path = "temp_input.jpg"
266
  image_pil.save(temp_path)
267
-
268
- sp, dp = PLM_CPU([text_prompt], H_feat, W_feat, [temp_path])
269
 
270
- # 6. Prepare SAM2 Decoder inputs
271
  dec = predictor.model.sam_mask_decoder
272
- dev = next(dec.parameters()).device
273
- dtype = next(dec.parameters()).dtype
274
-
275
- image_pe = predictor.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype)
276
- image_emb = image_emb.to(dev, dtype)
277
- hi = [h.to(dev, dtype) for h in hi]
278
- sp, dp = sp.to(dev, dtype), dp.to(dev, dtype)
279
-
280
- # 7. SAM2 Decoding
281
  low, scores, _, _ = dec(
282
- image_embeddings=image_emb,
283
- image_pe=image_pe,
284
- sparse_prompt_embeddings=sp,
285
- dense_prompt_embeddings=dp,
286
- multimask_output=True,
287
- repeat_image=False,
288
- high_res_features=hi,
289
  )
290
 
291
  logits_sq = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
292
- best = scores.argmax(dim=1).item()
293
- logit_sq = logits_sq[0, best]
294
- logit_gt = _unpad_and_resize_pred_to_gt(logit_sq, meta, (Hgt, Wgt))
295
-
296
- prob = torch.sigmoid(logit_gt)
297
- mask = (prob > 0.5).cpu().numpy().astype(np.uint8) * 255
298
-
299
- # 8. Visualization
300
  overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
301
- mask_img = Image.fromarray(mask, mode="L")
 
 
 
 
 
 
302
 
303
- return overlay_img, mask_img
304
 
305
  except Exception as e:
306
- print("An error occurred during inference:")
307
  traceback.print_exc()
308
  raise e
309
-
310
  finally:
311
- # CRITICAL: Move models back to CPU
312
- # This preserves the Global Variable on CPU RAM for the next run.
313
- # If we leave them on CUDA, they might be lost when ZeroGPU releases the device.
314
- logging.info("Moving models back to CPU...")
315
  MODEL_SAM_CPU.to("cpu")
316
  PLM_CPU.to("cpu")
317
-
318
- if predictor:
319
- del predictor
320
  torch.cuda.empty_cache()
321
 
322
  # ----------------- Gradio UI -----------------
323
 
324
- with gr.Blocks(title="SAM2 + PLM Interactive Segmentation") as demo:
325
  gr.Markdown("# SAM2 + PLM Interactive Segmentation")
326
- gr.Markdown("Enter a text prompt to segment objects in the image.")
327
-
328
  with gr.Row():
329
  with gr.Column():
330
  input_image = gr.Image(type="pil", label="Input Image")
331
  text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the red car'")
 
 
 
 
332
  run_btn = gr.Button("Segment", variant="primary")
333
 
334
  with gr.Column():
335
  out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
336
- out_mask = gr.Image(label="Binary Mask", type="pil")
337
 
338
  run_btn.click(
339
  fn=run_prediction,
340
- inputs=[input_image, text_prompt],
341
- outputs=[out_overlay, out_mask]
342
  )
343
 
344
  if __name__ == "__main__":
 
23
  HF_REPO_ID = "aadarsh99/ConvSeg-Stage1"
24
  SAM2_CONFIG = "sam2_hiera_l.yaml"
25
 
 
26
  BASE_CKPT_NAME = "sam2_hiera_large.pt"
27
  FINAL_CKPT_NAME = "fine_tuned_sam2_batched_100000.torch"
28
  PLM_CKPT_NAME = "fine_tuned_sam2_batched_plm_100000.torch"
 
31
  SQUARE_DIM = 1024
32
  logging.basicConfig(level=logging.INFO)
33
 
 
 
 
34
  MODEL_SAM_CPU = None
35
  PLM_CPU = None
36
 
37
+ # ----------------- Helper Functions -----------------
38
+
39
  def download_if_needed(filename):
 
 
 
 
40
  if os.path.exists(filename):
 
41
  return filename
 
 
 
 
42
  try:
43
+ return hf_hub_download(repo_id=HF_REPO_ID, filename=filename)
 
44
  except Exception as e:
45
+ raise FileNotFoundError(f"Could not find {filename} in HF repo {HF_REPO_ID}. Error: {e}")
 
 
 
46
 
47
  def _hex_to_rgb(h: str):
48
  h = h.lstrip("#")
49
  return tuple(int(h[i : i + 2], 16) for i in (0, 2, 4))
50
 
51
+ EDGE_COLORS_HEX = ["#3A86FF", "#FF006E", "#43AA8B", "#F3722C", "#8338EC", "#90BE6D"]
52
  EDGE_COLORS = [_hex_to_rgb(h) for h in EDGE_COLORS_HEX]
53
 
54
  def stable_color(key: str):
 
60
 
61
  def edge_map(mask_bool: np.ndarray, width_px: int = 2) -> Image.Image:
62
  m = Image.fromarray((mask_bool.astype(np.uint8) * 255), "L")
63
+ edges = ImageChops.difference(m.filter(ImageFilter.MaxFilter(3)), m.filter(ImageFilter.MinFilter(3)))
 
 
64
  for _ in range(max(0, width_px - 1)):
65
  edges = edges.filter(ImageFilter.MaxFilter(3))
66
  return edges.point(lambda p: 255 if p > 0 else 0)
 
80
  H, W = mask.shape[:2]
81
  if base.size != (W, H):
82
  base = base.resize((W, H), Image.BICUBIC)
 
83
  base_rgba = base.convert("RGBA")
84
  mask_bool = mask > 0
 
85
  color = stable_color(key)
86
  fill_rgb = tint(color, 0.1)
87
+ fill_layer = Image.new("RGBA", base_rgba.size, fill_rgb + (0,))
88
+ fill_alpha = Image.fromarray((mask_bool.astype(np.uint8) * 178), "L")
 
 
 
 
 
 
89
  fill_layer.putalpha(fill_alpha)
90
+ edgesL = edge_map(mask_bool, width_px=2)
91
+ stroke = Image.new("RGBA", base_rgba.size, color + (0,))
 
92
  stroke.putalpha(edgesL)
 
93
  out = Image.alpha_composite(base_rgba, fill_layer)
94
  out = Image.alpha_composite(out, stroke)
95
+ return _apply_rounded_corners(out.convert("RGB"), max(12, int(0.06 * min(out.size))))
 
96
 
97
+ # ----------------- Image Processing -----------------
98
 
99
  def _resize_pad_square(arr: np.ndarray, max_dim: int, *, is_mask: bool) -> np.ndarray:
100
  h, w = arr.shape[:2]
101
  scale = float(max_dim) / float(max(h, w))
102
+ new_w, new_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
103
+ interp = cv2.INTER_NEAREST if is_mask else (cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR)
 
 
 
 
 
 
104
  arr = cv2.resize(arr, (new_w, new_h), interpolation=interp)
105
+ pad_w, pad_h = max_dim - new_w, max_dim - new_h
106
+ left, top = pad_w // 2, pad_h // 2
107
+ return np.ascontiguousarray(cv2.copyMakeBorder(arr, top, pad_h - top, left, pad_w - left, cv2.BORDER_CONSTANT, value=0))
 
 
 
 
 
 
 
 
 
 
108
 
109
  def _resize_pad_square_meta(h: int, w: int, max_dim: int):
110
  scale = float(max_dim) / float(max(h, w))
111
+ new_w, new_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
112
+ return {"scale": scale, "new_w": new_w, "new_h": new_h, "left": (max_dim - new_w) // 2, "top": (max_dim - new_h) // 2}
 
 
 
 
 
 
 
 
 
 
113
 
114
  def _unpad_and_resize_pred_to_gt(logit_sq: torch.Tensor, meta: dict, out_hw: tuple[int, int]) -> torch.Tensor:
115
  top, left = meta["top"], meta["left"]
116
  nh, nw = meta["new_h"], meta["new_w"]
117
+ crop = logit_sq[top : top + nh, left : left + nw].unsqueeze(0).unsqueeze(0)
118
+ return F.interpolate(crop, size=out_hw, mode="bilinear", align_corners=False)[0, 0]
 
 
119
 
120
+ # ----------------- Prediction Logic -----------------
121
 
122
  def ensure_models_loaded_on_cpu():
 
 
 
 
123
  global MODEL_SAM_CPU, PLM_CPU
 
124
  if MODEL_SAM_CPU is not None and PLM_CPU is not None:
125
+ return
126
+ logging.info("Loading models into CPU RAM...")
 
 
 
127
  base_path = download_if_needed(BASE_CKPT_NAME)
 
 
128
  model = build_sam2(SAM2_CONFIG, base_path, device="cpu")
 
 
129
  final_path = download_if_needed(FINAL_CKPT_NAME)
130
  sd = torch.load(final_path, map_location="cpu")
131
  model.load_state_dict(sd.get("model", sd), strict=True)
 
 
132
  MODEL_SAM_CPU = model
133
 
 
 
134
  plm = PLMLanguageAdapter(
135
  model_name="Qwen/Qwen2.5-VL-3B-Instruct",
136
+ transformer_dim=model.sam_mask_decoder.transformer_dim,
137
+ n_sparse_tokens=0, use_dense_bias=True, use_lora=True,
138
+ lora_r=16, lora_alpha=32, lora_dropout=0.05,
139
+ dtype=torch.bfloat16, device="cpu",
 
 
 
 
 
140
  )
 
141
  plm_path = download_if_needed(PLM_CKPT_NAME)
142
  plm_sd = torch.load(plm_path, map_location="cpu")
143
  plm.load_state_dict(plm_sd["plm"], strict=True)
 
 
 
 
 
144
  plm.eval()
145
  PLM_CPU = plm
 
 
146
 
147
  @spaces.GPU(duration=120)
148
+ def run_prediction(image_pil, text_prompt, threshold):
149
  if image_pil is None or not text_prompt:
150
  return None, None
151
 
 
152
  ensure_models_loaded_on_cpu()
 
 
 
 
153
  MODEL_SAM_CPU.to("cuda")
154
  PLM_CPU.to("cuda")
155
 
156
  predictor = None
 
157
  try:
 
158
  predictor = SAM2ImagePredictor(MODEL_SAM_CPU)
 
 
159
  rgb_orig = np.array(image_pil.convert("RGB"))
160
  Hgt, Wgt = rgb_orig.shape[:2]
161
  meta = _resize_pad_square_meta(Hgt, Wgt, SQUARE_DIM)
162
  rgb_sq = _resize_pad_square(rgb_orig, SQUARE_DIM, is_mask=False)
163
 
 
164
  predictor.set_image(rgb_sq)
165
  image_emb = predictor._features["image_embed"][-1].unsqueeze(0)
166
  hi = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]]
167
+
 
 
168
  temp_path = "temp_input.jpg"
169
  image_pil.save(temp_path)
170
+ sp, dp = PLM_CPU([text_prompt], image_emb.shape[2], image_emb.shape[3], [temp_path])
 
171
 
 
172
  dec = predictor.model.sam_mask_decoder
173
+ dev, dtype = next(dec.parameters()).device, next(dec.parameters()).dtype
174
+
 
 
 
 
 
 
 
175
  low, scores, _, _ = dec(
176
+ image_embeddings=image_emb.to(dev, dtype),
177
+ image_pe=predictor.model.sam_prompt_encoder.get_dense_pe().to(dev, dtype),
178
+ sparse_prompt_embeddings=sp.to(dev, dtype),
179
+ dense_prompt_embeddings=dp.to(dev, dtype),
180
+ multimask_output=True, repeat_image=False,
181
+ high_res_features=[h.to(dev, dtype) for h in hi],
 
182
  )
183
 
184
  logits_sq = predictor._transforms.postprocess_masks(low, (SQUARE_DIM, SQUARE_DIM))
185
+ logit_gt = _unpad_and_resize_pred_to_gt(logits_sq[0, scores.argmax(dim=1).item()], meta, (Hgt, Wgt))
186
+
187
+ # 1. Calculate Probabilities (Heatmap)
188
+ prob = torch.sigmoid(logit_gt).cpu().numpy()
189
+
190
+ # 2. Apply dynamic threshold for overlay
191
+ mask = (prob > threshold).astype(np.uint8) * 255
 
192
  overlay_img = make_overlay(rgb_orig, mask, key=text_prompt)
193
+
194
+ # 3. Create Heatmap Visualization
195
+ # Scale 0.0-1.0 to 0-255
196
+ prob_uint8 = (prob * 255).astype(np.uint8)
197
+ heatmap_color = cv2.applyColorMap(prob_uint8, cv2.COLORMAP_JET)
198
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
199
+ heatmap_pil = Image.fromarray(heatmap_color)
200
 
201
+ return overlay_img, heatmap_pil
202
 
203
  except Exception as e:
 
204
  traceback.print_exc()
205
  raise e
 
206
  finally:
 
 
 
 
207
  MODEL_SAM_CPU.to("cpu")
208
  PLM_CPU.to("cpu")
209
+ if predictor: del predictor
 
 
210
  torch.cuda.empty_cache()
211
 
212
  # ----------------- Gradio UI -----------------
213
 
214
+ with gr.Blocks(title="SAM2 + PLM Segmentation") as demo:
215
  gr.Markdown("# SAM2 + PLM Interactive Segmentation")
216
+
 
217
  with gr.Row():
218
  with gr.Column():
219
  input_image = gr.Image(type="pil", label="Input Image")
220
  text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., 'the red car'")
221
+ threshold_slider = gr.Slider(
222
+ minimum=0.0, maximum=1.0, value=0.5, step=0.01,
223
+ label="Confidence Threshold", info="Adjust to include more/less of the object"
224
+ )
225
  run_btn = gr.Button("Segment", variant="primary")
226
 
227
  with gr.Column():
228
  out_overlay = gr.Image(label="Segmentation Overlay", type="pil")
229
+ out_heatmap = gr.Image(label="Probability Heatmap", type="pil")
230
 
231
  run_btn.click(
232
  fn=run_prediction,
233
+ inputs=[input_image, text_prompt, threshold_slider],
234
+ outputs=[out_overlay, out_heatmap]
235
  )
236
 
237
  if __name__ == "__main__":