Zhen Ye commited on
Commit
5c36daa
·
1 Parent(s): 1eea4fe

Implement Batch Inference & Queue Backpressure Fixes

Browse files
inference.py CHANGED
@@ -22,6 +22,7 @@ from models.detectors.base import ObjectDetector
22
  from models.model_loader import load_detector, load_detector_on_device
23
  from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
24
  from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
 
25
  from utils.video import extract_frames, write_video, VideoReader, VideoWriter
26
  from utils.gpt_distance import estimate_distance_gpt
27
  import tempfile
@@ -352,6 +353,149 @@ def infer_frame(
352
  ), detections
353
 
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  def infer_segmentation_frame(
356
  frame: np.ndarray,
357
  text_queries: Optional[List[str]] = None,
@@ -557,16 +701,45 @@ def run_inference(
557
  # queue_in: (frame_idx, frame_data)
558
  # queue_out: (frame_idx, processed_frame, detections)
559
  queue_in = Queue(maxsize=16)
560
- queue_out = Queue() # Unbounded, consumed by writer
 
 
 
561
 
562
  # 5. Worker Function
563
  def worker_task(gpu_idx: int):
564
  detector_instance = detectors[gpu_idx]
565
  depth_instance = depth_estimators[gpu_idx] if depth_estimators[gpu_idx] else None
566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  while True:
568
  item = queue_in.get()
569
  if item is None:
 
570
  queue_in.task_done()
571
  break
572
 
@@ -576,24 +749,22 @@ def run_inference(
576
  logging.debug("Processing frame %d on device %s", frame_idx, "cpu" if num_gpus==0 else f"cuda:{gpu_idx}")
577
 
578
  try:
579
- # Depth strategy: Run every 3 frames
580
- active_depth_name = depth_estimator_name if (frame_idx % 3 == 0) else None
581
- active_depth_instance = depth_instance if (frame_idx % 3 == 0) else None
582
-
583
- processed, frame_dets = infer_frame(
584
- frame_data,
585
- queries,
586
- detector_name=None,
587
- depth_estimator_name=active_depth_name,
588
- depth_scale=depth_scale,
589
- detector_instance=detector_instance,
590
- depth_estimator_instance=active_depth_instance
591
- )
592
- queue_out.put((frame_idx, processed, frame_dets))
593
  except Exception as e:
594
- logging.exception("Error processing frame %d", frame_idx)
595
- # Put placeholders to avoid hanging writer
596
- queue_out.put((frame_idx, frame_data, []))
 
 
 
 
 
 
 
 
 
597
 
598
  queue_in.task_done()
599
 
@@ -626,7 +797,18 @@ def run_inference(
626
  # Fetch from queue
627
  try:
628
  while next_idx not in buffer:
 
 
 
 
 
 
 
 
 
 
629
  item = queue_out.get(timeout=1.0) # wait
 
630
  idx, p_frame, dets = item
631
  buffer[idx] = (p_frame, dets)
632
 
@@ -763,33 +945,68 @@ def run_segmentation(
763
 
764
  # 3. Processing
765
  queue_in = Queue(maxsize=16)
766
- queue_out = Queue()
767
 
768
  def worker_seg(gpu_idx: int):
769
  seg = segmenters[gpu_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
  while True:
771
  item = queue_in.get()
772
  if item is None:
 
773
  queue_in.task_done()
774
  break
775
 
776
  idx, frame = item
777
-
778
  if idx % 30 == 0:
779
- logging.debug("Segmenting frame %d (GPU %d)", idx, gpu_idx)
 
 
 
780
 
781
- try:
782
- processed, _ = infer_segmentation_frame(
783
- frame,
784
- text_queries=queries,
785
- segmenter_name=None,
786
- segmenter_instance=seg
787
- )
788
- queue_out.put((idx, processed))
789
- except Exception as e:
790
- logging.error("Segmentation failed frame %d: %s", idx, e)
791
- queue_out.put((idx, frame))
792
-
793
  queue_in.task_done()
794
 
795
  workers = []
@@ -811,6 +1028,10 @@ def run_segmentation(
811
  while next_idx < total_frames:
812
  try:
813
  while next_idx not in buffer:
 
 
 
 
814
  idx, frm = queue_out.get(timeout=1.0)
815
  buffer[idx] = frm
816
 
@@ -1014,49 +1235,82 @@ def run_depth_inference(
1014
  logging.info("Starting Phase 2: Streaming...")
1015
 
1016
  queue_in = Queue(maxsize=16)
1017
- queue_out = Queue()
 
1018
 
1019
  def worker_depth(gpu_idx: int):
1020
  est = estimators[gpu_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1021
  while True:
1022
  item = queue_in.get()
1023
  if item is None:
 
1024
  queue_in.task_done()
1025
  break
1026
 
1027
  idx, frame = item
1028
- try:
1029
- if idx % 30 == 0:
1030
- logging.info("Depth frame %d (GPU %d)", idx, gpu_idx)
1031
-
1032
- with est.lock:
1033
- res = est.predict(frame)
1034
-
1035
- depth_map = res.depth_map
1036
- # Colorize
1037
- colored = colorize_depth_map(depth_map, global_min, global_max)
1038
-
1039
- # Overlay Detections
1040
- # Detections list is [ [det1, det2], [det1, det2] ... ]
1041
- if detections and idx < len(detections):
1042
- frame_dets = detections[idx]
1043
- if frame_dets:
1044
- import cv2
1045
- boxes = []
1046
- labels = []
1047
- for d in frame_dets:
1048
- boxes.append(d.get("bbox"))
1049
- lbl = d.get("label", "obj")
1050
- if d.get("depth_est_m"):
1051
- lbl = f"{lbl} {int(d['depth_est_m'])}m"
1052
- labels.append(lbl)
1053
- colored = draw_boxes(colored, boxes=boxes, label_names=labels)
1054
-
1055
- queue_out.put((idx, colored))
1056
- except Exception as e:
1057
- logging.error("Depth worker failed frame %d: %s", idx, e)
1058
- queue_out.put((idx, frame)) # Fallback to original?
1059
 
 
 
 
 
 
 
1060
  queue_in.task_done()
1061
 
1062
  # Workers
@@ -1081,6 +1335,8 @@ def run_depth_inference(
1081
  while next_idx < total_frames:
1082
  try:
1083
  while next_idx not in buffer:
 
 
1084
  idx, frm = queue_out.get(timeout=1.0)
1085
  buffer[idx] = frm
1086
 
 
22
  from models.model_loader import load_detector, load_detector_on_device
23
  from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
24
  from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
25
+ from models.depth_estimators.base import DepthEstimator
26
  from utils.video import extract_frames, write_video, VideoReader, VideoWriter
27
  from utils.gpt_distance import estimate_distance_gpt
28
  import tempfile
 
353
  ), detections
354
 
355
 
356
+ def infer_batch(
357
+ frames: List[np.ndarray],
358
+ frame_indices: List[int],
359
+ queries: Sequence[str],
360
+ detector_instance: ObjectDetector,
361
+ depth_estimator_instance: Optional[DepthEstimator] = None,
362
+ depth_scale: float = 1.0,
363
+ depth_frame_stride: int = 3,
364
+ ) -> List[Tuple[int, np.ndarray, List[Dict[str, Any]]]]:
365
+ # Batch detection
366
+ text_queries = list(queries) or ["object"]
367
+ try:
368
+ if detector_instance.supports_batch:
369
+ with detector_instance.lock:
370
+ det_results = detector_instance.predict_batch(frames, text_queries)
371
+ else:
372
+ # Fallback
373
+ with detector_instance.lock:
374
+ det_results = [detector_instance.predict(f, text_queries) for f in frames]
375
+ except Exception:
376
+ logging.exception("Batch detection failed")
377
+ # Return empty for all
378
+ return [(idx, f, []) for idx, f in zip(frame_indices, frames)]
379
+
380
+ # Batch depth
381
+ depth_map_results = {} # frame_idx -> depth_map
382
+ depth_batch_inputs = []
383
+ depth_batch_indices = []
384
+
385
+ for idx, f in zip(frame_indices, frames):
386
+ if idx % depth_frame_stride == 0:
387
+ depth_batch_inputs.append(f)
388
+ depth_batch_indices.append(idx)
389
+
390
+ if depth_estimator_instance and depth_batch_inputs:
391
+ try:
392
+ with depth_estimator_instance.lock:
393
+ if depth_estimator_instance.supports_batch:
394
+ d_results = depth_estimator_instance.predict_batch(depth_batch_inputs)
395
+ else:
396
+ d_results = [depth_estimator_instance.predict(f) for f in depth_batch_inputs]
397
+
398
+ for idx, res in zip(depth_batch_indices, d_results):
399
+ depth_map_results[idx] = res
400
+ except Exception:
401
+ logging.exception("Batch depth estimation failed")
402
+
403
+ # Post-process and merge
404
+ outputs = []
405
+ for i, (idx, frame, det_result) in enumerate(zip(frame_indices, frames, det_results)):
406
+ detections = _build_detection_records(
407
+ det_result.boxes, det_result.scores, det_result.labels, text_queries, det_result.label_names
408
+ )
409
+
410
+ if idx in depth_map_results:
411
+ try:
412
+ # existing _attach_depth_metrics expects detections and estimator name/instance
413
+ # but we already computed depth. We need a helper or just modify logical flow.
414
+ # Actually _attach_depth_metrics calls predict(). We want to skip predict.
415
+ # Let's manually attach.
416
+ d_res = depth_map_results[idx]
417
+ # We need to manually invoke the attachment logic using the precomputed result.
418
+ # Refactoring _attach_depth_metrics to accept result would be cleaner, but for now:
419
+ # Copy-paste logic or use a trick.
420
+
421
+ # Let's extract logic from _attach_depth_metrics essentially.
422
+ # Wait, _attach_depth_metrics does the box checking.
423
+ _attach_depth_from_result(detections, d_res, depth_scale)
424
+ except Exception:
425
+ logging.warning("Failed to attach depth for frame %d", idx)
426
+
427
+ display_labels = [_build_display_label(d) for d in detections]
428
+ processed = draw_boxes(frame, det_result.boxes, label_names=display_labels)
429
+ outputs.append((idx, processed, detections))
430
+
431
+ return outputs
432
+
433
+ def _build_display_label(det):
434
+ label = det["label"]
435
+ if det.get("depth_valid") and det.get("depth_est_m") is not None:
436
+ depth_str = f"{int(det['depth_est_m'])}m"
437
+ label = f"{label} {depth_str}"
438
+ return label
439
+
440
+ def _attach_depth_from_result(detections, depth_result, depth_scale):
441
+ depth_map = depth_result.depth_map
442
+ if depth_map is None or depth_map.size == 0: return
443
+
444
+ height, width = depth_map.shape[:2]
445
+ valid_depths = []
446
+
447
+ for det in detections:
448
+ det["depth_est_m"] = None
449
+ det["depth_rel"] = None
450
+ det["depth_valid"] = False
451
+
452
+ bbox = det.get("bbox")
453
+ if not bbox or len(bbox) < 4: continue
454
+
455
+ x1, y1, x2, y2 = [int(coord) for coord in bbox[:4]]
456
+ x1 = max(0, min(width - 1, x1))
457
+ y1 = max(0, min(height - 1, y1))
458
+ x2 = max(x1 + 1, min(width, x2))
459
+ y2 = max(y1 + 1, min(height, y2))
460
+
461
+ patch = depth_map[y1:y2, x1:x2]
462
+ if patch.size == 0: continue
463
+
464
+ h_p, w_p = patch.shape
465
+ cy, cx = h_p // 2, w_p // 2
466
+ dy, dx = h_p // 4, w_p // 4
467
+ center_patch = patch[cy - dy : cy + dy, cx - dx : cx + dx]
468
+ if center_patch.size == 0: center_patch = patch
469
+
470
+ finite = center_patch[np.isfinite(center_patch)]
471
+ if finite.size == 0: continue
472
+
473
+ depth_raw = float(np.median(finite))
474
+ if depth_raw <= 1e-6:
475
+ det["depth_est_m"] = None
476
+ det["depth_valid"] = False
477
+ continue
478
+
479
+ try:
480
+ depth_est = depth_scale / depth_raw
481
+ except ZeroDivisionError:
482
+ continue
483
+
484
+ det["depth_est_m"] = depth_est
485
+ det["depth_valid"] = True
486
+ valid_depths.append(depth_est)
487
+
488
+ if not valid_depths: return
489
+
490
+ min_depth = float(min(valid_depths))
491
+ max_depth = float(max(valid_depths))
492
+ denom = max(max_depth - min_depth, 1e-6)
493
+
494
+ for det in detections:
495
+ if det.get("depth_valid"):
496
+ det["depth_rel"] = (float(det["depth_est_m"]) - min_depth) / denom
497
+
498
+
499
  def infer_segmentation_frame(
500
  frame: np.ndarray,
501
  text_queries: Optional[List[str]] = None,
 
701
  # queue_in: (frame_idx, frame_data)
702
  # queue_out: (frame_idx, processed_frame, detections)
703
  queue_in = Queue(maxsize=16)
704
+ # Bound queue_out to prevent OOM
705
+ # Maxsize should be enough to keep writer busy but not explode memory
706
+ queue_out_max = max(32, (len(detectors) if detectors else 1) * 4)
707
+ queue_out = Queue(maxsize=queue_out_max)
708
 
709
  # 5. Worker Function
710
  def worker_task(gpu_idx: int):
711
  detector_instance = detectors[gpu_idx]
712
  depth_instance = depth_estimators[gpu_idx] if depth_estimators[gpu_idx] else None
713
 
714
+ batch_size = detector_instance.max_batch_size if detector_instance.supports_batch else 1
715
+ batch_accum = [] # List[Tuple[idx, frame]]
716
+
717
+ def flush_batch():
718
+ if not batch_accum: return
719
+
720
+ indices = [item[0] for item in batch_accum]
721
+ frames = [item[1] for item in batch_accum]
722
+
723
+ batch_outputs = infer_batch(
724
+ frames, indices, queries, detector_instance,
725
+ depth_estimator_instance=depth_instance,
726
+ depth_scale=depth_scale
727
+ )
728
+
729
+ for out_item in batch_outputs:
730
+ while True:
731
+ try:
732
+ queue_out.put(out_item, timeout=1.0)
733
+ break
734
+ except Full:
735
+ if job_id: _check_cancellation(job_id)
736
+
737
+ batch_accum.clear()
738
+
739
  while True:
740
  item = queue_in.get()
741
  if item is None:
742
+ flush_batch()
743
  queue_in.task_done()
744
  break
745
 
 
749
  logging.debug("Processing frame %d on device %s", frame_idx, "cpu" if num_gpus==0 else f"cuda:{gpu_idx}")
750
 
751
  try:
752
+ batch_accum.append((frame_idx, frame_data))
753
+ if len(batch_accum) >= batch_size:
754
+ flush_batch()
 
 
 
 
 
 
 
 
 
 
 
755
  except Exception as e:
756
+ logging.exception("Error processing batch around frame %d", frame_idx)
757
+ # Fail strictly or soft?
758
+ # If batch fails, we probably lost a chunk.
759
+ # Put placeholders for what we have in accum
760
+ for idx, frm in batch_accum:
761
+ while True:
762
+ try:
763
+ queue_out.put((idx, frm, []), timeout=1.0)
764
+ break
765
+ except Full:
766
+ if job_id: _check_cancellation(job_id)
767
+ batch_accum.clear()
768
 
769
  queue_in.task_done()
770
 
 
797
  # Fetch from queue
798
  try:
799
  while next_idx not in buffer:
800
+ # Backpressure: If buffer gets too big due to out-of-order frames,
801
+ # we might want to warn or just hope for the best.
802
+ # But here we are just consuming.
803
+
804
+ # However, if 'buffer' grows too large (because we are missing next_idx),
805
+ # we are effectively unbounded again if queue_out fills up with future frames.
806
+ # So we should monitor buffer size.
807
+ if len(buffer) > 64:
808
+ logging.warning("Writer buffer large (%d items), waiting for frame %d...", len(buffer), next_idx)
809
+
810
  item = queue_out.get(timeout=1.0) # wait
811
+
812
  idx, p_frame, dets = item
813
  buffer[idx] = (p_frame, dets)
814
 
 
945
 
946
  # 3. Processing
947
  queue_in = Queue(maxsize=16)
948
+ queue_out = Queue(maxsize=max(32, len(segmenters)*4))
949
 
950
  def worker_seg(gpu_idx: int):
951
  seg = segmenters[gpu_idx]
952
+ batch_size = seg.max_batch_size if seg.supports_batch else 1
953
+ batch_accum = []
954
+
955
+ def flush_batch():
956
+ if not batch_accum: return
957
+ indices = [i for i, _ in batch_accum]
958
+ frames = [f for _, f in batch_accum]
959
+
960
+ try:
961
+ # 1. Inference
962
+ if seg.supports_batch:
963
+ with seg.lock:
964
+ results = seg.predict_batch(frames, queries)
965
+ else:
966
+ with seg.lock:
967
+ results = [seg.predict(f, queries) for f in frames]
968
+
969
+ # 2. Post-process loop
970
+ for idx, frm, res in zip(indices, frames, results):
971
+ labels = queries or []
972
+ if len(labels) == 1:
973
+ masks = res.masks if res.masks is not None else []
974
+ labels = [labels[0] for _ in range(len(masks))]
975
+ processed = draw_masks(frm, res.masks, labels=labels)
976
+
977
+ while True:
978
+ try:
979
+ queue_out.put((idx, processed), timeout=1.0)
980
+ break
981
+ except Full:
982
+ if job_id: _check_cancellation(job_id)
983
+
984
+ except Exception as e:
985
+ logging.error("Batch seg failed: %s", e)
986
+ for idx, frm in batch_accum:
987
+ while True:
988
+ try:
989
+ queue_out.put((idx, frm), timeout=1.0) # Fallback
990
+ break
991
+ except Full:
992
+ if job_id: _check_cancellation(job_id)
993
+ batch_accum.clear()
994
+
995
  while True:
996
  item = queue_in.get()
997
  if item is None:
998
+ flush_batch()
999
  queue_in.task_done()
1000
  break
1001
 
1002
  idx, frame = item
1003
+ batch_accum.append(item)
1004
  if idx % 30 == 0:
1005
+ logging.debug("Seg frame %d (GPU %d)", idx, gpu_idx)
1006
+
1007
+ if len(batch_accum) >= batch_size:
1008
+ flush_batch()
1009
 
 
 
 
 
 
 
 
 
 
 
 
 
1010
  queue_in.task_done()
1011
 
1012
  workers = []
 
1028
  while next_idx < total_frames:
1029
  try:
1030
  while next_idx not in buffer:
1031
+ # Check buffer size
1032
+ if len(buffer) > 64:
1033
+ logging.warning("Writer buffer large (%d), waiting for %d", len(buffer), next_idx)
1034
+
1035
  idx, frm = queue_out.get(timeout=1.0)
1036
  buffer[idx] = frm
1037
 
 
1235
  logging.info("Starting Phase 2: Streaming...")
1236
 
1237
  queue_in = Queue(maxsize=16)
1238
+ queue_out_max = max(32, (len(estimators) if estimators else 1) * 4)
1239
+ queue_out = Queue(maxsize=queue_out_max)
1240
 
1241
  def worker_depth(gpu_idx: int):
1242
  est = estimators[gpu_idx]
1243
+ batch_size = est.max_batch_size if est.supports_batch else 1
1244
+ batch_accum = []
1245
+
1246
+ def flush_batch():
1247
+ if not batch_accum: return
1248
+ indices = [i for i, _ in batch_accum]
1249
+ frames = [f for _, f in batch_accum]
1250
+
1251
+ try:
1252
+ # 1. Inference
1253
+ if est.supports_batch:
1254
+ with est.lock:
1255
+ results = est.predict_batch(frames)
1256
+ else:
1257
+ with est.lock:
1258
+ results = [est.predict(f) for f in frames]
1259
+
1260
+ # 2. Post-process loop
1261
+ for idx, frm, res in zip(indices, frames, results):
1262
+ depth_map = res.depth_map
1263
+ colored = colorize_depth_map(depth_map, global_min, global_max)
1264
+
1265
+ # Overlay Detections
1266
+ if detections and idx < len(detections):
1267
+ frame_dets = detections[idx]
1268
+ if frame_dets:
1269
+ import cv2
1270
+ boxes = []
1271
+ labels = []
1272
+ for d in frame_dets:
1273
+ boxes.append(d.get("bbox"))
1274
+ lbl = d.get("label", "obj")
1275
+ if d.get("depth_est_m"):
1276
+ lbl = f"{lbl} {int(d['depth_est_m'])}m"
1277
+ labels.append(lbl)
1278
+ colored = draw_boxes(colored, boxes=boxes, label_names=labels)
1279
+
1280
+ while True:
1281
+ try:
1282
+ queue_out.put((idx, colored), timeout=1.0)
1283
+ break
1284
+ except Full:
1285
+ if job_id: _check_cancellation(job_id)
1286
+
1287
+ except Exception as e:
1288
+ logging.error("Batch depth failed: %s", e)
1289
+ for idx, frm in batch_accum:
1290
+ while True:
1291
+ try:
1292
+ queue_out.put((idx, frm), timeout=1.0)
1293
+ break
1294
+ except Full:
1295
+ if job_id: _check_cancellation(job_id)
1296
+ batch_accum.clear()
1297
+
1298
  while True:
1299
  item = queue_in.get()
1300
  if item is None:
1301
+ flush_batch()
1302
  queue_in.task_done()
1303
  break
1304
 
1305
  idx, frame = item
1306
+ batch_accum.append(item)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1307
 
1308
+ if idx % 30 == 0:
1309
+ logging.info("Depth frame %d (GPU %d)", idx, gpu_idx)
1310
+
1311
+ if len(batch_accum) >= batch_size:
1312
+ flush_batch()
1313
+
1314
  queue_in.task_done()
1315
 
1316
  # Workers
 
1335
  while next_idx < total_frames:
1336
  try:
1337
  while next_idx not in buffer:
1338
+ if len(buffer) > 64:
1339
+ logging.warning("Writer buffer large (%d), waiting for %d", len(buffer), next_idx)
1340
  idx, frm = queue_out.get(timeout=1.0)
1341
  buffer[idx] = frm
1342
 
models/depth_estimators/base.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import NamedTuple
2
 
3
  import numpy as np
4
 
@@ -13,6 +13,8 @@ class DepthEstimator:
13
  """Base interface for depth estimation models."""
14
 
15
  name: str
 
 
16
 
17
  def predict(self, frame: np.ndarray) -> DepthResult:
18
  """
@@ -25,3 +27,6 @@ class DepthEstimator:
25
  DepthResult with depth_map and focal_length
26
  """
27
  raise NotImplementedError
 
 
 
 
1
+ from typing import NamedTuple, Sequence, List
2
 
3
  import numpy as np
4
 
 
13
  """Base interface for depth estimation models."""
14
 
15
  name: str
16
+ supports_batch: bool = False
17
+ max_batch_size: int = 1
18
 
19
  def predict(self, frame: np.ndarray) -> DepthResult:
20
  """
 
27
  DepthResult with depth_map and focal_length
28
  """
29
  raise NotImplementedError
30
+
31
+ def predict_batch(self, frames: Sequence[np.ndarray]) -> Sequence[DepthResult]:
32
+ return [self.predict(f) for f in frames]
models/depth_estimators/depth_anything_v2.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
 
3
  import numpy as np
4
  import torch
@@ -12,6 +13,24 @@ class DepthAnythingV2Estimator(DepthEstimator):
12
  """Depth-Anything depth estimator (Transformers-compatible)."""
13
 
14
  name = "depth"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def __init__(self, device: str = None) -> None:
17
  logging.info("Loading Depth-Anything model from Hugging Face (transformers)...")
@@ -50,25 +69,30 @@ class DepthAnythingV2Estimator(DepthEstimator):
50
  outputs = self.model(**inputs)
51
 
52
  raw_depth = outputs.predicted_depth
53
- if raw_depth.dim() == 2:
54
- raw_depth = raw_depth.unsqueeze(0).unsqueeze(0)
55
- elif raw_depth.dim() == 3:
56
- raw_depth = raw_depth.unsqueeze(1) if raw_depth.shape[0] == 1 else raw_depth.unsqueeze(0)
57
-
58
- if raw_depth.shape[-2:] != (height, width):
59
- import torch.nn.functional as F
60
-
61
- raw_depth = F.interpolate(
62
- raw_depth,
63
- size=(height, width),
64
- mode="bilinear",
65
- align_corners=False,
66
- )
67
-
68
- depth_map = raw_depth.squeeze().cpu().numpy().astype(np.float32, copy=False)
69
  except Exception as exc:
70
  logging.error("Depth-Anything inference failed: %s", exc)
71
  h, w = frame.shape[:2]
72
  return DepthResult(depth_map=np.zeros((h, w), dtype=np.float32), focal_length=1.0)
73
 
74
  return DepthResult(depth_map=depth_map, focal_length=1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from typing import Sequence
3
 
4
  import numpy as np
5
  import torch
 
13
  """Depth-Anything depth estimator (Transformers-compatible)."""
14
 
15
  name = "depth"
16
+ supports_batch = True
17
+ max_batch_size = 4
18
+
19
+ def _resize_depth(self, raw_depth, height, width):
20
+ if raw_depth.dim() == 2:
21
+ raw_depth = raw_depth.unsqueeze(0).unsqueeze(0)
22
+ elif raw_depth.dim() == 3:
23
+ raw_depth = raw_depth.unsqueeze(1) if raw_depth.shape[0] == 1 else raw_depth.unsqueeze(0)
24
+
25
+ if raw_depth.shape[-2:] != (height, width):
26
+ import torch.nn.functional as F
27
+ raw_depth = F.interpolate(
28
+ raw_depth,
29
+ size=(height, width),
30
+ mode="bilinear",
31
+ align_corners=False,
32
+ )
33
+ return raw_depth.squeeze().cpu().numpy().astype(np.float32, copy=False)
34
 
35
  def __init__(self, device: str = None) -> None:
36
  logging.info("Loading Depth-Anything model from Hugging Face (transformers)...")
 
69
  outputs = self.model(**inputs)
70
 
71
  raw_depth = outputs.predicted_depth
72
+ depth_map = self._resize_depth(raw_depth, height, width)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  except Exception as exc:
74
  logging.error("Depth-Anything inference failed: %s", exc)
75
  h, w = frame.shape[:2]
76
  return DepthResult(depth_map=np.zeros((h, w), dtype=np.float32), focal_length=1.0)
77
 
78
  return DepthResult(depth_map=depth_map, focal_length=1.0)
79
+
80
+ def predict_batch(self, frames: Sequence[np.ndarray]) -> Sequence[DepthResult]:
81
+ # Convert frames to PIL images
82
+ pil_images = [Image.fromarray(f[:, :, ::-1]) for f in frames] # BGR->RGB
83
+ sizes = [(img.height, img.width) for img in pil_images]
84
+
85
+ inputs = self.image_processor(images=pil_images, return_tensors="pt").to(self.device)
86
+
87
+ with torch.no_grad():
88
+ outputs = self.model(**inputs)
89
+
90
+ # outputs.predicted_depth is (B, H, W)
91
+ depths = outputs.predicted_depth
92
+
93
+ results = []
94
+ for i, (h, w) in enumerate(sizes):
95
+ depth_map = self._resize_depth(depths[i], h, w)
96
+ results.append(DepthResult(depth_map=depth_map, focal_length=1.0))
97
+
98
+ return results
models/detectors/base.py CHANGED
@@ -14,6 +14,12 @@ class ObjectDetector:
14
  """Detector interface to keep inference agnostic to model details."""
15
 
16
  name: str
 
 
17
 
18
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
19
  raise NotImplementedError
 
 
 
 
 
14
  """Detector interface to keep inference agnostic to model details."""
15
 
16
  name: str
17
+ supports_batch: bool = False
18
+ max_batch_size: int = 1
19
 
20
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
21
  raise NotImplementedError
22
+
23
+ def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
24
+ """Default: sequential fallback"""
25
+ return [self.predict(f, queries) for f in frames]
models/detectors/detr.py CHANGED
@@ -26,17 +26,10 @@ class DetrDetector(ObjectDetector):
26
  self.model.to(self.device)
27
  self.model.eval()
28
 
29
- def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
30
- inputs = self.processor(images=frame, return_tensors="pt")
31
- inputs = {key: value.to(self.device) for key, value in inputs.items()}
32
- with torch.no_grad():
33
- outputs = self.model(**inputs)
34
- target_sizes = torch.tensor([frame.shape[:2]], device=self.device)
35
- processed = self.processor.post_process_object_detection(
36
- outputs,
37
- threshold=self.score_threshold,
38
- target_sizes=target_sizes,
39
- )[0]
40
  boxes = processed["boxes"].cpu().numpy()
41
  scores = processed["scores"].cpu().tolist()
42
  labels = processed["labels"].cpu().tolist()
@@ -49,3 +42,31 @@ class DetrDetector(ObjectDetector):
49
  labels=labels,
50
  label_names=label_names,
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  self.model.to(self.device)
27
  self.model.eval()
28
 
29
+ supports_batch = True
30
+ max_batch_size = 4
31
+
32
+ def _parse_single_result(self, processed) -> DetectionResult:
 
 
 
 
 
 
 
33
  boxes = processed["boxes"].cpu().numpy()
34
  scores = processed["scores"].cpu().tolist()
35
  labels = processed["labels"].cpu().tolist()
 
42
  labels=labels,
43
  label_names=label_names,
44
  )
45
+
46
+ def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
47
+ inputs = self.processor(images=frame, return_tensors="pt")
48
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
49
+ with torch.no_grad():
50
+ outputs = self.model(**inputs)
51
+ target_sizes = torch.tensor([frame.shape[:2]], device=self.device)
52
+ processed = self.processor.post_process_object_detection(
53
+ outputs,
54
+ threshold=self.score_threshold,
55
+ target_sizes=target_sizes,
56
+ )[0]
57
+ return self._parse_single_result(processed)
58
+
59
+ def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
60
+ inputs = self.processor(images=frames, return_tensors="pt", padding=True)
61
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
62
+
63
+ with torch.no_grad():
64
+ outputs = self.model(**inputs)
65
+
66
+ target_sizes = torch.tensor([f.shape[:2] for f in frames], device=self.device)
67
+ processed_list = self.processor.post_process_object_detection(
68
+ outputs,
69
+ threshold=self.score_threshold,
70
+ target_sizes=target_sizes,
71
+ )
72
+ return [self._parse_single_result(p) for p in processed_list]
models/detectors/drone_yolo.py CHANGED
@@ -15,6 +15,8 @@ class DroneYoloDetector(ObjectDetector):
15
 
16
  REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
17
  DEFAULT_WEIGHT = "best.pt"
 
 
18
 
19
  def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
20
  self.name = "drone_yolo"
@@ -42,15 +44,7 @@ class DroneYoloDetector(ObjectDetector):
42
  keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
43
  return keep or list(range(len(label_names)))
44
 
45
- def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
46
- device_arg = self.device
47
- results = self.model.predict(
48
- source=frame,
49
- device=device_arg,
50
- conf=self.score_threshold,
51
- verbose=False,
52
- )
53
- result = results[0]
54
  boxes = result.boxes
55
  if boxes is None or boxes.xyxy is None:
56
  empty = np.empty((0, 4), dtype=np.float32)
@@ -71,3 +65,22 @@ class DroneYoloDetector(ObjectDetector):
71
  labels=label_ids,
72
  label_names=label_names,
73
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  REPO_ID = "rujutashashikanjoshi/yolo12-drone-detection-0205-100m"
17
  DEFAULT_WEIGHT = "best.pt"
18
+ supports_batch = True
19
+ max_batch_size = 8
20
 
21
  def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
22
  self.name = "drone_yolo"
 
44
  keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
45
  return keep or list(range(len(label_names)))
46
 
47
+ def _parse_single_result(self, result, queries: Sequence[str]) -> DetectionResult:
 
 
 
 
 
 
 
 
48
  boxes = result.boxes
49
  if boxes is None or boxes.xyxy is None:
50
  empty = np.empty((0, 4), dtype=np.float32)
 
65
  labels=label_ids,
66
  label_names=label_names,
67
  )
68
+
69
+ def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
70
+ device_arg = self.device
71
+ results = self.model.predict(
72
+ source=frame,
73
+ device=device_arg,
74
+ conf=self.score_threshold,
75
+ verbose=False,
76
+ )
77
+ return self._parse_single_result(results[0], queries)
78
+
79
+ def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
80
+ results = self.model.predict(
81
+ source=frames,
82
+ device=self.device,
83
+ conf=self.score_threshold,
84
+ verbose=False,
85
+ )
86
+ return [self._parse_single_result(r, queries) for r in results]
models/detectors/grounding_dino.py CHANGED
@@ -33,36 +33,35 @@ class GroundingDinoDetector(ObjectDetector):
33
  return "object."
34
  return " ".join(f"{term}." for term in filtered)
35
 
36
- def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
37
- prompt = self._build_prompt(queries)
38
- inputs = self.processor(images=frame, text=prompt, return_tensors="pt")
39
- inputs = {key: value.to(self.device) for key, value in inputs.items()}
40
- with torch.no_grad():
41
- outputs = self.model(**inputs)
42
- target_sizes = torch.tensor([frame.shape[:2]], device=self.device)
43
  try:
44
- processed = self.processor.post_process_grounded_object_detection(
45
  outputs,
46
- inputs["input_ids"],
47
  box_threshold=self.box_threshold,
48
  text_threshold=self.text_threshold,
49
  target_sizes=target_sizes,
50
- )[0]
51
  except TypeError:
52
  try:
53
- processed = self.processor.post_process_grounded_object_detection(
54
  outputs,
55
- inputs["input_ids"],
56
  threshold=self.box_threshold,
57
  text_threshold=self.text_threshold,
58
  target_sizes=target_sizes,
59
- )[0]
60
  except TypeError:
61
- processed = self.processor.post_process_grounded_object_detection(
62
  outputs,
63
- inputs["input_ids"],
64
  target_sizes=target_sizes,
65
- )[0]
 
 
66
  boxes = processed["boxes"].cpu().numpy()
67
  scores = processed["scores"].cpu().tolist()
68
  label_names = list(processed.get("labels") or [])
@@ -73,3 +72,26 @@ class GroundingDinoDetector(ObjectDetector):
73
  labels=label_ids,
74
  label_names=label_names,
75
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return "object."
34
  return " ".join(f"{term}." for term in filtered)
35
 
36
+ supports_batch = True
37
+ max_batch_size = 4
38
+
39
+ def _post_process(self, outputs, input_ids, target_sizes):
 
 
 
40
  try:
41
+ return self.processor.post_process_grounded_object_detection(
42
  outputs,
43
+ input_ids,
44
  box_threshold=self.box_threshold,
45
  text_threshold=self.text_threshold,
46
  target_sizes=target_sizes,
47
+ )
48
  except TypeError:
49
  try:
50
+ return self.processor.post_process_grounded_object_detection(
51
  outputs,
52
+ input_ids,
53
  threshold=self.box_threshold,
54
  text_threshold=self.text_threshold,
55
  target_sizes=target_sizes,
56
+ )
57
  except TypeError:
58
+ return self.processor.post_process_grounded_object_detection(
59
  outputs,
60
+ input_ids,
61
  target_sizes=target_sizes,
62
+ )
63
+
64
+ def _parse_single_result(self, processed) -> DetectionResult:
65
  boxes = processed["boxes"].cpu().numpy()
66
  scores = processed["scores"].cpu().tolist()
67
  label_names = list(processed.get("labels") or [])
 
72
  labels=label_ids,
73
  label_names=label_names,
74
  )
75
+
76
+ def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
77
+ prompt = self._build_prompt(queries)
78
+ inputs = self.processor(images=frame, text=prompt, return_tensors="pt")
79
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
80
+ with torch.no_grad():
81
+ outputs = self.model(**inputs)
82
+ target_sizes = torch.tensor([frame.shape[:2]], device=self.device)
83
+ processed_list = self._post_process(outputs, inputs["input_ids"], target_sizes)
84
+ return self._parse_single_result(processed_list[0])
85
+
86
+ def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
87
+ prompt = self._build_prompt(queries)
88
+ # Same prompt for all frames in batch
89
+ inputs = self.processor(images=frames, text=[prompt]*len(frames), return_tensors="pt", padding=True)
90
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
91
+
92
+ with torch.no_grad():
93
+ outputs = self.model(**inputs)
94
+
95
+ target_sizes = torch.tensor([f.shape[:2] for f in frames], device=self.device)
96
+ processed_list = self._post_process(outputs, inputs["input_ids"], target_sizes)
97
+ return [self._parse_single_result(p) for p in processed_list]
models/detectors/yolov8.py CHANGED
@@ -14,6 +14,8 @@ class HuggingFaceYoloV8Detector(ObjectDetector):
14
 
15
  REPO_ID = "Ultralytics/YOLOv8"
16
  WEIGHT_FILE = "yolov8s.pt"
 
 
17
 
18
  def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
19
  self.name = "hf_yolov8"
@@ -40,14 +42,7 @@ class HuggingFaceYoloV8Detector(ObjectDetector):
40
  keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
41
  return keep or list(range(len(label_names)))
42
 
43
- def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
44
- results = self.model.predict(
45
- source=frame,
46
- device=self.device,
47
- conf=self.score_threshold,
48
- verbose=False,
49
- )
50
- result = results[0]
51
  boxes = result.boxes
52
  if boxes is None or boxes.xyxy is None:
53
  empty = np.empty((0, 4), dtype=np.float32)
@@ -69,3 +64,21 @@ class HuggingFaceYoloV8Detector(ObjectDetector):
69
  label_names=label_names,
70
  )
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  REPO_ID = "Ultralytics/YOLOv8"
16
  WEIGHT_FILE = "yolov8s.pt"
17
+ supports_batch = True
18
+ max_batch_size = 8
19
 
20
  def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
21
  self.name = "hf_yolov8"
 
42
  keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
43
  return keep or list(range(len(label_names)))
44
 
45
+ def _parse_single_result(self, result, queries: Sequence[str]) -> DetectionResult:
 
 
 
 
 
 
 
46
  boxes = result.boxes
47
  if boxes is None or boxes.xyxy is None:
48
  empty = np.empty((0, 4), dtype=np.float32)
 
64
  label_names=label_names,
65
  )
66
 
67
+ def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
68
+ results = self.model.predict(
69
+ source=frame,
70
+ device=self.device,
71
+ conf=self.score_threshold,
72
+ verbose=False,
73
+ )
74
+ return self._parse_single_result(results[0], queries)
75
+
76
+ def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
77
+ results = self.model.predict(
78
+ source=frames,
79
+ device=self.device,
80
+ conf=self.score_threshold,
81
+ verbose=False,
82
+ )
83
+ return [self._parse_single_result(r, queries) for r in results]
84
+
models/segmenters/base.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import NamedTuple, Optional
2
 
3
  import numpy as np
4
 
@@ -14,6 +14,8 @@ class Segmenter:
14
  """Base interface for segmentation models."""
15
 
16
  name: str
 
 
17
 
18
  def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
19
  """
@@ -27,3 +29,6 @@ class Segmenter:
27
  SegmentationResult with masks and optional metadata
28
  """
29
  raise NotImplementedError
 
 
 
 
1
+ from typing import NamedTuple, Optional, Sequence, List
2
 
3
  import numpy as np
4
 
 
14
  """Base interface for segmentation models."""
15
 
16
  name: str
17
+ supports_batch: bool = False
18
+ max_batch_size: int = 1
19
 
20
  def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
21
  """
 
29
  SegmentationResult with masks and optional metadata
30
  """
31
  raise NotImplementedError
32
+
33
+ def predict_batch(self, frames: Sequence[np.ndarray], text_prompts: Optional[list] = None) -> Sequence[SegmentationResult]:
34
+ return [self.predict(f, text_prompts) for f in frames]
models/segmenters/sam3.py CHANGED
@@ -1,5 +1,5 @@
1
  import logging
2
- from typing import Optional
3
 
4
  import numpy as np
5
  import torch
@@ -55,6 +55,38 @@ class SAM3Segmenter(Segmenter):
55
 
56
  logging.info("SAM3 model loaded successfully")
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
59
  """
60
  Run SAM3 segmentation on a frame.
@@ -95,34 +127,7 @@ class SAM3Segmenter(Segmenter):
95
  mask_threshold=self.mask_threshold,
96
  target_sizes=inputs.get("original_sizes").tolist(),
97
  )[0]
98
-
99
- # Extract results
100
- masks = results.get("masks", [])
101
- scores = results.get("scores", None)
102
- boxes = results.get("boxes", None)
103
-
104
- # Convert to numpy arrays
105
- if len(masks) > 0:
106
- # Stack masks: list of (H, W) -> (N, H, W)
107
- masks_array = np.stack([m.cpu().numpy() for m in masks])
108
- else:
109
- # No objects detected
110
- masks_array = np.zeros(
111
- (0, frame.shape[0], frame.shape[1]), dtype=bool
112
- )
113
-
114
- scores_array = (
115
- scores.cpu().numpy() if scores is not None else None
116
- )
117
- boxes_array = (
118
- boxes.cpu().numpy() if boxes is not None else None
119
- )
120
-
121
- return SegmentationResult(
122
- masks=masks_array,
123
- scores=scores_array,
124
- boxes=boxes_array,
125
- )
126
 
127
  except Exception:
128
  logging.exception("SAM3 post-processing failed")
@@ -132,3 +137,38 @@ class SAM3Segmenter(Segmenter):
132
  scores=None,
133
  boxes=None,
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from typing import Optional, Sequence
3
 
4
  import numpy as np
5
  import torch
 
55
 
56
  logging.info("SAM3 model loaded successfully")
57
 
58
+ supports_batch = True
59
+ max_batch_size = 4
60
+
61
+ def _parse_single_result(self, results, frame_shape) -> SegmentationResult:
62
+ # Extract results
63
+ masks = results.get("masks", [])
64
+ scores = results.get("scores", None)
65
+ boxes = results.get("boxes", None)
66
+
67
+ # Convert to numpy arrays
68
+ if len(masks) > 0:
69
+ # Stack masks: list of (H, W) -> (N, H, W)
70
+ masks_array = np.stack([m.cpu().numpy() for m in masks])
71
+ else:
72
+ # No objects detected
73
+ masks_array = np.zeros(
74
+ (0, frame_shape[0], frame_shape[1]), dtype=bool
75
+ )
76
+
77
+ scores_array = (
78
+ scores.cpu().numpy() if scores is not None else None
79
+ )
80
+ boxes_array = (
81
+ boxes.cpu().numpy() if boxes is not None else None
82
+ )
83
+
84
+ return SegmentationResult(
85
+ masks=masks_array,
86
+ scores=scores_array,
87
+ boxes=boxes_array,
88
+ )
89
+
90
  def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
91
  """
92
  Run SAM3 segmentation on a frame.
 
127
  mask_threshold=self.mask_threshold,
128
  target_sizes=inputs.get("original_sizes").tolist(),
129
  )[0]
130
+ return self._parse_single_result(results, frame.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  except Exception:
133
  logging.exception("SAM3 post-processing failed")
 
137
  scores=None,
138
  boxes=None,
139
  )
140
+
141
+ def predict_batch(self, frames: Sequence[np.ndarray], text_prompts: Optional[list] = None) -> Sequence[SegmentationResult]:
142
+ pil_images = []
143
+ for f in frames:
144
+ if f.dtype == np.uint8:
145
+ pil_images.append(Image.fromarray(f))
146
+ else:
147
+ f_uint8 = (f * 255).astype(np.uint8)
148
+ pil_images.append(Image.fromarray(f_uint8))
149
+
150
+ prompts = text_prompts or ["object"]
151
+
152
+ # Same prompts for all images
153
+ inputs = self.processor(images=pil_images, text=[prompts]*len(frames), return_tensors="pt").to(self.device)
154
+
155
+ with torch.no_grad():
156
+ outputs = self.model(**inputs)
157
+
158
+ try:
159
+ results_list = self.processor.post_process_instance_segmentation(
160
+ outputs,
161
+ threshold=self.threshold,
162
+ mask_threshold=self.mask_threshold,
163
+ target_sizes=inputs.get("original_sizes").tolist(),
164
+ )
165
+ return [self._parse_single_result(r, f.shape) for r, f in zip(results_list, frames)]
166
+ except Exception:
167
+ logging.exception("SAM3 batch post-processing failed")
168
+ return [
169
+ SegmentationResult(
170
+ masks=np.zeros((0, f.shape[0], f.shape[1]), dtype=bool),
171
+ scores=None,
172
+ boxes=None
173
+ ) for f in frames
174
+ ]