prithivMLmods commited on
Commit
a3f9b81
·
verified ·
1 Parent(s): 79e7003

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -54
app.py CHANGED
@@ -8,8 +8,11 @@ from PIL import Image, ImageDraw
8
  from typing import Iterable
9
  from gradio.themes import Soft
10
  from gradio.themes.utils import colors, fonts, sizes
11
- from transformers import Sam3Processor, Sam3Model
 
 
12
 
 
13
  colors.steel_blue = colors.Color(
14
  name="steel_blue",
15
  c50="#EBF3F8",
@@ -72,22 +75,55 @@ class SteelBlueTheme(Soft):
72
 
73
  steel_blue_theme = SteelBlueTheme()
74
 
 
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
76
  print(f"Using device: {device}")
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  try:
79
- print("Loading SAM3 Model and Processor...")
80
- model = Sam3Model.from_pretrained("facebook/sam3").to(device)
81
- processor = Sam3Processor.from_pretrained("facebook/sam3")
82
- print("Model loaded successfully.")
83
-
84
  except Exception as e:
85
- print(f"Error loading model: {e}")
86
  print("Ensure you have the correct libraries installed and access to the model.")
87
- # Fallback/Placeholder for demonstration if model doesn't exist in environment yet
88
- model = None
89
- processor = None
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  @spaces.GPU
92
  def segment_image(input_image, text_prompt, threshold=0.5):
93
  if input_image is None:
@@ -95,20 +131,17 @@ def segment_image(input_image, text_prompt, threshold=0.5):
95
  if not text_prompt:
96
  raise gr.Error("Please enter a text prompt (e.g., 'cat', 'face').")
97
 
98
- if model is None or processor is None:
99
- raise gr.Error("Model not loaded correctly.")
 
 
100
 
101
- # Convert image to RGB
102
  image_pil = input_image.convert("RGB")
103
-
104
- # Preprocess
105
  inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
106
 
107
- # Inference
108
  with torch.no_grad():
109
  outputs = model(**inputs)
110
 
111
- # Post-process results
112
  results = processor.post_process_instance_segmentation(
113
  outputs,
114
  threshold=threshold,
@@ -116,27 +149,67 @@ def segment_image(input_image, text_prompt, threshold=0.5):
116
  target_sizes=inputs.get("original_sizes").tolist()
117
  )[0]
118
 
119
- masks = results['masks'] # Boolean tensor [N, H, W]
120
  scores = results['scores']
121
 
122
- # Prepare for Gradio AnnotatedImage
123
- # Gradio expects (image, [(mask, label), ...])
124
-
125
  annotations = []
126
  masks_np = masks.cpu().numpy()
127
  scores_np = scores.cpu().numpy()
128
 
129
  for i, mask in enumerate(masks_np):
130
- # mask is a boolean array (True/False).
131
- # AnnotatedImage handles the coloring automatically.
132
- # We just pass the mask and a label.
133
  score_val = scores_np[i]
134
  label = f"{text_prompt} ({score_val:.2f})"
135
  annotations.append((mask, label))
136
 
137
- # Return tuple format for AnnotatedImage
138
  return (image_pil, annotations)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  css="""
141
  #col-container {
142
  margin: 0 auto;
@@ -148,40 +221,60 @@ css="""
148
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
149
  with gr.Column(elem_id="col-container"):
150
  gr.Markdown(
151
- "# **SAM3 Image Segmentation**",
152
  elem_id="main-title"
153
  )
154
 
155
- gr.Markdown("Segment objects in images using **SAM3** (Segment Anything Model 3) with text prompts.")
156
-
157
- with gr.Row():
158
- with gr.Column(scale=1):
159
- input_image = gr.Image(label="Input Image", type="pil", height=300)
160
- text_prompt = gr.Textbox(
161
- label="Text Prompt",
162
- placeholder="e.g., cat, ear, car wheel...",
163
- )
164
-
165
- run_button = gr.Button("Segment", variant="primary")
 
 
166
 
167
- with gr.Column(scale=1.5):
168
- output_image = gr.AnnotatedImage(label="Segmented Output", height=380)
 
 
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  with gr.Row():
171
- threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
172
-
173
- gr.Examples(
174
- examples=[
175
- ["examples/player.jpg", "player in white", 0.5],
176
- ["examples/goldencat.webp", "black cat", 0.4],
177
- ["examples/taxi.jpg", "blue taxi", 0.5],
178
- ],
179
- inputs=[input_image, text_prompt, threshold],
180
- outputs=[output_image],
181
- fn=segment_image,
182
- cache_examples="lazy",
183
- label="Examples"
184
- )
 
185
 
186
  run_button.click(
187
  fn=segment_image,
@@ -190,4 +283,4 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
190
  )
191
 
192
  if __name__ == "__main__":
193
- demo.launch(mcp_server=True, ssr_mode=False, show_error=True)
 
8
  from typing import Iterable
9
  from gradio.themes import Soft
10
  from gradio.themes.utils import colors, fonts, sizes
11
+ from transformers import Sam3Processor, Sam3Model, Sam3VideoModel, Sam3VideoProcessor
12
+ import cv2
13
+ import tempfile
14
 
15
+ # --- Theme Definition ---
16
  colors.steel_blue = colors.Color(
17
  name="steel_blue",
18
  c50="#EBF3F8",
 
75
 
76
  steel_blue_theme = SteelBlueTheme()
77
 
78
+ # --- Model Loading ---
79
  device = "cuda" if torch.cuda.is_available() else "cpu"
80
  print(f"Using device: {device}")
81
 
82
+ MODELS = {}
83
+
84
+ def get_model(model_type):
85
+ if model_type not in MODELS:
86
+ if model_type == "sam3_image":
87
+ print("Loading SAM3 Image Model and Processor...")
88
+ model = Sam3Model.from_pretrained("facebook/sam3").to(device)
89
+ processor = Sam3Processor.from_pretrained("facebook/sam3")
90
+ MODELS[model_type] = (model, processor)
91
+ elif model_type == "sam3_video_text":
92
+ print("Loading SAM3 Video Model and Processor...")
93
+ model = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
94
+ processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
95
+ MODELS[model_type] = (model, processor)
96
+ return MODELS[model_type]
97
+
98
  try:
99
+ get_model("sam3_image")
100
+ print("Image model loaded successfully.")
 
 
 
101
  except Exception as e:
102
+ print(f"Error loading image model: {e}")
103
  print("Ensure you have the correct libraries installed and access to the model.")
 
 
 
104
 
105
+ # --- Helper Functions ---
106
+ def overlay_masks(image, masks, alpha=0.5):
107
+ """ Overlays masks on the image with random colors. """
108
+ image = image.convert("RGBA")
109
+ overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
110
+ draw = ImageDraw.Draw(overlay)
111
+
112
+ for mask in masks:
113
+ # Generate a random color for each mask
114
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), int(255 * alpha))
115
+
116
+ # Convert boolean mask to an image that can be pasted
117
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
118
+
119
+ # Draw the colored mask
120
+ draw.bitmap((0, 0), mask_pil, fill=color)
121
+
122
+ # Combine the original image with the overlay
123
+ combined = Image.alpha_composite(image, overlay)
124
+ return combined.convert("RGB")
125
+
126
+ # --- Core Functions ---
127
  @spaces.GPU
