aadarsh99 commited on
Commit
17751a0
·
1 Parent(s): 9f6e3d0
Files changed (1) hide show
  1. app.py +45 -16
app.py CHANGED
@@ -105,10 +105,14 @@ def ensure_models_loaded():
105
  # ----------------- GPU Inference -----------------
106
 
107
  @spaces.GPU(duration=120)
108
- def run_prediction(image_pil, text_prompt, threshold=0.5):
109
- if image_pil is None or not text_prompt:
110
  return None, None, None
111
 
 
 
 
 
112
  ensure_models_loaded()
113
  sam_model = MODEL_CACHE["sam"]
114
  plm_model = MODEL_CACHE["plm"]
@@ -140,7 +144,7 @@ def run_prediction(image_pil, text_prompt, threshold=0.5):
140
  with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp:
141
  image_pil.save(tmp.name)
142
  # Qwen/PLM processes the text prompt here
143
- sp, dp = plm_model([text_prompt], image_emb.shape[2], image_emb.shape[3], [tmp.name])
144
 
145
  # SAM2 Mask Decoder
146
  dec = sam_model.sam_mask_decoder
@@ -169,7 +173,8 @@ def run_prediction(image_pil, text_prompt, threshold=0.5):
169
  heatmap_rgb = cv2.cvtColor(heatmap_cv, cv2.COLOR_BGR2RGB)
170
 
171
  mask = (prob > threshold).astype(np.uint8) * 255
172
- overlay = make_overlay(rgb_orig, mask, key=text_prompt)
 
173
 
174
  return overlay, Image.fromarray(heatmap_rgb), prob
175
 
@@ -182,13 +187,15 @@ def run_prediction(image_pil, text_prompt, threshold=0.5):
182
  plm_model.to("cpu")
183
  torch.cuda.empty_cache()
184
 
185
- def update_threshold_ui(image_pil, text_prompt, threshold, cached_prob):
186
  """Real-time update using CPU only (no GPU quota usage)."""
187
  if image_pil is None or cached_prob is None:
188
  return None
189
  rgb_orig = np.array(image_pil.convert("RGB"))
190
  mask = (cached_prob > threshold).astype(np.uint8) * 255
191
- return make_overlay(rgb_orig, mask, key=text_prompt)
 
 
192
 
193
  # ----------------- UI Styling & Layout -----------------
194
 
@@ -202,6 +209,15 @@ h1 {
202
  font-size: 1.1em;
203
  margin-bottom: 20px;
204
  }
 
 
 
 
 
 
 
 
 
