prithivMLmods commited on
Commit
bc27759
·
verified ·
1 Parent(s): 4b1c88a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -174
app.py CHANGED
@@ -1,18 +1,23 @@
1
  import os
 
 
 
2
  import spaces
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
- import random
7
- from PIL import Image, ImageDraw
 
8
  from typing import Iterable
9
  from gradio.themes import Soft
10
  from gradio.themes.utils import colors, fonts, sizes
11
- from transformers import Sam3Processor, Sam3Model, Sam3VideoModel, Sam3VideoProcessor
12
- import cv2
13
- import tempfile
 
14
 
15
- # --- Theme Definition ---
16
  colors.steel_blue = colors.Color(
17
  name="steel_blue",
18
  c50="#EBF3F8",
@@ -28,7 +33,7 @@ colors.steel_blue = colors.Color(
28
  c950="#1E3450",
29
  )
30
 
31
- class SteelBlueTheme(Soft):
32
  def __init__(
33
  self,
34
  *,
@@ -73,215 +78,229 @@ class SteelBlueTheme(Soft):
73
  block_label_background_fill="*primary_200",
74
  )
75
 
76
- steel_blue_theme = SteelBlueTheme()
77
 
78
- # --- Model Loading ---
 
79
  device = "cuda" if torch.cuda.is_available() else "cpu"
80
- print(f"Using device: {device}")
81
-
82
- MODELS = {}
83
 
84
- def get_model(model_type):
85
- if model_type not in MODELS:
86
- if model_type == "sam3_image":
87
- print("Loading SAM3 Image Model and Processor...")
88
- model = Sam3Model.from_pretrained("facebook/sam3").to(device)
89
- processor = Sam3Processor.from_pretrained("facebook/sam3")
90
- MODELS[model_type] = (model, processor)
91
- elif model_type == "sam3_video_text":
92
- print("Loading SAM3 Video Model and Processor...")
93
- model = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
94
- processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
95
- MODELS[model_type] = (model, processor)
96
- return MODELS[model_type]
97
 
98
- try:
99
- get_model("sam3_image")
100
- print("Image model loaded successfully.")
101
- except Exception as e:
102
- print(f"Error loading image model: {e}")
103
- print("Ensure you have the correct libraries installed and access to the model.")
104
 
105
- # --- Helper Functions ---
106
- def overlay_masks(image, masks, alpha=0.5):
107
- """ Overlays masks on the image with random colors. """
108
- image = image.convert("RGBA")
109
- overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
110
- draw = ImageDraw.Draw(overlay)
111
 
112
- for mask in masks:
113
- # Generate a random color for each mask
114
- color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), int(255 * alpha))
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # Convert boolean mask to an image that can be pasted
117
- mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- # Draw the colored mask
120
- draw.bitmap((0, 0), mask_pil, fill=color)
 
 
 
 
 
 
 
 
 
 
121
 
122
- # Combine the original image with the overlay
123
- combined = Image.alpha_composite(image, overlay)
124
- return combined.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- # --- Core Functions ---
127
  @spaces.GPU
128
- def segment_image(input_image, text_prompt, threshold=0.5):
129
- if input_image is None:
130
- raise gr.Error("Please upload an image.")
131
- if not text_prompt:
132
- raise gr.Error("Please enter a text prompt (e.g., 'cat', 'face').")
133
 
134
  try:
135
- model, processor = get_model("sam3_image")
136
- except Exception as e:
137
- raise gr.Error(f"Model not loaded correctly: {e}")
 
138
 
139
- image_pil = input_image.convert("RGB")
140
- inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
141
 
142
- with torch.no_grad():
143
- outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- results = processor.post_process_instance_segmentation(
146
- outputs,
147
- threshold=threshold,
148
- mask_threshold=0.5,
149
- target_sizes=inputs.get("original_sizes").tolist()
150
- )[0]
151
 
