peterproofpath commited on
Commit
0bd64d1
·
verified ·
1 Parent(s): ab27c3a

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +599 -0
  2. requirements.txt +8 -0
handler.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints
3
+ Model: facebook/sam3
4
+
5
+ For ProofPath video assessment - text-prompted segmentation to find UI elements.
6
+ Supports text prompts like "Save button", "dropdown menu", "text input field".
7
+
8
+ KEY CAPABILITIES:
9
+ - Text-to-segment: Find ALL instances of a concept (e.g., "button" → all buttons)
10
+ - Promptable Concept Segmentation (PCS): 270K unique concepts
11
+ - Video tracking: Consistent object IDs across frames
12
+ - Presence token: Discriminates similar elements ("player in white" vs "player in red")
13
+
14
+ REQUIREMENTS:
15
+ 1. Set HF_TOKEN environment variable (model is gated)
16
+ 2. Accept license at https://huggingface.co/facebook/sam3
17
+ """
18
+
19
+ from typing import Dict, List, Any, Optional, Union
20
+ import torch
21
+ import numpy as np
22
+ import base64
23
+ import io
24
+ import os
25
+
26
+
27
+ class EndpointHandler:
28
+ def __init__(self, path: str = ""):
29
+ """
30
+ Initialize SAM 3 model for text-prompted segmentation.
31
+
32
+ Args:
33
+ path: Path to the model directory (ignored - we load from HF hub)
34
+ """
35
+ model_id = "facebook/sam3"
36
+
37
+ # Get HF token for gated model access
38
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
39
+
40
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+
42
+ # Import SAM3 components from transformers
43
+ from transformers import Sam3Processor, Sam3Model
44
+
45
+ self.processor = Sam3Processor.from_pretrained(
46
+ model_id,
47
+ token=hf_token,
48
+ )
49
+
50
+ self.model = Sam3Model.from_pretrained(
51
+ model_id,
52
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
53
+ token=hf_token,
54
+ ).to(self.device)
55
+
56
+ self.model.eval()
57
+
58
+ # Also load video model for video segmentation
59
+ self._video_model = None
60
+ self._video_processor = None
61
+
62
+ def _get_video_model(self):
63
+ """Lazy load video model only when needed."""
64
+ if self._video_model is None:
65
+ from transformers import Sam3VideoModel, Sam3VideoProcessor
66
+
67
+ model_id = "facebook/sam3"
68
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
69
+
70
+ self._video_processor = Sam3VideoProcessor.from_pretrained(
71
+ model_id,
72
+ token=hf_token,
73
+ )
74
+
75
+ self._video_model = Sam3VideoModel.from_pretrained(
76
+ model_id,
77
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
78
+ token=hf_token,
79
+ ).to(self.device)
80
+
81
+ self._video_model.eval()
82
+
83
+ return self._video_model, self._video_processor
84
+
85
+ def _load_image(self, image_data: Any):
86
+ """Load image from various formats."""
87
+ from PIL import Image
88
+ import requests
89
+
90
+ if isinstance(image_data, Image.Image):
91
+ return image_data.convert('RGB')
92
+ elif isinstance(image_data, str):
93
+ if image_data.startswith(('http://', 'https://')):
94
+ response = requests.get(image_data, stream=True)
95
+ return Image.open(response.raw).convert('RGB')
96
+ elif image_data.startswith('data:'):
97
+ header, encoded = image_data.split(',', 1)
98
+ image_bytes = base64.b64decode(encoded)
99
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
100
+ else:
101
+ # Assume base64 encoded
102
+ image_bytes = base64.b64decode(image_data)
103
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
104
+ elif isinstance(image_data, bytes):
105
+ return Image.open(io.BytesIO(image_data)).convert('RGB')
106
+ else:
107
+ raise ValueError(f"Unsupported image input type: {type(image_data)}")
108
+
109
+ def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> List:
110
+ """Load video frames from various formats."""
111
+ import cv2
112
+ from PIL import Image
113
+ import tempfile
114
+
115
+ # Decode to temp file if needed
116
+ if isinstance(video_data, str):
117
+ if video_data.startswith(('http://', 'https://')):
118
+ import requests
119
+ response = requests.get(video_data, stream=True)
120
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
121
+ for chunk in response.iter_content(chunk_size=8192):
122
+ f.write(chunk)
123
+ video_path = f.name
124
+ elif video_data.startswith('data:'):
125
+ header, encoded = video_data.split(',', 1)
126
+ video_bytes = base64.b64decode(encoded)
127
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
128
+ f.write(video_bytes)
129
+ video_path = f.name
130
+ else:
131
+ video_bytes = base64.b64decode(video_data)
132
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
133
+ f.write(video_bytes)
134
+ video_path = f.name
135
+ elif isinstance(video_data, bytes):
136
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
137
+ f.write(video_data)
138
+ video_path = f.name
139
+ else:
140
+ raise ValueError(f"Unsupported video input type: {type(video_data)}")
141
+
142
+ try:
143
+ cap = cv2.VideoCapture(video_path)
144
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
145
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
146
+ duration = total_frames / video_fps if video_fps > 0 else 0
147
+
148
+ # Calculate frames to sample
149
+ target_frames = min(max_frames, int(duration * fps), total_frames)
150
+ if target_frames <= 0:
151
+ target_frames = min(max_frames, total_frames)
152
+
153
+ frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int)
154
+
155
+ frames = []
156
+ for idx in frame_indices:
157
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
158
+ ret, frame = cap.read()
159
+ if ret:
160
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
161
+ pil_image = Image.fromarray(frame_rgb)
162
+ frames.append(pil_image)
163
+
164
+ cap.release()
165
+
166
+ metadata = {
167
+ "duration": duration,
168
+ "total_frames": total_frames,
169
+ "sampled_frames": len(frames),
170
+ "video_fps": video_fps
171
+ }
172
+
173
+ return frames, metadata
174
+
175
+ finally:
176
+ if os.path.exists(video_path):
177
+ os.unlink(video_path)
178
+
179
+ def _masks_to_serializable(self, masks: torch.Tensor) -> List[List[List[int]]]:
180
+ """Convert binary masks to RLE or simplified format for JSON serialization."""
181
+ # For efficiency, we'll return bounding box info and optionally compressed masks
182
+ # Full masks can be very large - return as base64 encoded numpy if needed
183
+ masks_np = masks.cpu().numpy().astype(np.uint8)
184
+
185
+ # Return as list of base64-encoded masks
186
+ encoded_masks = []
187
+ for mask in masks_np:
188
+ # Encode each mask as PNG for compression
189
+ from PIL import Image
190
+ img = Image.fromarray(mask * 255)
191
+ buffer = io.BytesIO()
192
+ img.save(buffer, format='PNG')
193
+ encoded = base64.b64encode(buffer.getvalue()).decode('utf-8')
194
+ encoded_masks.append(encoded)
195
+
196
+ return encoded_masks
197
+
198
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
199
+ """
200
+ Process image or video with SAM 3 for text-prompted segmentation.
201
+
202
+ INPUT FORMATS:
203
+
204
+ 1. Single image with text prompt (find all instances):
205
+ {
206
+ "inputs": <image_url_or_base64>,
207
+ "parameters": {
208
+ "prompt": "Save button",
209
+ "threshold": 0.5,
210
+ "mask_threshold": 0.5,
211
+ "return_masks": true
212
+ }
213
+ }
214
+
215
+ 2. Single image with multiple text prompts:
216
+ {
217
+ "inputs": <image_url_or_base64>,
218
+ "parameters": {
219
+ "prompts": ["button", "text field", "dropdown"],
220
+ "threshold": 0.5
221
+ }
222
+ }
223
+
224
+ 3. Single image with box prompts (positive/negative):
225
+ {
226
+ "inputs": <image_url_or_base64>,
227
+ "parameters": {
228
+ "prompt": "handle",
229
+ "boxes": [[40, 183, 318, 204]],
230
+ "box_labels": [0], // 0=negative, 1=positive
231
+ "threshold": 0.5
232
+ }
233
+ }
234
+
235
+ 4. Video with text prompt (track all instances):
236
+ {
237
+ "inputs": <video_url_or_base64>,
238
+ "parameters": {
239
+ "mode": "video",
240
+ "prompt": "Submit button",
241
+ "max_frames": 100,
242
+ "fps": 2.0
243
+ }
244
+ }
245
+
246
+ 5. Batch images:
247
+ {
248
+ "inputs": [<image1>, <image2>, ...],
249
+ "parameters": {
250
+ "prompts": ["ear", "dial"], // One per image
251
+ "threshold": 0.5
252
+ }
253
+ }
254
+
255
+ 6. ProofPath UI element detection:
256
+ {
257
+ "inputs": <screenshot_base64>,
258
+ "parameters": {
259
+ "mode": "ui_elements",
260
+ "elements": ["Save button", "Cancel button", "text input"],
261
+ "threshold": 0.5
262
+ }
263
+ }
264
+
265
+ OUTPUT FORMAT:
266
+ {
267
+ "results": [
268
+ {
269
+ "prompt": "Save button",
270
+ "instances": [
271
+ {
272
+ "box": [x1, y1, x2, y2],
273
+ "score": 0.95,
274
+ "mask": "<base64_png>" // if return_masks=true
275
+ }
276
+ ]
277
+ }
278
+ ],
279
+ "image_size": {"width": 1920, "height": 1080}
280
+ }
281
+ """
282
+ inputs = data.get("inputs")
283
+ params = data.get("parameters", {})
284
+
285
+ if inputs is None:
286
+ raise ValueError("No inputs provided")
287
+
288
+ mode = params.get("mode", "image")
289
+
290
+ if mode == "video":
291
+ return self._process_video(inputs, params)
292
+ elif mode == "ui_elements":
293
+ return self._process_ui_elements(inputs, params)
294
+ elif isinstance(inputs, list):
295
+ return self._process_batch(inputs, params)
296
+ else:
297
+ return self._process_single_image(inputs, params)
298
+
299
+ def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]:
300
+ """Process a single image with text and/or box prompts."""
301
+ image = self._load_image(image_data)
302
+
303
+ threshold = params.get("threshold", 0.5)
304
+ mask_threshold = params.get("mask_threshold", 0.5)
305
+ return_masks = params.get("return_masks", True)
306
+
307
+ # Get prompts
308
+ prompt = params.get("prompt")
309
+ prompts = params.get("prompts", [prompt] if prompt else [])
310
+
311
+ if not prompts:
312
+ raise ValueError("No text prompt(s) provided")
313
+
314
+ # Get optional box prompts
315
+ boxes = params.get("boxes")
316
+ box_labels = params.get("box_labels")
317
+
318
+ results = []
319
+
320
+ for text_prompt in prompts:
321
+ # Prepare inputs
322
+ if boxes is not None:
323
+ input_boxes = [boxes]
324
+ input_boxes_labels = [box_labels] if box_labels else [[1] * len(boxes)]
325
+
326
+ processor_inputs = self.processor(
327
+ images=image,
328
+ text=text_prompt,
329
+ input_boxes=input_boxes,
330
+ input_boxes_labels=input_boxes_labels,
331
+ return_tensors="pt"
332
+ ).to(self.device)
333
+ else:
334
+ processor_inputs = self.processor(
335
+ images=image,
336
+ text=text_prompt,
337
+ return_tensors="pt"
338
+ ).to(self.device)
339
+
340
+ # Run inference
341
+ with torch.no_grad():
342
+ outputs = self.model(**processor_inputs)
343
+
344
+ # Post-process
345
+ post_results = self.processor.post_process_instance_segmentation(
346
+ outputs,
347
+ threshold=threshold,
348
+ mask_threshold=mask_threshold,
349
+ target_sizes=processor_inputs.get("original_sizes").tolist()
350
+ )[0]
351
+
352
+ instances = []
353
+ for i in range(len(post_results.get("boxes", []))):
354
+ instance = {
355
+ "box": post_results["boxes"][i].tolist(),
356
+ "score": float(post_results["scores"][i])
357
+ }
358
+
359
+ if return_masks and "masks" in post_results:
360
+ # Encode mask as base64 PNG
361
+ mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255
362
+ from PIL import Image as PILImage
363
+ mask_img = PILImage.fromarray(mask)
364
+ buffer = io.BytesIO()
365
+ mask_img.save(buffer, format='PNG')
366
+ instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8')
367
+
368
+ instances.append(instance)
369
+
370
+ results.append({
371
+ "prompt": text_prompt,
372
+ "instances": instances,
373
+ "count": len(instances)
374
+ })
375
+
376
+ return {
377
+ "results": results,
378
+ "image_size": {"width": image.width, "height": image.height}
379
+ }
380
+
381
+ def _process_batch(self, images_data: List, params: Dict) -> Dict[str, Any]:
382
+ """Process multiple images with text prompts."""
383
+ images = [self._load_image(img) for img in images_data]
384
+
385
+ prompts = params.get("prompts", [])
386
+ prompt = params.get("prompt")
387
+
388
+ # Handle single prompt for all images
389
+ if prompt and not prompts:
390
+ prompts = [prompt] * len(images)
391
+
392
+ if len(prompts) != len(images):
393
+ raise ValueError(f"Number of prompts ({len(prompts)}) must match number of images ({len(images)})")
394
+
395
+ threshold = params.get("threshold", 0.5)
396
+ mask_threshold = params.get("mask_threshold", 0.5)
397
+ return_masks = params.get("return_masks", False) # Default false for batch
398
+
399
+ # Process batch
400
+ processor_inputs = self.processor(
401
+ images=images,
402
+ text=prompts,
403
+ return_tensors="pt"
404
+ ).to(self.device)
405
+
406
+ with torch.no_grad():
407
+ outputs = self.model(**processor_inputs)
408
+
409
+ # Post-process all results
410
+ all_results = self.processor.post_process_instance_segmentation(
411
+ outputs,
412
+ threshold=threshold,
413
+ mask_threshold=mask_threshold,
414
+ target_sizes=processor_inputs.get("original_sizes").tolist()
415
+ )
416
+
417
+ results = []
418
+ for idx, (post_results, text_prompt, image) in enumerate(zip(all_results, prompts, images)):
419
+ instances = []
420
+ for i in range(len(post_results.get("boxes", []))):
421
+ instance = {
422
+ "box": post_results["boxes"][i].tolist(),
423
+ "score": float(post_results["scores"][i])
424
+ }
425
+
426
+ if return_masks and "masks" in post_results:
427
+ mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255
428
+ from PIL import Image as PILImage
429
+ mask_img = PILImage.fromarray(mask)
430
+ buffer = io.BytesIO()
431
+ mask_img.save(buffer, format='PNG')
432
+ instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8')
433
+
434
+ instances.append(instance)
435
+
436
+ results.append({
437
+ "image_index": idx,
438
+ "prompt": text_prompt,
439
+ "instances": instances,
440
+ "count": len(instances),
441
+ "image_size": {"width": image.width, "height": image.height}
442
+ })
443
+
444
+ return {"results": results}
445
+
446
+ def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]:
447
+ """
448
+ ProofPath-specific mode: Detect multiple UI element types in a screenshot.
449
+ Returns structured data for each element type with bounding boxes.
450
+ """
451
+ image = self._load_image(image_data)
452
+
453
+ elements = params.get("elements", [])
454
+ if not elements:
455
+ # Default UI elements to look for
456
+ elements = ["button", "text input", "dropdown", "checkbox", "link"]
457
+
458
+ threshold = params.get("threshold", 0.5)
459
+ mask_threshold = params.get("mask_threshold", 0.5)
460
+
461
+ all_detections = {}
462
+
463
+ for element_type in elements:
464
+ processor_inputs = self.processor(
465
+ images=image,
466
+ text=element_type,
467
+ return_tensors="pt"
468
+ ).to(self.device)
469
+
470
+ with torch.no_grad():
471
+ outputs = self.model(**processor_inputs)
472
+
473
+ post_results = self.processor.post_process_instance_segmentation(
474
+ outputs,
475
+ threshold=threshold,
476
+ mask_threshold=mask_threshold,
477
+ target_sizes=processor_inputs.get("original_sizes").tolist()
478
+ )[0]
479
+
480
+ detections = []
481
+ for i in range(len(post_results.get("boxes", []))):
482
+ box = post_results["boxes"][i].tolist()
483
+ detections.append({
484
+ "box": box,
485
+ "score": float(post_results["scores"][i]),
486
+ "center": [
487
+ (box[0] + box[2]) / 2,
488
+ (box[1] + box[3]) / 2
489
+ ]
490
+ })
491
+
492
+ all_detections[element_type] = {
493
+ "count": len(detections),
494
+ "instances": detections
495
+ }
496
+
497
+ return {
498
+ "ui_elements": all_detections,
499
+ "image_size": {"width": image.width, "height": image.height},
500
+ "total_elements": sum(d["count"] for d in all_detections.values())
501
+ }
502
+
503
+ def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]:
504
+ """
505
+ Process video with SAM3 Video for text-prompted tracking.
506
+ Tracks all instances of the prompted concept across frames.
507
+ """
508
+ video_model, video_processor = self._get_video_model()
509
+
510
+ prompt = params.get("prompt")
511
+ if not prompt:
512
+ raise ValueError("Text prompt required for video mode")
513
+
514
+ max_frames = params.get("max_frames", 100)
515
+ fps = params.get("fps", 2.0)
516
+
517
+ # Load video frames
518
+ frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
519
+
520
+ if not frames:
521
+ raise ValueError("No frames could be extracted from video")
522
+
523
+ # Initialize video session
524
+ inference_session = video_processor.init_video_session(
525
+ video=frames,
526
+ inference_device=self.device,
527
+ processing_device="cpu",
528
+ video_storage_device="cpu",
529
+ dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
530
+ )
531
+
532
+ # Add text prompt
533
+ inference_session = video_processor.add_text_prompt(
534
+ inference_session=inference_session,
535
+ text=prompt,
536
+ )
537
+
538
+ # Process all frames
539
+ outputs_per_frame = {}
540
+ for model_outputs in video_model.propagate_in_video_iterator(
541
+ inference_session=inference_session,
542
+ max_frame_num_to_track=max_frames
543
+ ):
544
+ processed = video_processor.postprocess_outputs(inference_session, model_outputs)
545
+
546
+ frame_data = {
547
+ "frame_idx": model_outputs.frame_idx,
548
+ "object_ids": processed["object_ids"].tolist() if hasattr(processed["object_ids"], "tolist") else processed["object_ids"],
549
+ "scores": processed["scores"].tolist() if hasattr(processed["scores"], "tolist") else processed["scores"],
550
+ "boxes": processed["boxes"].tolist() if hasattr(processed["boxes"], "tolist") else processed["boxes"],
551
+ }
552
+
553
+ outputs_per_frame[model_outputs.frame_idx] = frame_data
554
+
555
+ # Compile tracking results
556
+ # Group by object_id to show trajectory
557
+ object_tracks = {}
558
+ for frame_idx, frame_data in outputs_per_frame.items():
559
+ for i, obj_id in enumerate(frame_data["object_ids"]):
560
+ obj_id_str = str(obj_id)
561
+ if obj_id_str not in object_tracks:
562
+ object_tracks[obj_id_str] = {
563
+ "object_id": obj_id,
564
+ "frames": []
565
+ }
566
+ object_tracks[obj_id_str]["frames"].append({
567
+ "frame_idx": frame_idx,
568
+ "box": frame_data["boxes"][i] if i < len(frame_data["boxes"]) else None,
569
+ "score": frame_data["scores"][i] if i < len(frame_data["scores"]) else None
570
+ })
571
+
572
+ return {
573
+ "prompt": prompt,
574
+ "video_metadata": video_metadata,
575
+ "frames_processed": len(outputs_per_frame),
576
+ "objects_tracked": len(object_tracks),
577
+ "tracks": list(object_tracks.values()),
578
+ "per_frame_detections": outputs_per_frame
579
+ }
580
+
581
+
582
+ # For testing locally
583
+ if __name__ == "__main__":
584
+ handler = EndpointHandler()
585
+
586
+ # Test with a sample image URL
587
+ test_data = {
588
+ "inputs": "http://images.cocodataset.org/val2017/000000077595.jpg",
589
+ "parameters": {
590
+ "prompt": "ear",
591
+ "threshold": 0.5,
592
+ "return_masks": False
593
+ }
594
+ }
595
+
596
+ result = handler(test_data)
597
+ print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'")
598
+ for inst in result['results'][0]['instances']:
599
+ print(f" Box: {inst['box']}, Score: {inst['score']:.3f}")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # SAM 3 Inference Endpoint Requirements
2
+ transformers>=4.48.0
3
+ torch>=2.7.0
4
+ accelerate>=0.25.0
5
+ Pillow>=9.0.0
6
+ requests>=2.28.0
7
+ numpy>=1.24.0,<2.0.0
8
+ opencv-python-headless>=4.8.0