akhaliq HF Staff commited on
Commit
ed119eb
Β·
verified Β·
1 Parent(s): 7d13f7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -8
app.py CHANGED
@@ -39,6 +39,7 @@ def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image:
39
  return image
40
 
41
  spaces.GPU()
 
42
  def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
43
  """
44
  Perform promptable concept segmentation using SAM3.
@@ -73,6 +74,15 @@ def segment(image: Image.Image, text: str, threshold: float, mask_threshold: flo
73
  except Exception as e:
74
  return image, f"❌ Error during segmentation: {str(e)}"
75
 
 
 
 
 
 
 
 
 
 
76
  # Gradio Interface
77
  with gr.Blocks(
78
  theme=gr.themes.Soft(),
@@ -98,7 +108,6 @@ with gr.Blocks(
98
  label="Input Image",
99
  type="pil",
100
  height=400,
101
- sources=["upload", "url"],
102
  )
103
  image_output = gr.Image(
104
  label="Output (Segmented Image)",
@@ -112,9 +121,7 @@ with gr.Blocks(
112
  placeholder="e.g., a person, ear, cat, bicycle...",
113
  scale=3
114
  )
115
- gr.Button("πŸ” Clear", size="sm", variant="secondary").click(
116
- fn=lambda: (None, "", None, 0.5, 0.5), outputs=[image_output, text_input, image_input, thresh_slider, mask_thresh_slider]
117
- )
118
 
119
  with gr.Row():
120
  thresh_slider = gr.Slider(
@@ -141,14 +148,19 @@ with gr.Blocks(
141
 
142
  segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
143
 
144
- # Event
 
 
 
 
 
 
145
  segment_btn.click(
146
  fn=segment,
147
  inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
148
  outputs=[image_output, info_output]
149
  ).then(
150
- fn=lambda: gr.Info("Segmentation complete!"),
151
- _js="() => {}"
152
  )
153
 
154
  # Examples
@@ -178,7 +190,7 @@ with gr.Blocks(
178
  gr.Examples(
179
  examples=examples,
180
  inputs=[image_input, text_input],
181
- fn=segment,
182
  outputs=[image_output, info_output],
183
  cache_examples=True,
184
  examples_per_page=10,
@@ -197,3 +209,108 @@ with gr.Blocks(
197
 
198
  if __name__ == "__main__":
199
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  return image
40
 
41
  spaces.GPU()
42
+
43
  def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
44
  """
45
  Perform promptable concept segmentation using SAM3.
 
74
  except Exception as e:
75
  return image, f"❌ Error during segmentation: {str(e)}"
76
 
77
+ def clear_all():
78
+ """Clear all inputs and outputs"""
79
+ return None, "", None, 0.5, 0.5
80
+
81
+ def segment_example(image_path: str, prompt: str):
82
+ """Handle example clicks"""
83
+ image = Image.open(image_path) if image_path else None
84
+ return segment(image, prompt, 0.5, 0.5)
85
+
86
  # Gradio Interface
87
  with gr.Blocks(
88
  theme=gr.themes.Soft(),
 
108
  label="Input Image",
109
  type="pil",
110
  height=400,
 
111
  )
112
  image_output = gr.Image(
113
  label="Output (Segmented Image)",
 
121
  placeholder="e.g., a person, ear, cat, bicycle...",
122
  scale=3
123
  )
124
+ clear_btn = gr.Button("πŸ” Clear", size="sm", variant="secondary")
 
 
125
 
126
  with gr.Row():
127
  thresh_slider = gr.Slider(
 
148
 
149
  segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
150
 
151
+ # Clear button handler
152
+ clear_btn.click(
153
+ fn=clear_all,
154
+ outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider]
155
+ )
156
+
157
+ # Segment button handler
158
  segment_btn.click(
159
  fn=segment,
160
  inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
161
  outputs=[image_output, info_output]
162
  ).then(
163
+ fn=lambda: None,
 
164
  )
165
 
166
  # Examples
 
190
  gr.Examples(
191
  examples=examples,
192
  inputs=[image_input, text_input],
193
+ fn=segment_example,
194
  outputs=[image_output, info_output],
195
  cache_examples=True,
196
  examples_per_page=10,
 
209
 
210
  if __name__ == "__main__":
211
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
212
+ ```
213
+
214
+ === utils.py ===
215
+ ```python
216
+ import torch
217
+ import numpy as np
218
+ from PIL import Image
219
+ import matplotlib
220
+ import requests
221
+ from io import BytesIO
222
+
223
+ def load_image_from_url(url: str) -> Image.Image:
224
+ """
225
+ Load an image from a URL.
226
+
227
+ Args:
228
+ url: Image URL
229
+
230
+ Returns:
231
+ PIL Image object
232
+ """
233
+ try:
234
+ response = requests.get(url, timeout=10)
235
+ response.raise_for_status()
236
+ image = Image.open(BytesIO(response.content))
237
+ return image.convert("RGB")
238
+ except Exception as e:
239
+ raise ValueError(f"Could not load image from URL: {str(e)}")
240
+
241
+ def validate_image(image: Image.Image) -> bool:
242
+ """
243
+ Validate if the image is suitable for processing.
244
+
245
+ Args:
246
+ image: PIL Image object
247
+
248
+ Returns:
249
+ True if valid, False otherwise
250
+ """
251
+ if image is None:
252
+ return False
253
+
254
+ if image.size[0] <= 0 or image.size[1] <= 0:
255
+ return False
256
+
257
+ return True
258
+
259
+ def resize_for_processing(image: Image.Image, max_size: int = 1024) -> Image.Image:
260
+ """
261
+ Resize image for processing while maintaining aspect ratio.
262
+
263
+ Args:
264
+ image: Input PIL Image
265
+ max_size: Maximum size for the longer dimension
266
+
267
+ Returns:
268
+ Resized PIL Image
269
+ """
270
+ width, height = image.size
271
+ if max(width, height) <= max_size:
272
+ return image
273
+
274
+ if width > height:
275
+ new_width = max_size
276
+ new_height = int(height * max_size / width)
277
+ else:
278
+ new_height = max_size
279
+ new_width = int(width * max_size / height)
280
+
281
+ return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
282
+
283
+ def overlay_masks_advanced(image: Image.Image, masks: torch.Tensor, alpha: float = 0.5) -> Image.Image:
284
+ """
285
+ Advanced overlay function with customizable alpha.
286
+
287
+ Args:
288
+ image: Input PIL Image
289
+ masks: Segmentation masks tensor
290
+ alpha: Overlay transparency (0-1)
291
+
292
+ Returns:
293
+ Overlaid PIL Image
294
+ """
295
+ image = image.convert("RGBA")
296
+ masks = 255 * masks.cpu().numpy().astype(np.uint8)
297
+
298
+ n_masks = masks.shape[0]
299
+ if n_masks == 0:
300
+ return image.convert("RGB")
301
+
302
+ # Use a good colormap
303
+ cmap = matplotlib.colormaps.get_cmap("tab10").resampled(n_masks)
304
+ colors = [
305
+ tuple(int(c * 255) for c in cmap(i)[:3])
306
+ for i in range(n_masks)
307
+ ]
308
+
309
+ for mask, color in zip(masks, colors):
310
+ mask_img = Image.fromarray(mask)
311
+ overlay = Image.new("RGBA", image.size, color + (0,))
312
+ alpha_map = mask_img.point(lambda v: int(v * alpha * 255))
313
+ overlay.putalpha(alpha_map)
314
+ image = Image.alpha_composite(image, overlay)
315
+
316
+ return image