152
- masks = results['masks']
153
- scores = results['scores']
154
-
155
- annotations = []
156
- masks_np = masks.cpu().numpy()
157
- scores_np = scores.cpu().numpy()
158
-
159
- for i, mask in enumerate(masks_np):
160
- score_val = scores_np[i]
161
- label = f"{text_prompt} ({score_val:.2f})"
162
- annotations.append((mask, label))
163
-
164
- return (image_pil, annotations)
165
 
166
- @spaces.GPU
167
- def process_video_text(video_path, text_prompt, max_frames, timeout_seconds):
168
- if not video_path or not text_prompt:
169
- return None, "Missing video or prompt."
 
170
  try:
171
- model, processor = get_model("sam3_video_text")
172
- cap = cv2.VideoCapture(video_path)
173
- fps = cap.get(cv2.CAP_PROP_FPS)
174
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
175
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
176
- frames = []
177
- frame_count = 0
178
- while cap.isOpened():
179
- ret, frame = cap.read()
180
- if not ret or (max_frames > 0 and frame_count >= max_frames):
181
- break
182
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
183
- frame_count += 1
184
- cap.release()
185
-
186
- inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16)
187
- inference_session = processor.add_text_prompt(inference_session=inference_session, text=text_prompt)
188
 
189
- output_path = tempfile.mktemp(suffix=".mp4")
190
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
191
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
192
-
193
- for model_outputs in model.propagate_in_video_iterator(inference_session=inference_session, max_frame_num_to_track=len(frames)):
194
- processed_outputs = processor.postprocess_outputs(inference_session, model_outputs)
195
- frame_idx = model_outputs.frame_idx
196
- orig_frame = Image.fromarray(frames[frame_idx])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- if 'masks' in processed_outputs:
199
- masks = processed_outputs['masks']
200
- if masks.ndim == 4:
201
- masks = masks.squeeze(1)
202
- res_frame = overlay_masks(orig_frame, masks)
203
- else:
204
- res_frame = orig_frame
 
205
 
206
- out.write(cv2.cvtColor(np.array(res_frame), cv2.COLOR_RGB2BGR))
 
207
 
208
- out.release()
209
- return output_path, "Done!"
210
  except Exception as e:
211
- return None, f"Error: {str(e)}"
212
 
213
- # --- Gradio UI ---
214
- css="""
215
- #col-container {
216
- margin: 0 auto;
217
- max-width: 980px;
218
- }
219
- #main-title h1 {font-size: 2.1em !important;}
220
  """
221
 
222
- with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
223
  with gr.Column(elem_id="col-container"):
224
- gr.Markdown(
225
- "# **SAM3 Image & Video Segmentation**",
226
- elem_id="main-title"
227
- )
228
-
229
- gr.Markdown("Segment objects in images or videos using **SAM3** (Segment Anything Model 3) with text prompts.")
230
 
231
  with gr.Tabs():
232
- with gr.TabItem("Image Segmentation"):
233
  with gr.Row():
234
  with gr.Column(scale=1):
235
- input_image = gr.Image(label="Input Image", type="pil", height=300)
236
- text_prompt = gr.Textbox(
237
- label="Text Prompt",
238
- placeholder="e.g., cat, ear, car wheel...",
239
- )
240
 
241
- run_button = gr.Button("Segment Image", variant="primary")
242
 
243
  with gr.Column(scale=1.5):
244
- output_image = gr.AnnotatedImage(label="Segmented Output", height=380)
245
-
246
- with gr.Row():
247
- threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
248
 
249
- gr.Examples(
250
- examples=[
251
- ["examples/player.jpg", "player in white", 0.5],
252
- ["examples/goldencat.webp", "black cat", 0.4],
253
- ["examples/taxi.jpg", "blue taxi", 0.5],
254
- ],
255
- inputs=[input_image, text_prompt, threshold],
256
- outputs=[output_image],
257
- fn=segment_image,
258
- cache_examples="lazy",
259
- label="Image Examples"
260
  )
261
 
262
- with gr.TabItem("Video Segmentation"):
263
  with gr.Row():
