peterproofpath commited on
Commit
b7720c4
·
verified ·
1 Parent(s): 1169aff

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +134 -289
handler.py CHANGED
@@ -2,6 +2,9 @@
2
  SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints
3
  Model: facebook/sam3
4
 
 
 
 
5
  For ProofPath video assessment - text-prompted segmentation to find UI elements.
6
  Supports text prompts like "Save button", "dropdown menu", "text input field".
7
 
@@ -28,59 +31,31 @@ class EndpointHandler:
28
  def __init__(self, path: str = ""):
29
  """
30
  Initialize SAM 3 model for text-prompted segmentation.
 
31
 
32
  Args:
33
  path: Path to the model directory (ignored - we load from HF hub)
34
  """
35
- model_id = "facebook/sam3"
36
-
37
- # Get HF token for gated model access
38
- hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
39
-
40
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
 
42
- # Import SAM3 components from transformers
43
- from transformers import Sam3Processor, Sam3Model
44
-
45
- self.processor = Sam3Processor.from_pretrained(
46
- model_id,
47
- token=hf_token,
48
- )
49
-
50
- self.model = Sam3Model.from_pretrained(
51
- model_id,
52
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
53
- token=hf_token,
54
- ).to(self.device)
55
 
56
- self.model.eval()
 
 
 
57
 
58
- # Also load video model for video segmentation
59
- self._video_model = None
60
- self._video_processor = None
61
 
62
- def _get_video_model(self):
63
- """Lazy load video model only when needed."""
64
- if self._video_model is None:
65
- from transformers import Sam3VideoModel, Sam3VideoProcessor
66
-
67
- model_id = "facebook/sam3"
68
- hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
69
-
70
- self._video_processor = Sam3VideoProcessor.from_pretrained(
71
- model_id,
72
- token=hf_token,
73
- )
74
-
75
- self._video_model = Sam3VideoModel.from_pretrained(
76
- model_id,
77
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
78
- token=hf_token,
79
- ).to(self.device)
80
-
81
- self._video_model.eval()
82
-
83
- return self._video_model, self._video_processor
84
 
85
  def _load_image(self, image_data: Any):
86
  """Load image from various formats."""
@@ -106,7 +81,7 @@ class EndpointHandler:
106
  else:
107
  raise ValueError(f"Unsupported image input type: {type(image_data)}")
108
 
109
- def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> List:
110
  """Load video frames from various formats."""
111
  import cv2
112
  from PIL import Image
@@ -170,30 +145,12 @@ class EndpointHandler:
170
  "video_fps": video_fps
171
  }
172
 
173
- return frames, metadata
174
 
175
- finally:
176
  if os.path.exists(video_path):
177
  os.unlink(video_path)
178
-
179
- def _masks_to_serializable(self, masks: torch.Tensor) -> List[List[List[int]]]:
180
- """Convert binary masks to RLE or simplified format for JSON serialization."""
181
- # For efficiency, we'll return bounding box info and optionally compressed masks
182
- # Full masks can be very large - return as base64 encoded numpy if needed
183
- masks_np = masks.cpu().numpy().astype(np.uint8)
184
-
185
- # Return as list of base64-encoded masks
186
- encoded_masks = []
187
- for mask in masks_np:
188
- # Encode each mask as PNG for compression
189
- from PIL import Image
190
- img = Image.fromarray(mask * 255)
191
- buffer = io.BytesIO()
192
- img.save(buffer, format='PNG')
193
- encoded = base64.b64encode(buffer.getvalue()).decode('utf-8')
194
- encoded_masks.append(encoded)
195
-
196
- return encoded_masks
197
 