128
  def segment_image(input_image, text_prompt, threshold=0.5):
129
  if input_image is None:
 
131
  if not text_prompt:
132
  raise gr.Error("Please enter a text prompt (e.g., 'cat', 'face').")
133
 
134
+ try:
135
+ model, processor = get_model("sam3_image")
136
+ except Exception as e:
137
+ raise gr.Error(f"Model not loaded correctly: {e}")
138
 
 
139
  image_pil = input_image.convert("RGB")
 
 
140
  inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
141
 
 
142
  with torch.no_grad():
143
  outputs = model(**inputs)
144
 
 
145
  results = processor.post_process_instance_segmentation(
146
  outputs,
147
  threshold=threshold,
 
149
  target_sizes=inputs.get("original_sizes").tolist()
150
  )[0]
151
 
152
+ masks = results['masks']
153
  scores = results['scores']
154
 
 
 
 
155
  annotations = []
156
  masks_np = masks.cpu().numpy()
157
  scores_np = scores.cpu().numpy()
158
 
159
  for i, mask in enumerate(masks_np):
 
 
 
160
  score_val = scores_np[i]
161
  label = f"{text_prompt} ({score_val:.2f})"
162
  annotations.append((mask, label))
163
 
 
164
  return (image_pil, annotations)
165
 
