chris-propeller commited on
Commit
acd640e
·
1 Parent(s): b597179

combine points/boxes/text

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app-bak.py +0 -342
  3. app.py +115 -101
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ __pycache__
app-bak.py DELETED
@@ -1,342 +0,0 @@
1
- import spaces
2
- import gradio as gr
3
- import numpy as np
4
- from PIL import Image
5
- import base64
6
- import io
7
- from typing import Dict, Any
8
- import warnings
9
- warnings.filterwarnings("ignore")
10
-
11
- @spaces.GPU
12
- def sam3_inference(image, text_prompt, confidence_threshold=0.5):
13
- """
14
- Standalone GPU function with model initialization for Spaces Stateless GPU
15
- All CUDA operations and related imports must happen inside this decorated function
16
- """
17
- try:
18
- # Import torch and transformers inside GPU function to avoid main process CUDA init
19
- import torch
20
- from transformers import Sam3Model, Sam3Processor
21
-
22
- # Initialize model and processor inside GPU function (required for Stateless GPU)
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- model = Sam3Model.from_pretrained(
25
- "facebook/sam3",
26
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
27
- ).to(device)
28
- processor = Sam3Processor.from_pretrained("facebook/sam3")
29
- print(f"Model loaded on device: {device}")
30
-
31
- # Handle base64 input (for API)
32
- if isinstance(image, str):
33
- if image.startswith('data:image'):
34
- image = image.split(',')[1]
35
- image_bytes = base64.b64decode(image)
36
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
37
-
38
- # Process with SAM3
39
- inputs = processor(
40
- images=image,
41
- text=text_prompt.strip(),
42
- return_tensors="pt"
43
- ).to(device)
44
-
45
- # Convert dtype to match model
46
- for key in inputs:
47
- if inputs[key].dtype == torch.float32:
48
- inputs[key] = inputs[key].to(model.dtype)
49
-
50
- with torch.no_grad():
51
- outputs = model(**inputs)
52
-
53
- # Use proper SAM3 post-processing
54
- results = processor.post_process_instance_segmentation(
55
- outputs,
56
- threshold=confidence_threshold,
57
- mask_threshold=0.5,
58
- target_sizes=inputs.get("original_sizes").tolist()
59
- )[0]
60
-
61
- return results
62
-
63
- except Exception as e:
64
- raise Exception(f"SAM3 inference error: {str(e)}")
65
-
66
- class SAM3Handler:
67
- """SAM3 handler for both UI and API access"""
68
-
69
- def __init__(self):
70
- print("SAM3 handler initialized (models will be loaded lazily)")
71
-
72
- def predict(self, image, text_prompt, confidence_threshold=0.5):
73
- """
74
- Main prediction function for both UI and API
75
-
76
- Args:
77
- image: PIL Image or base64 string
78
- text_prompt: String describing what to segment
79
- confidence_threshold: Minimum confidence for masks
80
-
81
- Returns:
82
- Dict with masks, scores, and metadata
83
- """
84
- try:
85
- # Call the standalone GPU function
86
- results = sam3_inference(image, text_prompt, confidence_threshold)
87
-
88
- # Prepare response
89
- response = {
90
- "masks": [],
91
- "scores": [],
92
- "prompt_type": "text",
93
- "prompt_value": text_prompt,
94
- "num_masks": len(results["masks"])
95
- }
96
-
97
- # Process each mask
98
- for i in range(len(results["masks"])):
99
- mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
100
- score = results["scores"][i].item()
101
-
102
- if score >= confidence_threshold:
103
- # Convert mask to base64 for API response
104
- mask_image = Image.fromarray(mask_np, mode='L')
105
- buffer = io.BytesIO()
106
- mask_image.save(buffer, format='PNG')
107
- mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
108
-
109
- response["masks"].append(mask_b64)
110
- response["scores"].append(score)
111
-
112
- return response
113
-
114
- except Exception as e:
115
- return {"error": str(e)}
116
-
117
- # Initialize the handler
118
- handler = SAM3Handler()
119
-
120
- def gradio_interface(image, text_prompt, confidence_threshold):
121
- """Gradio interface wrapper"""
122
- result = handler.predict(image, text_prompt, confidence_threshold)
123
-
124
- if "error" in result:
125
- return f"Error: {result['error']}", None
126
-
127
- # For UI, show the first mask as an example
128
- if result["masks"]:
129
- first_mask_b64 = result["masks"][0]
130
- first_score = result["scores"][0]
131
-
132
- # Decode first mask for display
133
- mask_bytes = base64.b64decode(first_mask_b64)
134
- mask_image = Image.open(io.BytesIO(mask_bytes))
135
-
136
- info = f"Found {result['num_masks']} masks. First mask score: {first_score:.3f}"
137
- return info, mask_image
138
- else:
139
- return "No masks found above confidence threshold", None
140
-
141
- def api_predict(data: Dict[str, Any]) -> Dict[str, Any]:
142
- """
143
- API function matching SAM2 inference endpoint format
144
-
145
- Expected input format (matching SAM2 + SAM3 extensions):
146
- {
147
- "inputs": {
148
- "image": "base64_encoded_image_string",
149
-
150
- # SAM3 NEW: Text-based prompts
151
- "text_prompts": ["person", "car"], # List of text descriptions
152
-
153
- # SAM2 compatible: Point-based prompts
154
- "points": [[[x1, y1]], [[x2, y2]]], # Points for each object
155
- "labels": [[1], [1]], # Labels for each point (1=foreground, 0=background)
156
-
157
- # SAM2 compatible: Bounding box prompts
158
- "boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], # Bounding boxes
159
-
160
- "multimask_output": false, # Optional, defaults to False
161
- "confidence_threshold": 0.5 # Optional, minimum confidence for returned masks
162
- }
163
- }
164
-
165
- Returns (matching SAM2 format):
166
- {
167
- "masks": [base64_encoded_mask_1, base64_encoded_mask_2, ...],
168
- "scores": [score1, score2, ...],
169
- "num_objects": int,
170
- "sam_version": "3.0",
171
- "success": true
172
- }
173
- """
174
- try:
175
- inputs_data = data.get("inputs", {})
176
-
177
- # Extract inputs
178
- image_b64 = inputs_data.get("image")
179
- text_prompts = inputs_data.get("text_prompts", [])
180
- input_points = inputs_data.get("points", [])
181
- input_labels = inputs_data.get("labels", [])
182
- input_boxes = inputs_data.get("boxes", [])
183
- multimask_output = inputs_data.get("multimask_output", False)
184
- confidence_threshold = inputs_data.get("confidence_threshold", 0.5)
185
-
186
- # Validate inputs
187
- if not image_b64:
188
- return {"error": "No image provided", "success": False}
189
-
190
- has_text = bool(text_prompts)
191
- has_points = bool(input_points and input_labels)
192
- has_boxes = bool(input_boxes)
193
-
194
- if not (has_text or has_points or has_boxes):
195
- return {"error": "Must provide at least one prompt type: text_prompts, points+labels, or boxes", "success": False}
196
-
197
- if has_points and len(input_points) != len(input_labels):
198
- return {"error": "Number of points and labels must match", "success": False}
199
-
200
- # Decode image
201
- if image_b64.startswith('data:image'):
202
- image_b64 = image_b64.split(',')[1]
203
- image_bytes = base64.b64decode(image_b64)
204
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
205
-
206
- all_masks = []
207
- all_scores = []
208
-
209
- # Process text prompts (SAM3 feature)
210
- if has_text:
211
- for text_prompt in text_prompts:
212
- result = handler.predict(image, text_prompt, confidence_threshold)
213
- if "error" not in result:
214
- all_masks.extend(result["masks"])
215
- all_scores.extend(result["scores"])
216
-
217
- # Process visual prompts (SAM2 compatibility) - Basic implementation
218
- # Note: This is a simplified version. Full SAM2 compatibility would require
219
- # implementing the visual prompt processing in the handler
220
- if has_boxes or has_points:
221
- # For now, fall back to a generic prompt if no text provided
222
- if not has_text:
223
- result = handler.predict(image, "object", confidence_threshold)
224
- if "error" not in result and result["masks"]:
225
- # Take only the number of masks requested
226
- num_requested = len(input_boxes) if has_boxes else len(input_points)
227
- all_masks.extend(result["masks"][:num_requested])
228
- all_scores.extend(result["scores"][:num_requested])
229
-
230
- # Build SAM2-compatible response
231
- return {
232
- "masks": all_masks,
233
- "scores": all_scores,
234
- "num_objects": len(all_masks),
235
- "sam_version": "3.0",
236
- "success": True
237
- }
238
-
239
- except Exception as e:
240
- return {"error": str(e), "success": False, "sam_version": "3.0"}
241
-
242
- # Create Gradio interface
243
- with gr.Blocks(title="SAM3 Inference API") as demo:
244
- gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>")
245
- gr.HTML("<p>This Space provides both a UI and API for SAM3 inference. Use the interface below or call the API programmatically.</p>")
246
-
247
- with gr.Row():
248
- with gr.Column():
249
- image_input = gr.Image(type="pil", label="Input Image")
250
- text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what you want to segment (e.g., 'cat', 'person', 'car')")
251
- confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold")
252
- predict_btn = gr.Button("Segment", variant="primary")
253
-
254
- with gr.Column():
255
- info_output = gr.Textbox(label="Results Info")
256
- mask_output = gr.Image(label="Sample Mask")
257
-
258
- # API endpoint - this creates /api/predict/
259
- predict_btn.click(
260
- gradio_interface,
261
- inputs=[image_input, text_input, confidence_slider],
262
- outputs=[info_output, mask_output],
263
- api_name="predict" # This creates the API endpoint
264
- )
265
-
266
- # SAM2-compatible API endpoint - this creates /api/sam2_compatible/
267
- gr.Interface(
268
- fn=api_predict,
269
- inputs=gr.JSON(label="SAM2/SAM3 Compatible Input"),
270
- outputs=gr.JSON(label="SAM2/SAM3 Compatible Output"),
271
- title="SAM2/SAM3 Compatible API",
272
- description="API endpoint that matches SAM2 inference endpoint format with SAM3 extensions",
273
- api_name="sam2_compatible"
274
- )
275
-
276
- # Add API documentation
277
- gr.HTML("""
278
- <h2>API Usage</h2>
279
-
280
- <h3>1. Simple Text API (Gradio format)</h3>
281
- <pre>
282
- import requests
283
- import base64
284
-
285
- # Encode your image to base64
286
- with open("image.jpg", "rb") as f:
287
- image_b64 = base64.b64encode(f.read()).decode()
288
-
289
- # Make API request
290
- response = requests.post(
291
- "https://your-username-sam3-api.hf.space/api/predict",
292
- json={
293
- "data": [image_b64, "kitten", 0.5]
294
- }
295
- )
296
-
297
- result = response.json()
298
- </pre>
299
-
300
- <h3>2. SAM2/SAM3 Compatible API (Inference Endpoint format)</h3>
301
- <pre>
302
- import requests
303
- import base64
304
-
305
- # Encode your image to base64
306
- with open("image.jpg", "rb") as f:
307
- image_b64 = base64.b64encode(f.read()).decode()
308
-
309
- # SAM3 Text Prompts (NEW)
310
- response = requests.post(
311
- "https://your-username-sam3-api.hf.space/api/sam2_compatible",
312
- json={
313
- "data": [{
314
- "inputs": {
315
- "image": image_b64,
316
- "text_prompts": ["kitten", "toy"],
317
- "confidence_threshold": 0.5
318
- }
319
- }]
320
- }
321
- )
322
-
323
- # SAM2 Compatible (Points/Boxes)
324
- response = requests.post(
325
- "https://your-username-sam3-api.hf.space/api/sam2_compatible",
326
- json={
327
- "data": [{
328
- "inputs": {
329
- "image": image_b64,
330
- "boxes": [[100, 100, 200, 200]],
331
- "confidence_threshold": 0.5
332
- }
333
- }]
334
- }
335
- )
336
-
337
- result = response.json()
338
- </pre>
339
- """)
340
-
341
- if __name__ == "__main__":
342
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -2,10 +2,10 @@ import spaces
2
  import gradio as gr