264
  with gr.Column():
265
- input_video = gr.Video(label="Input Video", format="mp4")
266
- video_text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g.: person, car")
267
- max_frames_slider = gr.Slider(10, 1000, value=50, step=10, label="Max Frames to Process")
268
- processing_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
269
- start_video_segmentation_button = gr.Button("Start Video Segmentation", variant="primary")
 
 
 
 
270
  with gr.Column():
271
- output_video = gr.Video(label="Result Video")
272
- status_textbox = gr.Textbox(label="Status")
273
 
274
- start_video_segmentation_button.click(
275
- process_video_text,
276
- [input_video, video_text_prompt, max_frames_slider, processing_duration],
277
- [output_video, status_textbox]
278
  )
279
 
280
- run_button.click(
281
- fn=segment_image,
282
- inputs=[input_image, text_prompt, threshold],
283
- outputs=[output_image]
284
- )
285
-
286
  if __name__ == "__main__":
287
- demo.launch(debug=True, show_error=True)
 
1
  import os
2
+ import gc
3
+ import cv2
4
+ import tempfile
5
  import spaces
6
  import gradio as gr
7
  import numpy as np
8
  import torch
9
+ import matplotlib
10
+ import matplotlib.pyplot as plt
11
+ from PIL import Image
12
  from typing import Iterable
13
  from gradio.themes import Soft
14
  from gradio.themes.utils import colors, fonts, sizes
15
+ from transformers import (
16
+ Sam3Model, Sam3Processor,
17
+ Sam3VideoModel, Sam3VideoProcessor
18
+ )
19
 
20
+ # --- THEME CONFIGURATION ---
21
  colors.steel_blue = colors.Color(
22
  name="steel_blue",
23
  c50="#EBF3F8",
 
33
  c950="#1E3450",
34
  )
35
 
36
+ class CustomBlueTheme(Soft):
37
  def __init__(
38
  self,
39
  *,
 
78
  block_label_background_fill="*primary_200",
79
  )
80
 
81
+ app_theme = CustomBlueTheme()
82
 
83
+ # --- MODEL MANAGEMENT & UTILS ---
84
+ MODEL_CACHE = {}
85
  device = "cuda" if torch.cuda.is_available() else "cpu"
86
+ print(f"Using compute device: {device}")
 
 
87
 
88
+ def clear_vram():
89
+ """Forces RAM/VRAM cleanup."""
90
+ if MODEL_CACHE:
91
+ print("🧹 Cleaning up memory...")
92
+ MODEL_CACHE.clear()
93
+ gc.collect()
94
+ torch.cuda.empty_cache()
 
 
 
 
 
 
95
 
96
+ def load_segmentation_model(model_key):
97
+ """Lazy loads the specific SAM3 model required."""
98
+ if model_key in MODEL_CACHE:
99
+ return MODEL_CACHE[model_key]
 
 
100
 
101
+ clear_vram()
102
+ print(f"⏳ Loading {model_key}...")
 
 
 
 
103
 
104
+ try:
105
+ if model_key == "img_seg_model":
106
+ # Using generic internal names
107
+ seg_model = Sam3Model.from_pretrained("facebook/sam3").to(device)
108
+ seg_processor = Sam3Processor.from_pretrained("facebook/sam3")
109
+ MODEL_CACHE[model_key] = (seg_model, seg_processor)
110
+
111
+ elif model_key == "vid_seg_model":
112
+ vid_model = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
113
+ vid_processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
114
+ MODEL_CACHE[model_key] = (vid_model, vid_processor)
115
+
116
+ print(f"✅ {model_key} loaded.")
117
+ return MODEL_CACHE[model_key]
118
 
119
+ except Exception as e:
120
+ print(f"❌ Error loading model: {e}")
121
+ clear_vram()
122
+ raise e
123
+
124
+ def apply_mask_overlay(base_image, mask_data, opacity=0.5):
125
+ """Draws segmentation masks on top of an image."""
126
+ if isinstance(base_image, np.ndarray):
127
+ base_image = Image.fromarray(base_image)
128
+ base_image = base_image.convert("RGBA")
129
+
130
+ if mask_data is None or len(mask_data) == 0:
131
+ return base_image.convert("RGB")
132
 