198
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
199
  """
@@ -206,8 +163,6 @@ class EndpointHandler:
206
  "inputs": <image_url_or_base64>,
207
  "parameters": {
208
  "prompt": "Save button",
209
- "threshold": 0.5,
210
- "mask_threshold": 0.5,
211
  "return_masks": true
212
  }
213
  }
@@ -216,49 +171,26 @@ class EndpointHandler:
216
  {
217
  "inputs": <image_url_or_base64>,
218
  "parameters": {
219
- "prompts": ["button", "text field", "dropdown"],
220
- "threshold": 0.5
221
- }
222
- }
223
-
224
- 3. Single image with box prompts (positive/negative):
225
- {
226
- "inputs": <image_url_or_base64>,
227
- "parameters": {
228
- "prompt": "handle",
229
- "boxes": [[40, 183, 318, 204]],
230
- "box_labels": [0], // 0=negative, 1=positive
231
- "threshold": 0.5
232
  }
233
  }
234
 
235
- 4. Video with text prompt (track all instances):
236
  {
237
  "inputs": <video_url_or_base64>,
238
  "parameters": {
239
  "mode": "video",
240
  "prompt": "Submit button",
241
- "max_frames": 100,
242
- "fps": 2.0
243
- }
244
- }
245
-
246
- 5. Batch images:
247
- {
248
- "inputs": [<image1>, <image2>, ...],
249
- "parameters": {
250
- "prompts": ["ear", "dial"], // One per image
251
- "threshold": 0.5
252
  }
253
  }
254
 
255
- 6. ProofPath UI element detection:
256
  {
257
  "inputs": <screenshot_base64>,
258
  "parameters": {
259
  "mode": "ui_elements",
260
- "elements": ["Save button", "Cancel button", "text input"],
261
- "threshold": 0.5
262
  }
263
  }
264
 
@@ -291,17 +223,13 @@ class EndpointHandler:
291
  return self._process_video(inputs, params)
292
  elif mode == "ui_elements":
293
  return self._process_ui_elements(inputs, params)
294
- elif isinstance(inputs, list):
295
- return self._process_batch(inputs, params)
296
  else:
297
  return self._process_single_image(inputs, params)
298
 
299
  def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]:
300
- """Process a single image with text and/or box prompts."""
301
  image = self._load_image(image_data)
302
 
303
- threshold = params.get("threshold", 0.5)
304
- mask_threshold = params.get("mask_threshold", 0.5)
305
  return_masks = params.get("return_masks", True)
306
 
307
  # Get prompts
@@ -311,56 +239,44 @@ class EndpointHandler:
311
  if not prompts:
312
  raise ValueError("No text prompt(s) provided")
313
 
314
- # Get optional box prompts
315
- boxes = params.get("boxes")
316
- box_labels = params.get("box_labels")
317
 
318
  results = []
319
 
320
  for text_prompt in prompts:
321
- # Prepare inputs
322
- if boxes is not None:
323
- input_boxes = [boxes]
324
- input_boxes_labels = [box_labels] if box_labels else [[1] * len(boxes)]
325
-
326
- processor_inputs = self.processor(
327
- images=image,
328
- text=text_prompt,
329
- input_boxes=input_boxes,
330
- input_boxes_labels=input_boxes_labels,
331
- return_tensors="pt"
332
- ).to(self.device)
333
- else:
334
- processor_inputs = self.processor(
335
- images=image,
336
- text=text_prompt,
337
- return_tensors="pt"
338
- ).to(self.device)
339
-
340
- # Run inference
341
- with torch.no_grad():
342
- outputs = self.model(**processor_inputs)
343
 
344
- # Post-process
345
- post_results = self.processor.post_process_instance_segmentation(
346
- outputs,
347
- threshold=threshold,
348
- mask_threshold=mask_threshold,
349
- target_sizes=processor_inputs.get("original_sizes").tolist()
350
- )[0]
351
 
352
  instances = []
353
- for i in range(len(post_results.get("boxes", []))):
 
 
 
 
 
 
 
354
  instance = {
355
- "box": post_results["boxes"][i].tolist(),
356
- "score": float(post_results["scores"][i])
357
  }
358
 
359
- if return_masks and "masks" in post_results:
360
  # Encode mask as base64 PNG
361
- mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255
 
 
 
362
  from PIL import Image as PILImage
363
- mask_img = PILImage.fromarray(mask)
364
  buffer = io.BytesIO()
365
  mask_img.save(buffer, format='PNG')
366
  instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8')
@@ -378,71 +294,6 @@ class EndpointHandler:
378
  "image_size": {"width": image.width, "height": image.height}
379
  }
380
 
381
- def _process_batch(self, images_data: List, params: Dict) -> Dict[str, Any]:
382
- """Process multiple images with text prompts."""
383
- images = [self._load_image(img) for img in images_data]
384
-
385
- prompts = params.get("prompts", [])
386
- prompt = params.get("prompt")
387
-
388
- # Handle single prompt for all images
389
- if prompt and not prompts:
390
- prompts = [prompt] * len(images)
391
-
392
- if len(prompts) != len(images):
393
- raise ValueError(f"Number of prompts ({len(prompts)}) must match number of images ({len(images)})")
394
-
395
- threshold = params.get("threshold", 0.5)
396
- mask_threshold = params.get("mask_threshold", 0.5)
397
- return_masks = params.get("return_masks", False) # Default false for batch
398
-
399
- # Process batch
400
- processor_inputs = self.processor(
401
- images=images,
402
- text=prompts,
403
- return_tensors="pt"
404
- ).to(self.device)
405
-
406
- with torch.no_grad():
407
- outputs = self.model(**processor_inputs)
408
-
409
- # Post-process all results
410
- all_results = self.processor.post_process_instance_segmentation(
411
- outputs,
412
- threshold=threshold,
413
- mask_threshold=mask_threshold,
414
- target_sizes=processor_inputs.get("original_sizes").tolist()
415
- )
416
-
417
- results = []
418
- for idx, (post_results, text_prompt, image) in enumerate(zip(all_results, prompts, images)):
419
- instances = []
420
- for i in range(len(post_results.get("boxes", []))):
421
- instance = {
422
- "box": post_results["boxes"][i].tolist(),
423
- "score": float(post_results["scores"][i])
424
- }
425
-
426
- if return_masks and "masks" in post_results:
427
- mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255
428
- from PIL import Image as PILImage
429
- mask_img = PILImage.fromarray(mask)
430
- buffer = io.BytesIO()
431
- mask_img.save(buffer, format='PNG')
432
- instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8')
433
-
434
- instances.append(instance)
435
-
436
- results.append({
437
- "image_index": idx,
438
- "prompt": text_prompt,
439
- "instances": instances,
440
- "count": len(instances),
441
- "image_size": {"width": image.width, "height": image.height}
442
- })
443
-
444
- return {"results": results}
445
-
446
  def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]:
447
  """