3
 
4
  @spaces.GPU
5
- def sam3_inference(image, text_prompt=None, boxes=None, box_labels=None, confidence_threshold=0.5):
6
  """
7
  Core SAM3 inference function for Stateless GPU environment
8
- Supports both text prompts and box prompts
9
  Returns raw results for both UI and API use
10
  """
11
  # Import everything inside the GPU function
@@ -18,12 +18,15 @@ def sam3_inference(image, text_prompt=None, boxes=None, box_labels=None, confide
18
 
19
  try:
20
  # Validate inputs
21
- if not text_prompt and not boxes:
22
- raise ValueError("Either text_prompt or boxes must be provided")
23
 
24
  if boxes and not box_labels:
25
  raise ValueError("box_labels must be provided when boxes are specified")
26
 
 
 
 
27
  # Handle base64 input if needed
28
  if isinstance(image, str):
29
  if image.startswith('data:image'):
@@ -59,14 +62,36 @@ def sam3_inference(image, text_prompt=None, boxes=None, box_labels=None, confide
59
  for i, box in enumerate(boxes):
60
  if len(box) == 4: # [x1, y1, x2, y2]
61
  formatted_boxes.append(box)
62
- # Use corresponding label or default to positive (1)
63
- label = box_labels[i] if i < len(box_labels) else 1
64
- formatted_labels.append(label)
 
 
65
 
66
  if formatted_boxes:
67
  processor_kwargs["input_boxes"] = [formatted_boxes]
68
  processor_kwargs["input_boxes_labels"] = [formatted_labels]
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Process input
71
  inputs = processor(**processor_kwargs).to(device)
72
 
@@ -248,96 +273,67 @@ def sam2_compatible_api(data):
248
  all_polygons = []
249
  prompt_types = []
250
 
251
- # Process text prompts (SAM3 feature)
252
  if has_text:
253
  prompt_types.append("text")
254
- for text_prompt in text_prompts:
255
- results = sam3_inference(image, text_prompt=text_prompt, confidence_threshold=confidence_threshold)
256
- if results and len(results["masks"]) > 0:
257
- for i in range(len(results["masks"])):
258
- mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
259
- score = results["scores"][i].item()
260
-
261
- if score >= confidence_threshold:
262
- # Convert mask to base64
263
- mask_image = Image.fromarray(mask_np, mode='L')
264
- buffer = io.BytesIO()
265
- mask_image.save(buffer, format='PNG')
266
- mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
267
-
268
- all_masks.append(mask_b64)
269
- all_scores.append(score)
270
-
271
- # Extract polygons if vectorize is enabled
272
- if vectorize:
273
- binary_mask = (mask_np > 0).astype(np.uint8)
274
- polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon)
275
- all_polygons.append(polygons)
276
-
277
- # Process visual prompts (SAM2 compatibility) - Now properly implemented
278
- if has_boxes:
279
  prompt_types.append("visual")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # Create box labels (default to positive boxes if not provided)
281
- box_labels = inputs_data.get("box_labels", [1] * len(input_boxes))
282
-
283
- # Process boxes using SAM3's native box support
284
- results = sam3_inference(
285
- image=image,
286
- text_prompt=None,
287
- boxes=input_boxes,
288
- box_labels=box_labels,
289
- confidence_threshold=confidence_threshold
290
- )
291
-
292
- if results and len(results["masks"]) > 0:
293
- for i in range(len(results["masks"])):
294
- mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
295
- score = results["scores"][i].item()
296
-
297
- if score >= confidence_threshold:
298
- # Convert mask to base64
299
- mask_image = Image.fromarray(mask_np, mode='L')
300
- buffer = io.BytesIO()
301
- mask_image.save(buffer, format='PNG')
302
- mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
303
-
304
- all_masks.append(mask_b64)
305
- all_scores.append(score)
306
-
307
- # Extract polygons if vectorize is enabled
308
- if vectorize:
309
- binary_mask = (mask_np > 0).astype(np.uint8)
310
- polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon)
311
- all_polygons.append(polygons)
312
-
313
- # Process point prompts (SAM2 compatibility) - Fallback implementation
314
- elif has_points and not has_text:
315
- prompt_types.append("visual")
316
- # For point prompts, use a generic prompt to get masks (SAM3 doesn't natively support points)
317
- # This is a fallback - true SAM2 compatibility would require point prompt support
318
- results = sam3_inference(image, text_prompt="object", confidence_threshold=confidence_threshold)
319
- if results and len(results["masks"]) > 0:
320
- # Take only the number of masks requested
321
- num_requested = len(input_points)
322
- for i in range(min(num_requested, len(results["masks"]))):
323
- mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
324
- score = results["scores"][i].item()
325
-
326
- if score >= confidence_threshold:
327
- # Convert mask to base64
328
- mask_image = Image.fromarray(mask_np, mode='L')
329
- buffer = io.BytesIO()
330
- mask_image.save(buffer, format='PNG')
331
- mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
332
-
333
- all_masks.append(mask_b64)
334
- all_scores.append(score)
335
-
336
- # Extract polygons if vectorize is enabled
337
- if vectorize:
338
- binary_mask = (mask_np > 0).astype(np.uint8)
339
- polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon)
340
- all_polygons.append(polygons)
341
 