166
+ def process_video_text(video_path, text_prompt, max_frames, timeout_seconds):
167
+ if not video_path or not text_prompt:
168
+ return None, "Missing video or prompt."
169
+ try:
170
+ model, processor = get_model("sam3_video_text")
171
+ cap = cv2.VideoCapture(video_path)
172
+ fps = cap.get(cv2.CAP_PROP_FPS)
173
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
174
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
175
+ frames = []
176
+ frame_count = 0
177
+ while cap.isOpened():
178
+ ret, frame = cap.read()
179
+ if not ret or (max_frames > 0 and frame_count >= max_frames):
180
+ break
181
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
182
+ frame_count += 1
183
+ cap.release()
184
+
185
+ inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16)
186
+ inference_session = processor.add_text_prompt(inference_session=inference_session, text=text_prompt)
187
+
188
+ output_path = tempfile.mktemp(suffix=".mp4")
189
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
190
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
191
+
192
+ for model_outputs in model.propagate_in_video_iterator(inference_session=inference_session, max_frame_num_to_track=len(frames)):
193
+ processed_outputs = processor.postprocess_outputs(inference_session, model_outputs)
194
+ frame_idx = model_outputs.frame_idx
195
+ orig_frame = Image.fromarray(frames[frame_idx])
196
+
197
+ if 'masks' in processed_outputs:
198
+ masks = processed_outputs['masks']
199
+ if masks.ndim == 4:
200
+ masks = masks.squeeze(1)
201
+ res_frame = overlay_masks(orig_frame, masks)
202
+ else:
203
+ res_frame = orig_frame
204
+
205
+ out.write(cv2.cvtColor(np.array(res_frame), cv2.COLOR_RGB2BGR))
206
+
207
+ out.release()
208
+ return output_path, "Done!"
209
+ except Exception as e:
210
+ return None, f"Error: {str(e)}"
211
+
212
+ # --- Gradio UI ---
213
  css="""
214
  #col-container {
215
  margin: 0 auto;
 
221
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
222
  with gr.Column(elem_id="col-container"):
223
  gr.Markdown(
224
+ "# **SAM3 Image & Video Segmentation**",
225
  elem_id="main-title"
226
  )
227
 
228
+ gr.Markdown("Segment objects in images or videos using **SAM3** (Segment Anything Model 3) with text prompts.")
229
+
230
+ with gr.Tabs():
231
+ with gr.TabItem("Image Segmentation"):
232
+ with gr.Row():
233
+ with gr.Column(scale=1):
234
+ input_image = gr.Image(label="Input Image", type="pil", height=300)
235
+ text_prompt = gr.Textbox(
236
+ label="Text Prompt",
237
+ placeholder="e.g., cat, ear, car wheel...",
238
+ )
239
+
240
+ run_button = gr.Button("Segment Image", variant="primary")
241
 
242
+ with gr.Column(scale=1.5):
243
+ output_image = gr.AnnotatedImage(label="Segmented Output", height=380)
244
+
245
+ with gr.Row():
246
+ threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
247
 
248
+ gr.Examples(
249
+ examples=[
250
+ ["examples/player.jpg", "player in white", 0.5],
251
+ ["examples/goldencat.webp", "black cat", 0.4],
252
+ ["examples/taxi.jpg", "blue taxi", 0.5],
253
+ ],
254
+ inputs=[input_image, text_prompt, threshold],
255
+ outputs=[output_image],
256
+ fn=segment_image,
257
+ cache_examples="lazy",
258
+ label="Image Examples"
259
+ )
260
+
261
+ with gr.TabItem("Video Segmentation"):
262
  with gr.Row():
263
+ with gr.Column():
264
+ input_video = gr.Video(label="Input Video", format="mp4")
265
+ video_text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g.: person, car")
266
+ max_frames_slider = gr.Slider(10, 1000, value=50, step=10, label="Max Frames to Process")
267
+ processing_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
268
+ start_video_segmentation_button = gr.Button("Start Video Segmentation", variant="primary")
269
+ with gr.Column():
270
+ output_video = gr.Video(label="Result Video")
271
+ status_textbox = gr.Textbox(label="Status")
272
+
273
+ start_video_segmentation_button.click(
274
+ process_video_text,
275
+ [input_video, video_text_prompt, max_frames_slider, processing_duration],
276
+ [output_video, status_textbox]
277
+ )
278
 
279
  run_button.click(
280
  fn=segment_image,
 
283
  )
284
 
285
  if __name__ == "__main__":
286
+ demo.launch(debug=True, show_error=True)