448
  ProofPath-specific mode: Detect multiple UI element types in a screenshot.
@@ -455,38 +306,35 @@ class EndpointHandler:
455
  # Default UI elements to look for
456
  elements = ["button", "text input", "dropdown", "checkbox", "link"]
457
 
458
- threshold = params.get("threshold", 0.5)
459
- mask_threshold = params.get("mask_threshold", 0.5)
460
 
461
  all_detections = {}
462
 
463
  for element_type in elements:
464
- processor_inputs = self.processor(
465
- images=image,
466
- text=element_type,
467
- return_tensors="pt"
468
- ).to(self.device)
469
 
470
- with torch.no_grad():
471
- outputs = self.model(**processor_inputs)
472
 
473
- post_results = self.processor.post_process_instance_segmentation(
474
- outputs,
475
- threshold=threshold,
476
- mask_threshold=mask_threshold,
477
- target_sizes=processor_inputs.get("original_sizes").tolist()
478
- )[0]
479
 
480
  detections = []
481
- for i in range(len(post_results.get("boxes", []))):
482
- box = post_results["boxes"][i].tolist()
483
  detections.append({
484
  "box": box,
485
- "score": float(post_results["scores"][i]),
486
  "center": [
487
  (box[0] + box[2]) / 2,
488
  (box[1] + box[3]) / 2
489
- ]
490
  })
491
 
492
  all_detections[element_type] = {
@@ -503,80 +351,78 @@ class EndpointHandler:
503
  def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]:
504
  """
505
  Process video with SAM3 Video for text-prompted tracking.
506
- Tracks all instances of the prompted concept across frames.
507
  """
508
- video_model, video_processor = self._get_video_model()
509
 
510
  prompt = params.get("prompt")
511
  if not prompt:
512
  raise ValueError("Text prompt required for video mode")
513
 
514
  max_frames = params.get("max_frames", 100)