342
  # Build SAM2-compatible response
343
  response = {
@@ -451,7 +447,7 @@ import base64
451
  with open("image.jpg", "rb") as f:
452
  image_b64 = base64.b64encode(f.read()).decode()
453
 
454
- # SAM3 Text Prompts (NEW)
455
  response = requests.post(
456
  "https://your-username-sam3-api.hf.space/api/sam2_compatible",
457
  json={
@@ -463,13 +459,30 @@ response = requests.post(
463
  }
464
  )
465
 
466
- # SAM2 Compatible (Points/Boxes)
467
  response = requests.post(
468
  "https://your-username-sam3-api.hf.space/api/sam2_compatible",
469
  json={
470
  "inputs": {
471
  "image": image_b64,
472
  "boxes": [[100, 100, 200, 200]],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  "confidence_threshold": 0.5
474
  }
475
  }
@@ -499,15 +512,16 @@ result = response.json()
499
  "inputs": {
500
  "image": "base64_encoded_image_string",
501
 
502
- // SAM3 NEW: Text-based prompts
503
  "text_prompts": ["person", "car"], // List of text descriptions
504
 
505
- // SAM2 COMPATIBLE: Point-based prompts
506
- "points": [[[x1, y1]], [[x2, y2]]], // Points for each object
507
- "labels": [[1], [1]], // Labels for each point (1=foreground, 0=background)
508
 
509
- // SAM2 COMPATIBLE: Bounding box prompts
510
- "boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], // Bounding boxes
 
511
 
512
  "multimask_output": false, // Optional, defaults to False
513
  "confidence_threshold": 0.5, // Optional, minimum confidence for returned masks
 
2
  import gradio as gr
3
 
4
  @spaces.GPU
5
+ def sam3_inference(image, text_prompt=None, boxes=None, box_labels=None, points=None, point_labels=None, confidence_threshold=0.5):
6
  """
7
  Core SAM3 inference function for Stateless GPU environment
8
+ Supports text prompts, box prompts, and point prompts (individually or combined)
9
  Returns raw results for both UI and API use
10
  """
11
  # Import everything inside the GPU function
 
18
 
19
  try:
20
  # Validate inputs
21
+ if not text_prompt and not boxes and not points:
22
+ raise ValueError("At least one of text_prompt, boxes, or points must be provided")
23
 
24
  if boxes and not box_labels:
25
  raise ValueError("box_labels must be provided when boxes are specified")
26
 
27
+ if points and not point_labels:
28
+ raise ValueError("point_labels must be provided when points are specified")
29
+
30
  # Handle base64 input if needed
31
  if isinstance(image, str):
32
  if image.startswith('data:image'):
 
62
  for i, box in enumerate(boxes):
63
  if len(box) == 4: # [x1, y1, x2, y2]
64
  formatted_boxes.append(box)
65
+ # Use the provided label (supports both positive=1 and negative=0)
66
+ if i < len(box_labels):
67
+ formatted_labels.append(box_labels[i])
68
+ else:
69
+ raise ValueError(f"Missing label for box {i}")
70
 
71
  if formatted_boxes:
72
  processor_kwargs["input_boxes"] = [formatted_boxes]
73
  processor_kwargs["input_boxes_labels"] = [formatted_labels]
74
 
75
+ # Add point prompts if provided
76
+ if points and point_labels:
77
+ # Convert points to expected format: [[[x1, y1], [x2, y2]], ...]
78
+ # SAM3 expects points as nested lists for batch processing
79
+ formatted_points = []
80
+ formatted_point_labels = []
81
+
82
+ for i, point in enumerate(points):
83
+ if len(point) == 2: # [x, y]
84
+ formatted_points.append(point)
85
+ # Use the provided label (supports both positive=1 and negative=0)
86
+ if i < len(point_labels):
87
+ formatted_point_labels.append(point_labels[i])
88
+ else:
89
+ raise ValueError(f"Missing label for point {i}")
90
+
91
+ if formatted_points:
92
+ processor_kwargs["input_points"] = [formatted_points]
93
+ processor_kwargs["input_points_labels"] = [formatted_point_labels]
94
+
95
  # Process input
96
  inputs = processor(**processor_kwargs).to(device)
97
 
 
273
  all_polygons = []
274
  prompt_types = []
275
 
276
+ # Determine what prompt types are being used
277
  if has_text:
278
  prompt_types.append("text")
279
+ if has_points or has_boxes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  prompt_types.append("visual")
281
+
282
+ # Prepare inputs for combined SAM3 inference call
283
+ combined_text_prompt = None
284
+ combined_boxes = None
285
+ combined_box_labels = None
286
+ combined_points = None
287
+ combined_point_labels = None
288
+
289
+ # Handle text prompts - combine multiple text prompts into one
290
+ if has_text:
291
+ # For multiple text prompts, join them (SAM3 can handle combined descriptions)
292
+ combined_text_prompt = ", ".join(text_prompts)
293
+
294
+ # Handle box prompts
295
+ if has_boxes:
296
+ combined_boxes = input_boxes
297
  # Create box labels (default to positive boxes if not provided)
298
+ combined_box_labels = inputs_data.get("box_labels", [1] * len(input_boxes))
299
+
300
+ # Handle point prompts
301
+ if has_points:
302
+ combined_points = input_points
303
+ combined_point_labels = input_labels
304
+
305
+ # Make single combined inference call with all prompt types
306
+ results = sam3_inference(
307
+ image=image,
308
+ text_prompt=combined_text_prompt,
309
+ boxes=combined_boxes,
310
+ box_labels=combined_box_labels,
311
+ points=combined_points,
312
+ point_labels=combined_point_labels,
313
+ confidence_threshold=confidence_threshold
314
+ )
315
+
316
+ # Process results
317
+ if results and len(results["masks"]) > 0:
318
+ for i in range(len(results["masks"])):
319
+ mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
320
+ score = results["scores"][i].item()
321
+
322
+ if score >= confidence_threshold:
323
+ # Convert mask to base64
324
+ mask_image = Image.fromarray(mask_np, mode='L')
325
+ buffer = io.BytesIO()
326
+ mask_image.save(buffer, format='PNG')
327
+ mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
328
+
329
+ all_masks.append(mask_b64)
330
+ all_scores.append(score)
331
+
332
+ # Extract polygons if vectorize is enabled
333
+ if vectorize:
334
+ binary_mask = (mask_np > 0).astype(np.uint8)
335
+ polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon)
336
+ all_polygons.append(polygons)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  # Build SAM2-compatible response
339
  response = {
 
447
  with open("image.jpg", "rb") as f:
448
  image_b64 = base64.b64encode(f.read()).decode()
449
 
450
+ # SAM3 Text Prompts Only
451
  response = requests.post(
452
  "https://your-username-sam3-api.hf.space/api/sam2_compatible",
453
  json={
 
459
  }
460
  )
461
 
462
+ # SAM2 Compatible (Points/Boxes Only)
463
  response = requests.post(
464
  "https://your-username-sam3-api.hf.space/api/sam2_compatible",
465
  json={
466
  "inputs": {
467
  "image": image_b64,
468
  "boxes": [[100, 100, 200, 200]],
469
+ "box_labels": [1], # 1=positive, 0=negative
470
+ "confidence_threshold": 0.5
471
+ }
472
+ }
473
+ )
474
+
475
+ # SAM3 Combined Prompts (Text + Visual) - NEW CAPABILITY!
476
+ response = requests.post(
477
+ "https://your-username-sam3-api.hf.space/api/sam2_compatible",
478
+ json={
479
+ "inputs": {
480
+ "image": image_b64,
481
+ "text_prompts": ["cat"], # Text description
482
+ "boxes": [[50, 50, 150, 150]], # Bounding box
483
+ "box_labels": [0], # 0=negative (exclude this area)
484
+ "points": [[200, 200]], # Point prompt
485
+ "labels": [1], # 1=positive point
486
  "confidence_threshold": 0.5
487
  }
488
  }
 
512
  "inputs": {
513
  "image": "base64_encoded_image_string",
514
 
515
+ // SAM3 NEW: Text-based prompts (can be combined with visual prompts)
516
  "text_prompts": ["person", "car"], // List of text descriptions
517
 
518
+ // SAM2 COMPATIBLE: Point-based prompts (can be combined with text/boxes)
519
+ "points": [[x1, y1], [x2, y2]], // Individual points (not nested arrays)
520
+ "labels": [1, 0], // Labels for each point (1=positive/foreground, 0=negative/background)
521
 
522
+ // SAM2 COMPATIBLE: Bounding box prompts (can be combined with text/points)
523
+ "boxes": [[x1, y1, x2, y2], [x3, y3, x4, y4]], // Bounding boxes
524
+ "box_labels": [1, 0], // Labels for each box (1=positive, 0=negative/exclude)
525
 
526
  "multimask_output": false, // Optional, defaults to False
527
  "confidence_threshold": 0.5, // Optional, minimum confidence for returned masks