205
  """
206
 
207
  theme = gr.themes.Soft(
@@ -227,13 +243,25 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConvSeg-Net Demo") as demo:
227
  with gr.Column(scale=1):
228
  input_image = gr.Image(type="pil", label="Input Image", height=400)
229
 
 
 
230
  with gr.Group():
231
- text_prompt = gr.Textbox(
232
- label="Conversational Prompt",
233
- placeholder="e.g., Segment the object that is prone to rolling...",
234
- lines=2
235
- )
236
- gr.Markdown("💡 **Tip:** The model works best when prompts start with **'Segment the...'**")
 
 
 
 
 
 
 
 
 
 
237
 
238
  with gr.Accordion("⚙️ Advanced Options", open=False):
239
  threshold_slider = gr.Slider(
@@ -250,15 +278,16 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConvSeg-Net Demo") as demo:
250
  out_heatmap = gr.Image(label="Confidence Heatmap", type="pil")
251
 
252
  # --- Examples Section ---
 
253
  gr.Markdown("### 📝 Try Examples")
254
  gr.Examples(
255
  examples=[
256
- ["./examples/elephants.png", "Segment the elephant acting as the vanguard of the herd."],
257
- ["./examples/luggage.png", "Segment luggage resting precariously."],
258
- ["./examples/veggies.png", "Segment the produce harvested from underground."],
259
  ],
260
  inputs=[input_image, text_prompt],
261
- # cache_examples=True # Uncomment if you want to pre-compute these on startup
262
  )
263
 
264
  # --- Event Handling ---
 
105
  # ----------------- GPU Inference -----------------
106
 
107
  @spaces.GPU(duration=120)
108
+ def run_prediction(image_pil, user_text, threshold=0.5):
109
+ if image_pil is None or not user_text:
110
  return None, None, None
111
 
112
+ # --- Prepend the required prefix ---
113
+ full_prompt = f"Segment the {user_text.strip()}"
114
+ logging.info(f"Processing prompt: {full_prompt}")
115
+
116
  ensure_models_loaded()
117
  sam_model = MODEL_CACHE["sam"]
118
  plm_model = MODEL_CACHE["plm"]
 
144
  with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp:
145
  image_pil.save(tmp.name)
146
  # Qwen/PLM processes the text prompt here
147
+ sp, dp = plm_model([full_prompt], image_emb.shape[2], image_emb.shape[3], [tmp.name])
148
 
149
  # SAM2 Mask Decoder
150
  dec = sam_model.sam_mask_decoder
 
173
  heatmap_rgb = cv2.cvtColor(heatmap_cv, cv2.COLOR_BGR2RGB)
174
 
175
  mask = (prob > threshold).astype(np.uint8) * 255
176
+ # Use full_prompt for key to ensure consistent colors
177
+ overlay = make_overlay(rgb_orig, mask, key=full_prompt)
178
 
179
  return overlay, Image.fromarray(heatmap_rgb), prob
180
 
 
187
  plm_model.to("cpu")
188
  torch.cuda.empty_cache()
189
 
190
+ def update_threshold_ui(image_pil, user_text, threshold, cached_prob):
191
  """Real-time update using CPU only (no GPU quota usage)."""
192
  if image_pil is None or cached_prob is None:
193
  return None
194
  rgb_orig = np.array(image_pil.convert("RGB"))
195
  mask = (cached_prob > threshold).astype(np.uint8) * 255
196
+ # Reconstruct full prompt to maintain consistent color hashing
197
+ full_prompt = f"Segment the {user_text.strip()}" if user_text else "mask"
198
+ return make_overlay(rgb_orig, mask, key=full_prompt)
199
 
200
  # ----------------- UI Styling & Layout -----------------
201
 
 
209
  font-size: 1.1em;
210
  margin-bottom: 20px;
211
  }
212
+ .prefix-container {
213
+ display: flex;
214
+ align-items: center;
215
+ justify-content: center;
216
+ height: 100%;
217
+ font-size: 1.1em;
218
+ font-weight: 600;
219
+ color: #444;
220
+ }
221
  """
222
 
223
  theme = gr.themes.Soft(
 
243
  with gr.Column(scale=1):
244
  input_image = gr.Image(type="pil", label="Input Image", height=400)
245
 
246
+ # Custom prompt input layout
247
+ gr.Markdown("**Conversational Prompt**")
248
  with gr.Group():
249
+ with gr.Row(equal_height=True):
250
+ # Fixed Prefix
251
+ gr.HTML(
252
+ "<div class='prefix-container'>Segment the</div>",
253
+ elem_classes="prefix-box",
254
+ min_width=110,
255
+ max_width=110
256
+ )
257
+ # User Input
258
+ text_prompt = gr.Textbox(
259
+ show_label=False,
260
+ container=False,
261
+ placeholder="object that is prone to rolling...",
262
+ lines=1,
263
+ scale=5
264
+ )
265
 
266
  with gr.Accordion("⚙️ Advanced Options", open=False):
267
  threshold_slider = gr.Slider(
 
278
  out_heatmap = gr.Image(label="Confidence Heatmap", type="pil")
279
 
280
  # --- Examples Section ---
281
+ # Note: removed "Segment the " from examples as it is auto-appended now
282
  gr.Markdown("### 📝 Try Examples")
283
  gr.Examples(
284
  examples=[
285
+ ["./examples/elephants.png", "elephant acting as the vanguard of the herd."],
286
+ ["./examples/luggage.png", "luggage resting precariously."],
287
+ ["./examples/veggies.png", "produce harvested from underground."],
288
  ],
289
  inputs=[input_image, text_prompt],
290
+ # cache_examples=True
291
  )
292
 
293
  # --- Event Handling ---