515
- fps = params.get("fps", 2.0)
516
-
517
- # Load video frames
518
- frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
519
-
520
- if not frames:
521
- raise ValueError("No frames could be extracted from video")
522
-
523
- # Initialize video session
524
- inference_session = video_processor.init_video_session(
525
- video=frames,
526
- inference_device=self.device,
527
- processing_device="cpu",
528
- video_storage_device="cpu",
529
- dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
530
- )
531
-
532
- # Add text prompt
533
- inference_session = video_processor.add_text_prompt(
534
- inference_session=inference_session,
535
- text=prompt,
536
- )
537
-
538
- # Process all frames
539
- outputs_per_frame = {}
540
- for model_outputs in video_model.propagate_in_video_iterator(
541
- inference_session=inference_session,
542
- max_frame_num_to_track=max_frames
543
- ):
544
- processed = video_processor.postprocess_outputs(inference_session, model_outputs)
 
 
 
 
 
 
 
 
 
 
 
545
 
546
- frame_data = {
547
- "frame_idx": model_outputs.frame_idx,
548
- "object_ids": processed["object_ids"].tolist() if hasattr(processed["object_ids"], "tolist") else processed["object_ids"],
549
- "scores": processed["scores"].tolist() if hasattr(processed["scores"], "tolist") else processed["scores"],
550
- "boxes": processed["boxes"].tolist() if hasattr(processed["boxes"], "tolist") else processed["boxes"],
 
 
 
 
 
 
 
 
 
 
551
  }
552
 
553
- outputs_per_frame[model_outputs.frame_idx] = frame_data
554
-
555
- # Compile tracking results
556
- # Group by object_id to show trajectory
557
- object_tracks = {}
558
- for frame_idx, frame_data in outputs_per_frame.items():
559
- for i, obj_id in enumerate(frame_data["object_ids"]):
560
- obj_id_str = str(obj_id)
561
- if obj_id_str not in object_tracks:
562
- object_tracks[obj_id_str] = {
563
- "object_id": obj_id,
564
- "frames": []
565
- }
566
- object_tracks[obj_id_str]["frames"].append({
567
- "frame_idx": frame_idx,
568
- "box": frame_data["boxes"][i] if i < len(frame_data["boxes"]) else None,
569
- "score": frame_data["scores"][i] if i < len(frame_data["scores"]) else None
570
- })
571
-
572
- return {
573
- "prompt": prompt,
574
- "video_metadata": video_metadata,
575
- "frames_processed": len(outputs_per_frame),
576
- "objects_tracked": len(object_tracks),
577
- "tracks": list(object_tracks.values()),
578
- "per_frame_detections": outputs_per_frame
579
- }
580
 
581
 
582
  # For testing locally
@@ -588,7 +434,6 @@ if __name__ == "__main__":
588
  "inputs": "http://images.cocodataset.org/val2017/000000077595.jpg",
589
  "parameters": {
590
  "prompt": "ear",
591
- "threshold": 0.5,
592
  "return_masks": False
593
  }
594
  }
@@ -596,4 +441,4 @@ if __name__ == "__main__":
596
  result = handler(test_data)
597
  print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'")
598
  for inst in result['results'][0]['instances']:
599
- print(f" Box: {inst['box']}, Score: {inst['score']:.3f}")
 
2
  SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints
3
  Model: facebook/sam3
4
 
5
+ Using the official sam3 package from Meta (pip install sam3)
6
+ NOT the transformers integration.
7
+
8
  For ProofPath video assessment - text-prompted segmentation to find UI elements.
9
  Supports text prompts like "Save button", "dropdown menu", "text input field".
10
 
 
31
  def __init__(self, path: str = ""):
32
  """
33
  Initialize SAM 3 model for text-prompted segmentation.
34
+ Uses the official sam3 package from Meta.
35
 
36
  Args:
37
  path: Path to the model directory (ignored - we load from HF hub)
38
  """
 
 
 
 
 
39
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
 
41
+ # Import from official sam3 package
42
+ from sam3.model_builder import build_sam3_image_model
43
+ from sam3.model.sam3_image_processor import Sam3Processor
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # Build model - this downloads from HuggingFace automatically
46
+ # Requires HF_TOKEN for gated model access
47
+ self.model = build_sam3_image_model()
48
+ self.processor = Sam3Processor(self.model)
49
 