133
+ if isinstance(mask_data, torch.Tensor):
134
+ mask_data = mask_data.cpu().numpy()
135
+ mask_data = mask_data.astype(np.uint8)
136
+
137
+ # Handle dimensions
138
+ if mask_data.ndim == 4: mask_data = mask_data[0]
139
+ if mask_data.ndim == 3 and mask_data.shape[0] == 1: mask_data = mask_data[0]
140
+
141
+ num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1
142
+ if mask_data.ndim == 2:
143
+ mask_data = [mask_data]
144
+ num_masks = 1
145
 
146
+ try:
147
+ color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1))
148
+ except AttributeError:
149
+ import matplotlib.cm as cm
150
+ color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1))
151
+
152
+ rgb_colors = [tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks)]
153
+ composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0))
154
+
155
+ for i, single_mask in enumerate(mask_data):
156
+ mask_bitmap = Image.fromarray((single_mask * 255).astype(np.uint8))
157
+ if mask_bitmap.size != base_image.size:
158
+ mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST)
159
+
160
+ fill_color = rgb_colors[i]
161
+ color_fill = Image.new("RGBA", base_image.size, fill_color + (0,))
162
+ mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0)
163
+ color_fill.putalpha(mask_alpha)
164
+ composite_layer = Image.alpha_composite(composite_layer, color_fill)
165
+
166
+ return Image.alpha_composite(base_image, composite_layer).convert("RGB")
167
 
 
168
  @spaces.GPU
169
+ def run_image_segmentation(source_img, text_query, conf_thresh=0.5):
170
+ if source_img is None or not text_query:
171
+ raise gr.Error("Please provide an image and a text prompt.")
 
 
172
 
173
  try:
174
+ active_model, active_processor = load_segmentation_model("img_seg_model")
175
+ pil_image = source_img.convert("RGB")
176
+
177
+ model_inputs = active_processor(images=pil_image, text=text_query, return_tensors="pt").to(device)
178
 
179
+ with torch.no_grad():
180
+ inference_output = active_model(**model_inputs)
181
 
182
+ processed_results = active_processor.post_process_instance_segmentation(
183
+ inference_output,
184
+ threshold=conf_thresh,
185
+ mask_threshold=0.5,
186
+ target_sizes=model_inputs.get("original_sizes").tolist()
187
+ )[0]
188
+
189
+ annotation_list = []
190
+ raw_masks = processed_results['masks'].cpu().numpy()
191
+ raw_scores = processed_results['scores'].cpu().numpy()
192
+
193
+ for idx, mask_array in enumerate(raw_masks):
194
+ label_str = f"{text_query} ({raw_scores[idx]:.2f})"
195
+ annotation_list.append((mask_array, label_str))
196
+
197
+ return (pil_image, annotation_list)
198
 
199
+ except Exception as e:
200
+ raise gr.Error(f"Error during image processing: {e}")
 
 
 
 
201
 
202
+ def calc_timeout_duration(vid_file, *args):
203
+ return args[-1] if args else 60
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ @spaces.GPU(duration=calc_timeout_duration)
206
+ def run_video_segmentation(source_vid, text_query, frame_limit, time_limit):
207
+ if not source_vid or not text_query:
208
+ raise gr.Error("Missing video or prompt.")
209
+
210
  try:
211
+ active_model, active_processor = load_segmentation_model("vid_seg_model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ video_cap = cv2.VideoCapture(source_vid)
214
+ vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
215
+ vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
216
+ vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
217
+
218
+ video_frames = []
219
+ counter = 0
220
+ while video_cap.isOpened():
221
+ ret, frame = video_cap.read()
222
+ if not ret or (frame_limit > 0 and counter >= frame_limit): break
223
+ video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
224
+ counter += 1
225
+ video_cap.release()
226
+
227
+ session = active_processor.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
228
+ session = active_processor.add_text_prompt(inference_session=session, text=text_query)
229
+
230
+ temp_out_path = tempfile.mktemp(suffix=".mp4")
231
+ video_writer = cv2.VideoWriter(temp_out_path, cv2.VideoWriter_fourcc(*'mp4v'), vid_fps, (vid_w, vid_h))
232
+
233
+ for model_out in active_model.propagate_in_video_iterator(inference_session=session, max_frame_num_to_track=len(video_frames)):
234
+ post_processed = active_processor.postprocess_outputs(session, model_out)
235
+ f_idx = model_out.frame_idx
236
+ original_pil = Image.fromarray(video_frames[f_idx])
237
 
238
+ if 'masks' in post_processed:
239
+ detected_masks = post_processed['masks']
240
+ if detected_masks.ndim == 4: detected_masks = detected_masks.squeeze(1)
241
+ final_frame = apply_mask_overlay(original_pil, detected_masks)
242
+ else:
243
+ final_frame = original_pil
244
+
245
+ video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
246
 
247
+ video_writer.release()
248
+ return temp_out_path, "Video processing completed successfully."
249
 
 
 
250
  except Exception as e:
251
+ return None, f"Error during video processing: {str(e)}"
252
 
253
+ # --- GUI ---
254
+ custom_css="""
255
+ #col-container { margin: 0 auto; max-width: 1100px; }
256
+ #main-title h1 { font-size: 2.1em !important; }
 
 
 
257
  """
258
 
259
+ with gr.Blocks(css=custom_css, theme=app_theme) as main_interface:
260
  with gr.Column(elem_id="col-container"):
261
+ gr.Markdown("# **SAM3 **", elem_id="main-title")
 
 
 
 
 
262
 
263
  with gr.Tabs():
264
+ with gr.Tab("Image Segmentation"):
265
  with gr.Row():
266
  with gr.Column(scale=1):
267
+ image_input = gr.Image(label="Source Image", type="pil", height=350)
268
+ txt_prompt_img = gr.Textbox(label="Text Description", placeholder="e.g., cat, face, car wheel")
269
+ with gr.Accordion("Advanced Settings", open=False):
270
+ conf_slider = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="Confidence Threshold")
 
271
 
272
+ btn_process_img = gr.Button("Segment Image", variant="primary")
273
 
274
  with gr.Column(scale=1.5):
275
+ image_result = gr.AnnotatedImage(label="Segmented Result", height=450)
 
 
 
276
 
277
+ btn_process_img.click(
278
+ fn=run_image_segmentation,
279
+ inputs=[image_input, txt_prompt_img, conf_slider],
280
+ outputs=[image_result]
 
 
 
 
 
 
 
281
  )
282
 
283
+ with gr.Tab("Video Segmentation"):
284
  with gr.Row():
285
  with gr.Column():
286
+ video_input = gr.Video(label="Source Video", format="mp4")
287
+ txt_prompt_vid = gr.Textbox(label="Text Description", placeholder="e.g., person running, red car")
288
+
289
+ with gr.Row():
290
+ frame_limiter = gr.Slider(10, 500, value=60, step=10, label="Max Frames")
291
+ time_limiter = gr.Radio([60, 120, 180], value=60, label="Timeout (seconds)")
292
+
293
+ btn_process_vid = gr.Button("Segment Video", variant="primary")
294
+
295
  with gr.Column():
296
+ video_result = gr.Video(label="Processed Video")
297
+ process_status = gr.Textbox(label="System Status", interactive=False)
298
 
299
+ btn_process_vid.click(
300
+ run_video_segmentation,
301
+ inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
302
+ outputs=[video_result, process_status]
303
  )
304
 
 
 
 
 
 
 
305
  if __name__ == "__main__":
306
+ main_interface.launch(ssr_mode=False, show_error=True)