50
+ # Video model will be loaded lazily
51
+ self._video_predictor = None
 
52
 
53
+ def _get_video_predictor(self):
54
+ """Lazy load video predictor only when needed."""
55
+ if self._video_predictor is None:
56
+ from sam3.model_builder import build_sam3_video_predictor
57
+ self._video_predictor = build_sam3_video_predictor()
58
+ return self._video_predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def _load_image(self, image_data: Any):
61
  """Load image from various formats."""
 
81
  else:
82
  raise ValueError(f"Unsupported image input type: {type(image_data)}")
83
 
84
+ def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> tuple:
85
  """Load video frames from various formats."""
86
  import cv2
87
  from PIL import Image
 
145
  "video_fps": video_fps
146
  }
147
 
148
+ return video_path, metadata
149
 
150
+ except Exception as e:
151
  if os.path.exists(video_path):
152
  os.unlink(video_path)
153
+ raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
156
  """
 
163
  "inputs": <image_url_or_base64>,
164
  "parameters": {
165
  "prompt": "Save button",
 
 
166
  "return_masks": true
167
  }
168
  }
 
171
  {
172
  "inputs": <image_url_or_base64>,
173
  "parameters": {
174
+ "prompts": ["button", "text field", "dropdown"]
 
 
 
 
 
 
 
 
 
 
 
 
175
  }
176
  }
177
 
178
+ 3. Video with text prompt (track all instances):
179
  {
180
  "inputs": <video_url_or_base64>,
181
  "parameters": {
182
  "mode": "video",
183
  "prompt": "Submit button",
184
+ "max_frames": 100
 
 
 
 
 
 
 
 
 
 
185
  }
186
  }
187
 
188
+ 4. ProofPath UI element detection:
189
  {
190
  "inputs": <screenshot_base64>,
191
  "parameters": {
192
  "mode": "ui_elements",
193
+ "elements": ["Save button", "Cancel button", "text input"]
 
194
  }
195
  }
196
 
 
223
  return self._process_video(inputs, params)
224
  elif mode == "ui_elements":
225
  return self._process_ui_elements(inputs, params)
 
 
226
  else:
227
  return self._process_single_image(inputs, params)
228
 
229
  def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]:
230
+ """Process a single image with text prompts using official sam3 API."""
231
  image = self._load_image(image_data)
232
 
 
 
233
  return_masks = params.get("return_masks", True)
234
 
235
  # Get prompts
 
239
  if not prompts:
240
  raise ValueError("No text prompt(s) provided")
241
 
242
+ # Set the image in processor
243
+ inference_state = self.processor.set_image(image)
 
244
 
245
  results = []
246
 
247
  for text_prompt in prompts:
248
+ # Use official sam3 API
249
+ output = self.processor.set_text_prompt(
250
+ state=inference_state,
251
+ prompt=text_prompt
252
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ masks = output.get("masks", [])
255
+ boxes = output.get("boxes", [])
256
+ scores = output.get("scores", [])
 
 
 
 
257
 
258
  instances = []
259
+
260
+ # Convert tensors to lists
261
+ if hasattr(boxes, 'tolist'):
262
+ boxes = boxes.tolist()
263
+ if hasattr(scores, 'tolist'):
264
+ scores = scores.tolist()
265
+
266
+ for i in range(len(boxes)):
267
  instance = {
268
+ "box": boxes[i] if i < len(boxes) else None,
269
+ "score": float(scores[i]) if i < len(scores) else 0.0
270
  }
271
 
272
+ if return_masks and masks is not None and i < len(masks):
273
  # Encode mask as base64 PNG
274
+ mask = masks[i]
275
+ if hasattr(mask, 'cpu'):
276
+ mask = mask.cpu().numpy()
277
+ mask_uint8 = (mask * 255).astype(np.uint8)
278
  from PIL import Image as PILImage
279
+ mask_img = PILImage.fromarray(mask_uint8)
280
  buffer = io.BytesIO()
281
  mask_img.save(buffer, format='PNG')
282
  instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8')
 
294
  "image_size": {"width": image.width, "height": image.height}
295
  }
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]:
298
  """
299
  ProofPath-specific mode: Detect multiple UI element types in a screenshot.
 
306
  # Default UI elements to look for
307
  elements = ["button", "text input", "dropdown", "checkbox", "link"]
308
 
309
+ # Set the image once
310
+ inference_state = self.processor.set_image(image)
311
 
312
  all_detections = {}
313
 
314
  for element_type in elements:
315
+ output = self.processor.set_text_prompt(
316
+ state=inference_state,
317
+ prompt=element_type
318
+ )
 
319
 
320
+ boxes = output.get("boxes", [])
321
+ scores = output.get("scores", [])
322
 
323
+ if hasattr(boxes, 'tolist'):
324
+ boxes = boxes.tolist()
325
+ if hasattr(scores, 'tolist'):
326
+ scores = scores.tolist()
 
 
327
 
328
  detections = []
329
+ for i in range(len(boxes)):
330
+ box = boxes[i]
331
  detections.append({
332
  "box": box,
333
+ "score": float(scores[i]) if i < len(scores) else 0.0,
334
  "center": [
335
  (box[0] + box[2]) / 2,
336
  (box[1] + box[3]) / 2
337
+ ] if len(box) >= 4 else None
338
  })
339
 
340
  all_detections[element_type] = {
 
351
  def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]:
352
  """
353
  Process video with SAM3 Video for text-prompted tracking.
354
+ Uses the official sam3 video predictor API.
355
  """
356
+ video_predictor = self._get_video_predictor()
357
 
358
  prompt = params.get("prompt")
359
  if not prompt:
360
  raise ValueError("Text prompt required for video mode")
361
 
362
  max_frames = params.get("max_frames", 100)
363
+
364
+ # Load video to temp path
365
+ video_path, video_metadata = self._load_video_frames(video_data, max_frames)
366
+
367
+ try:
368
+ # Start video session
369
+ response = video_predictor.handle_request(
370
+ request=dict(
371
+ type="start_session",
372
+ resource_path=video_path,
373
+ )
374
+ )
375
+ session_id = response.get("session_id")
376
+
377
+ # Add text prompt at frame 0
378
+ response = video_predictor.handle_request(
379
+ request=dict(
380
+ type="add_prompt",
381
+ session_id=session_id,
382
+ frame_index=0,
383
+ text=prompt,
384
+ )
385
+ )
386
+
387
+ output = response.get("outputs", {})
388
+
389
+ # Get tracked objects
390
+ object_ids = output.get("object_ids", [])
391
+ if hasattr(object_ids, 'tolist'):
392
+ object_ids = object_ids.tolist()
393
+
394
+ # Propagate through video
395
+ propagate_response = video_predictor.handle_request(
396
+ request=dict(
397
+ type="propagate",
398
+ session_id=session_id,
399
+ )
400
+ )
401
+
402
+ # Collect results per frame
403
+ per_frame_results = propagate_response.get("per_frame_outputs", {})
404
 
405
+ # Convert to serializable format
406
+ tracks = []
407
+ for obj_id in object_ids:
408
+ track = {
409
+ "object_id": int(obj_id) if hasattr(obj_id, 'item') else obj_id,
410
+ "frames": []
411
+ }
412
+ tracks.append(track)
413
+
414
+ return {
415
+ "prompt": prompt,
416
+ "video_metadata": video_metadata,
417
+ "objects_tracked": len(object_ids),
418
+ "tracks": tracks,
419
+ "session_id": session_id
420
  }
421
 
422
+ finally:
423
+ # Clean up temp file
424
+ if os.path.exists(video_path):
425
+ os.unlink(video_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
 
428
  # For testing locally
 
434
  "inputs": "http://images.cocodataset.org/val2017/000000077595.jpg",
435
  "parameters": {
436
  "prompt": "ear",
 
437
  "return_masks": False
438
  }
439
  }
 
441
  result = handler(test_data)
442
  print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'")
443
  for inst in result['results'][0]['instances']:
444
+ print(f" Box: {inst['box']}, Score: {inst['score']:.3f}")