Zhen Ye commited on
Commit
7f8fcb7
·
1 Parent(s): 43ec7b4

update:inference pipeline optimization

Browse files
Files changed (2) hide show
  1. inference.py +542 -404
  2. utils/video.py +77 -0
inference.py CHANGED
@@ -1,6 +1,8 @@
1
  import logging
2
  import os
3
- from threading import RLock
 
 
4
  from typing import Any, Dict, List, Optional, Sequence, Tuple
5
 
6
  import cv2
@@ -12,7 +14,7 @@ from models.detectors.base import ObjectDetector
12
  from models.model_loader import load_detector, load_detector_on_device
13
  from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
14
  from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
15
- from utils.video import extract_frames, write_video
16
 
17
 
18
  def _check_cancellation(job_id: Optional[str]) -> None:
@@ -27,7 +29,7 @@ def _check_cancellation(job_id: Optional[str]) -> None:
27
  raise RuntimeError("Job cancelled by user")
28
 
29
 
30
- def _color_for_label(label: str) -> tuple[int, int, int]:
31
  # Deterministic BGR color from label text.
32
  value = abs(hash(label)) % 0xFFFFFF
33
  blue = value & 0xFF
@@ -277,7 +279,7 @@ def infer_frame(
277
  depth_scale: float = 1.0,
278
  detector_instance: Optional[ObjectDetector] = None,
279
  depth_estimator_instance: Optional[Any] = None,
280
- ) -> tuple[np.ndarray, List[Dict[str, Any]]]:
281
  if detector_instance:
282
  detector = detector_instance
283
  else:
@@ -332,7 +334,7 @@ def infer_segmentation_frame(
332
  text_queries: Optional[List[str]] = None,
333
  segmenter_name: Optional[str] = None,
334
  segmenter_instance: Optional[Any] = None,
335
- ) -> tuple[np.ndarray, Any]:
336
  if segmenter_instance:
337
  segmenter = segmenter_instance
338
  # Use instance lock if available
@@ -406,175 +408,219 @@ def run_inference(
406
  job_id: Optional[str] = None,
407
  depth_estimator_name: Optional[str] = None,
408
  depth_scale: float = 1.0,
409
- ) -> tuple[str, List[List[Dict[str, Any]]]]:
410
- """
411
- Run object detection inference on a video.
412
-
413
- Args:
414
- input_video_path: Path to input video
415
- output_video_path: Path to write processed video
416
- queries: List of object classes to detect (e.g., ["person", "car"])
417
- max_frames: Optional frame limit for testing
418
- detector_name: Detector to use (default: hf_yolov8)
419
- job_id: Optional job ID for cancellation support
420
- depth_estimator_name: Optional depth estimator name
421
- depth_scale: Scale factor for depth estimation
422
- """
423
  try:
424
- frames, fps, width, height = extract_frames(input_video_path)
425
- except ValueError as exc:
426
- logging.exception("Failed to decode video at %s", input_video_path)
427
  raise
428
 
429
- # Use provided queries or default to common objects
 
 
 
 
 
 
 
 
430
  if not queries:
431
  queries = ["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"]
432
  logging.info("No queries provided, using defaults: %s", queries)
433
-
434
  logging.info("Detection queries: %s", queries)
435
-
436
- # Select detector
437
  active_detector = detector_name or "hf_yolov8"
438
- logging.info("Using detector: %s", active_detector)
439
-
440
- # Detect GPUs
441
- # Debug/Fix: Ensure internal restrictions don't hide GPUs
 
 
 
442
  if "CUDA_VISIBLE_DEVICES" in os.environ:
443
- logging.warning("Found CUDA_VISIBLE_DEVICES=%s in run_inference! Unsetting it.", os.environ["CUDA_VISIBLE_DEVICES"])
444
- del os.environ["CUDA_VISIBLE_DEVICES"]
445
 
446
- num_gpus = torch.cuda.device_count()
447
- detectors = None
448
- depth_estimators = None
449
-
450
- # DIAGNOSTICS
451
- logging.info("--- GPU DIAGNOSTICS ---")
452
- logging.info("Torch version: %s", torch.__version__)
453
- logging.info("CUDA available: %s", torch.cuda.is_available())
454
- logging.info("Device count: %d", torch.cuda.device_count())
455
- logging.info("Current device: %s", torch.cuda.current_device() if torch.cuda.is_available() else "N/A")
456
- for k, v in os.environ.items():
457
- if "CUDA" in k or "NVIDIA" in k:
458
- logging.info("Env %s=%s", k, v)
459
- logging.info("-----------------------")
460
-
461
- if num_gpus > 1:
462
- logging.info("Detected %d GPUs. Enabling Multi-GPU inference.", num_gpus)
463
- # Initialize one detector per GPU
464
- detectors = []
465
- depth_estimators = []
466
- for i in range(num_gpus):
467
- device_str = f"cuda:{i}"
468
- logging.info("Loading detector/depth on %s", device_str)
469
 
470
- # Detector
471
- det = load_detector_on_device(active_detector, device_str)
472
- det.lock = RLock()
473
- detectors.append(det)
474
-
475
- # Depth (if requested)
476
- if depth_estimator_name:
477
- depth = load_depth_estimator_on_device(depth_estimator_name, device_str)
478
- depth.lock = RLock()
479
  depth_estimators.append(depth)
480
- else:
481
- depth_estimators.append(None)
482
-
483
  else:
484
- logging.info("Single device detected. Using standard inference.")
485
- detectors = None
 
 
 
 
 
 
 
 
486
 
487
- processed_frames_map = {}
488
- all_detections_map = {}
 
 
 
489
 
490
- # Process frames
491
- if detectors:
492
- # Multi-GPU Parallel Processing
493
- def process_frame_task(frame_idx: int, frame_data: np.ndarray) -> tuple[int, np.ndarray, List[Dict[str, Any]]]:
494
- # Determine which GPU to use based on frame index (round-robin)
495
- gpu_idx = frame_idx % len(detectors)
496
- detector_instance = detectors[gpu_idx]
497
- depth_instance = depth_estimators[gpu_idx] if depth_estimators else None
 
 
 
 
498
 
499
  if frame_idx % 30 == 0:
500
- logging.info("Processing frame %d on GPU %d (cuda:%d)", frame_idx, gpu_idx, gpu_idx)
501
-
502
- # Run depth estimation every 3 frames if configured
503
- active_depth_name = depth_estimator_name if (frame_idx % 3 == 0) else None
504
- active_depth_instance = depth_instance if (frame_idx % 3 == 0) else None
505
-
506
- processed, frame_dets = infer_frame(
507
- frame_data,
508
- queries,
509
- detector_name=None, # Use instance
510
- depth_estimator_name=active_depth_name,
511
- depth_scale=depth_scale,
512
- detector_instance=detector_instance,
513
- depth_estimator_instance=active_depth_instance
514
- )
515
- return frame_idx, processed, frame_dets
516
 
517
- # Thread pool with more workers than GPUs to keep them fed
518
- max_workers = min(len(detectors) * 2, 8)
519
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
520
- futures = []
521
- for idx, frame in enumerate(frames):
522
- _check_cancellation(job_id)
523
- if max_frames is not None and idx >= max_frames:
524
- break
525
- futures.append(executor.submit(process_frame_task, idx, frame))
 
 
 
 
 
 
 
 
 
 
526
 
527
- for future in futures:
528
- idx, result_frame, result_dets = future.result() # Wait for completion (in order or not, but we verify order)
529
- processed_frames_map[idx] = result_frame
530
- all_detections_map[idx] = result_dets
531
-
532
- # Reasemble in order
533
- processed_frames = [processed_frames_map[i] for i in range(len(processed_frames_map))]
534
- all_detections = [all_detections_map[i] for i in range(len(all_detections_map))]
535
-
536
- else:
537
- # Standard Single-Threaded Loop
538
- # Pre-load models to ensure they are loaded once
539
- detector_instance = load_detector(active_detector)
540
- detector_instance.lock = _get_model_lock("detector", detector_instance.name)
 
 
 
 
 
 
 
 
 
 
541
 
542
- depth_estimator_instance = None
543
- if depth_estimator_name:
544
- depth_estimator_instance = load_depth_estimator(depth_estimator_name)
545
- depth_estimator_instance.lock = _get_model_lock("depth", depth_estimator_instance.name)
546
-
547
- processed_frames = []
548
- all_detections = []
549
- for idx, frame in enumerate(frames):
550
- # Check for cancellation every frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  _check_cancellation(job_id)
552
-
553
- if max_frames is not None and idx >= max_frames:
554
  break
555
- logging.debug("Processing frame %d", idx)
556
 
557
- # Run depth estimation every 3 frames if configured
558
- active_depth_name = depth_estimator_name if (idx % 3 == 0) else None
559
- active_depth_instance = depth_estimator_instance if (idx % 3 == 0) else None
560
 
561
- processed_frame, frame_dets = infer_frame(
562
- frame,
563
- queries,
564
- detector_name=None,
565
- depth_estimator_name=active_depth_name,
566
- depth_scale=depth_scale,
567
- detector_instance=detector_instance,
568
- depth_estimator_instance=active_depth_instance
569
- )
570
- processed_frames.append(processed_frame)
571
- all_detections.append(frame_dets)
 
 
 
 
572
 
573
- # Write output video
574
- write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
575
- logging.info("Processed video written to: %s", output_video_path)
 
 
 
 
 
 
576
 
577
- return output_video_path, all_detections
578
 
579
 
580
  def run_segmentation(
@@ -585,83 +631,139 @@ def run_segmentation(
585
  segmenter_name: Optional[str] = None,
586
  job_id: Optional[str] = None,
587
  ) -> str:
 
588
  try:
589
- frames, fps, width, height = extract_frames(input_video_path)
590
- except ValueError as exc:
591
- logging.exception("Failed to decode video at %s", input_video_path)
592
  raise
593
 
 
 
 
 
 
 
 
 
594
  active_segmenter = segmenter_name or "sam3"
595
  logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
596
 
597
- # Detect GPUs
598
  num_gpus = torch.cuda.device_count()
599
- segmenters = None
600
- if num_gpus > 1:
601
- logging.info("Detected %d GPUs. Enabling Multi-GPU segmentation.", num_gpus)
602
- segmenters = []
603
- for i in range(num_gpus):
604
- device_str = f"cuda:{i}"
605
- logging.info("Loading segmenter on %s", device_str)
 
 
606
  seg = load_segmenter_on_device(active_segmenter, device_str)
607
  seg.lock = RLock()
608
- segmenters.append(seg)
 
 
 
 
 
 
609
  else:
610
- logging.info("Single device detected. Using standard segmentation.")
611
- segmenters = None
 
612
 
613
- processed_frames_map = {}
614
-
615
- if segmenters:
616
- # Multi-GPU Parallel Processing
617
- def process_segmentation_task(frame_idx: int, frame_data: np.ndarray) -> tuple[int, np.ndarray]:
618
- gpu_idx = frame_idx % len(segmenters)
619
- segmenter_instance = segmenters[gpu_idx]
 
 
 
 
620
 
621
- if frame_idx % 30 == 0:
622
- logging.info("Segmenting frame %d on GPU %d (cuda:%d)", frame_idx, gpu_idx, gpu_idx)
623
 
624
- processed, _ = infer_segmentation_frame(
625
- frame_data,
626
- text_queries=queries,
627
- segmenter_name=None,
628
- segmenter_instance=segmenter_instance
629
- )
630
- return frame_idx, processed
631
 
632
- max_workers = min(len(segmenters) * 2, 8)
633
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
634
- futures = []
635
- for idx, frame in enumerate(frames):
636
- _check_cancellation(job_id)
637
- if max_frames is not None and idx >= max_frames:
638
- break
639
- futures.append(executor.submit(process_segmentation_task, idx, frame))
 
 
 
640
 
641
- for future in futures:
642
- idx, result_frame = future.result()
643
- processed_frames_map[idx] = result_frame
644
-
645
- processed_frames = [processed_frames_map[i] for i in range(len(processed_frames_map))]
646
 
647
- else:
648
- processed_frames: List[np.ndarray] = []
649
- for idx, frame in enumerate(frames):
650
- # Check for cancellation every frame
651
- _check_cancellation(job_id)
652
 
653
- if max_frames is not None and idx >= max_frames:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  break
655
- logging.debug("Processing frame %d", idx)
656
- processed_frame, _ = infer_segmentation_frame(frame, text_queries=queries, segmenter_name=active_segmenter)
657
- processed_frames.append(processed_frame)
658
-
659
- write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
 
 
 
 
 
 
660
  logging.info("Segmented video written to: %s", output_video_path)
661
-
662
  return output_video_path
663
 
664
 
 
665
  def run_depth_inference(
666
  input_video_path: str,
667
  output_video_path: str,
@@ -671,227 +773,263 @@ def run_depth_inference(
671
  first_frame_depth_path: Optional[str] = None,
672
  job_id: Optional[str] = None,
673
  ) -> str:
674
- """
675
- Run depth estimation on a video.
676
-
677
- Args:
678
- input_video_path: Path to input video
679
- output_video_path: Path to write depth visualization video
680
- max_frames: Optional frame limit for testing
681
- depth_estimator_name: Depth estimator to use (default: depth)
682
- first_frame_depth_path: Optional path to save the first depth visualization frame
683
- job_id: Optional job ID for cancellation support
684
-
685
- Returns:
686
- Path to depth visualization video
687
- """
688
  try:
689
- frames, fps, width, height = extract_frames(input_video_path)
690
- except ValueError as exc:
691
- logging.exception("Failed to decode video at %s", input_video_path)
692
  raise
693
 
694
- logging.info("Using depth estimator: %s", depth_estimator_name)
695
-
696
- # Limit frames if requested
 
 
697
  if max_frames is not None:
698
- frames = frames[:max_frames]
699
-
700
- # Process depth with stable normalization and overlay
701
- processed_frames = process_frames_depth(frames, depth_estimator_name, detections=detections, job_id=job_id)
702
-
703
- # Write output video
704
- write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
705
- logging.info("Depth video written to: %s", output_video_path)
706
-
707
- if first_frame_depth_path and processed_frames:
708
- import cv2
709
-
710
- if not cv2.imwrite(first_frame_depth_path, processed_frames[0]):
711
- logging.warning("Failed to write first frame depth image to: %s", first_frame_depth_path)
712
-
713
- return output_video_path
714
-
715
-
716
- def process_frames_depth(
717
- frames: List[np.ndarray],
718
- depth_estimator_name: str,
719
- detections: Optional[List[List[Dict[str, Any]]]] = None,
720
- job_id: Optional[str] = None,
721
- ) -> List[np.ndarray]:
722
- """
723
- Process all frames through depth estimator with stable normalization.
724
-
725
- Two-pass approach:
726
- 1. Compute depth for all frames and find global min/max
727
- 2. Colorize using global range to avoid flicker
728
-
729
- Args:
730
- frames: List of frames (HxWx3 BGR uint8)
731
- depth_estimator_name: Name of depth estimator to use
732
- job_id: Optional job ID for cancellation
733
-
734
- Returns:
735
- List of depth visualization frames (HxWx3 RGB uint8)
736
- """
737
- from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
738
 
739
- # Detect GPUs
740
  num_gpus = torch.cuda.device_count()
741
- estimators = None
742
- if num_gpus > 1:
743
- logging.info("Detected %d GPUs. Enabling Multi-GPU depth estimation.", num_gpus)
744
- estimators = []
745
- for i in range(num_gpus):
746
- device_str = f"cuda:{i}"
747
- logging.info("Loading depth estimator on %s", device_str)
 
 
748
  est = load_depth_estimator_on_device(depth_estimator_name, device_str)
749
  est.lock = RLock()
750
- estimators.append(est)
 
 
 
 
 
 
751
  else:
752
- logging.info("Single device detected. Using standard depth estimation.")
753
- estimators = None
754
- # Fallback to single estimator
755
- single_estimator = load_depth_estimator(depth_estimator_name)
756
-
757
- # First pass: Compute all depth maps and find global range
758
- depth_maps_map = {}
759
- all_values = []
760
-
761
- if estimators:
762
- # Multi-GPU Parallel Processing
763
- def compute_depth_task(frame_idx: int, frame_data: np.ndarray) -> tuple[int, Any]:
764
- gpu_idx = frame_idx % len(estimators)
765
- estimator_instance = estimators[gpu_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766
 
767
- if frame_idx % 30 == 0:
768
- logging.info("Estimating depth for frame %d on GPU %d (cuda:%d)", frame_idx, gpu_idx, gpu_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
- # Use instance lock
771
- if hasattr(estimator_instance, "lock"):
772
- lock = estimator_instance.lock
773
- else:
774
- # Should have been assigned above
775
- lock = RLock()
 
 
 
 
 
 
 
 
 
 
 
 
 
776
 
777
- with lock:
778
- result = estimator_instance.predict(frame_data)
779
- return frame_idx, result
780
-
781
- max_workers = min(len(estimators) * 2, 8)
782
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
783
- futures = []
784
- for idx, frame in enumerate(frames):
785
- _check_cancellation(job_id)
786
- futures.append(executor.submit(compute_depth_task, idx, frame))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
 
788
- for future in futures:
789
- idx, res = future.result()
790
- depth_maps_map[idx] = res.depth_map
791
- # We need to collect values for global min/max.
792
- # Doing this here or later? doing it later to keep thread clean
 
793
 
794
- # Reassemble
795
- depth_maps = [depth_maps_map[i] for i in range(len(depth_maps_map))]
796
- all_values = [dm.ravel() for dm in depth_maps]
797
-
798
- else:
799
- # Single threaded
800
- estimator = single_estimator
801
- depth_maps = []
802
- for idx, frame in enumerate(frames):
803
- _check_cancellation(job_id)
804
-
805
- lock = _get_model_lock("depth", estimator.name)
806
- with lock:
807
- depth_result = estimator.predict(frame)
808
-
809
- depth_maps.append(depth_result.depth_map)
810
- all_values.append(depth_result.depth_map.ravel())
811
-
812
- if idx % 10 == 0:
813
- logging.debug("Computed depth for frame %d/%d", idx + 1, len(frames))
814
-
815
- # Compute global min/max (using percentiles to handle outliers)
816
- all_depths = np.concatenate(all_values).astype(np.float32, copy=False)
817
-
818
- # Filter out NaN and inf values
819
- valid_depths = all_depths[np.isfinite(all_depths)]
820
-
821
- if len(valid_depths) == 0:
822
- logging.warning("All depth values are NaN/inf - using fallback range")
823
- global_min = 0.0
824
- global_max = 1.0
825
- else:
826
- valid_depths = valid_depths.astype(np.float64, copy=False)
827
- global_min = float(np.percentile(valid_depths, 1)) # 1st percentile to clip outliers
828
- global_max = float(np.percentile(valid_depths, 99)) # 99th percentile
829
-
830
- if not np.isfinite(global_min) or not np.isfinite(global_max):
831
- logging.warning("Depth percentiles are non-finite - using min/max fallback")
832
- global_min = float(valid_depths.min())
833
- global_max = float(valid_depths.max())
834
-
835
- # Handle edge case where min == max
836
- if abs(global_max - global_min) < 1e-6:
837
- global_min = float(valid_depths.min())
838
- global_max = float(valid_depths.max())
839
- if abs(global_max - global_min) < 1e-6:
840
- global_max = global_min + 1.0
841
-
842
- logging.info(
843
- "Depth range: %.2f - %.2f meters (1st-99th percentile)",
844
- global_min,
845
- global_max,
846
- )
847
-
848
- # Second pass: Apply colormap and overlay detections
849
- visualization_frames = []
850
 
851
- # draw_boxes is defined in this module, so we can use it directly.
852
- # Ensure cv2 is imported
853
- import cv2
854
-
855
- for i, depth_map in enumerate(depth_maps):
856
- _check_cancellation(job_id)
857
-
858
- # Norm: (val - min) / (max - min) -> 0..1
859
- # Clip to ensure range
860
- norm_map = np.clip(depth_map, global_min, global_max)
861
- norm_map = (norm_map - global_min) / (global_max - global_min + 1e-6)
862
 
863
- # Invert intensity? Usually Near(High val) -> Bright(1.0).
864
- # Our val is high for near. So direct map is fine.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
 
866
- # Colorize
867
- norm_map_u8 = (norm_map * 255).astype(np.uint8)
868
- heatmap = cv2.applyColorMap(norm_map_u8, cv2.COLORMAP_INFERNO)
869
 
870
- # Overlay detections if available
871
- if detections and i < len(detections):
872
- frame_dets = detections[i]
873
- # Convert list of dicts to format for draw_boxes
874
- if frame_dets:
875
- boxes = []
876
- labels = []
877
- display_labels = []
878
-
879
- for d in frame_dets:
880
- boxes.append(d.get("bbox"))
881
- # Create label "Class Dist"
882
- lbl = d.get("label", "obj")
883
- # If we have depth info that was calculated in inference:
884
- if d.get("depth_est_m"):
885
- lbl = f"{lbl} {int(d['depth_est_m'])}m"
886
-
887
- labels.append(lbl) # used for color
888
- display_labels.append(lbl)
889
-
890
- heatmap = draw_boxes(heatmap, boxes, label_names=display_labels)
891
 
892
- visualization_frames.append(heatmap)
 
 
893
 
894
- return visualization_frames
895
 
896
 
897
  def colorize_depth_map(
 
1
  import logging
2
  import os
3
+ import time
4
+ from threading import RLock, Thread
5
+ from queue import Queue, PriorityQueue
6
  from typing import Any, Dict, List, Optional, Sequence, Tuple
7
 
8
  import cv2
 
14
  from models.model_loader import load_detector, load_detector_on_device
15
  from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
16
  from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
17
+ from utils.video import extract_frames, write_video, VideoReader, VideoWriter
18
 
19
 
20
  def _check_cancellation(job_id: Optional[str]) -> None:
 
29
  raise RuntimeError("Job cancelled by user")
30
 
31
 
32
+ def _color_for_label(label: str) -> Tuple[int, int, int]:
33
  # Deterministic BGR color from label text.
34
  value = abs(hash(label)) % 0xFFFFFF
35
  blue = value & 0xFF
 
279
  depth_scale: float = 1.0,
280
  detector_instance: Optional[ObjectDetector] = None,
281
  depth_estimator_instance: Optional[Any] = None,
282
+ ) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
283
  if detector_instance:
284
  detector = detector_instance
285
  else:
 
334
  text_queries: Optional[List[str]] = None,
335
  segmenter_name: Optional[str] = None,
336
  segmenter_instance: Optional[Any] = None,
337
+ ) -> Tuple[np.ndarray, Any]:
338
  if segmenter_instance:
339
  segmenter = segmenter_instance
340
  # Use instance lock if available
 
408
  job_id: Optional[str] = None,
409
  depth_estimator_name: Optional[str] = None,
410
  depth_scale: float = 1.0,
411
+ ) -> Tuple[str, List[List[Dict[str, Any]]]]:
412
+
413
+ # 1. Setup Video Reader
 
 
 
 
 
 
 
 
 
 
 
414
  try:
415
+ reader = VideoReader(input_video_path)
416
+ except ValueError:
417
+ logging.exception("Failed to open video at %s", input_video_path)
418
  raise
419
 
420
+ fps = reader.fps
421
+ width = reader.width
422
+ height = reader.height
423
+ total_frames = reader.total_frames
424
+
425
+ if max_frames is not None:
426
+ total_frames = min(total_frames, max_frames)
427
+
428
+ # 2. Defaults and Config
429
  if not queries:
430
  queries = ["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"]
431
  logging.info("No queries provided, using defaults: %s", queries)
432
+
433
  logging.info("Detection queries: %s", queries)
 
 
434
  active_detector = detector_name or "hf_yolov8"
435
+
436
+ # 3. Parallel Model Loading
437
+ num_gpus = torch.cuda.device_count()
438
+ detectors = []
439
+ depth_estimators = []
440
+
441
+ # Clear CUDA_VISIBLE_DEVICES to ensure we see all GPUs if not already handled
442
  if "CUDA_VISIBLE_DEVICES" in os.environ:
443
+ del os.environ["CUDA_VISIBLE_DEVICES"]
 
444
 
445
+ if num_gpus > 0:
446
+ logging.info("Detected %d GPUs. Loading models in parallel...", num_gpus)
447
+
448
+ def load_models_on_gpu(gpu_id: int):
449
+ device_str = f"cuda:{gpu_id}"
450
+ try:
451
+ det = load_detector_on_device(active_detector, device_str)
452
+ det.lock = RLock()
453
+
454
+ depth = None
455
+ if depth_estimator_name:
456
+ depth = load_depth_estimator_on_device(depth_estimator_name, device_str)
457
+ depth.lock = RLock()
458
+ return (gpu_id, det, depth)
459
+ except Exception as e:
460
+ logging.error(f"Failed to load models on GPU {gpu_id}: {e}")
461
+ raise
462
+
463
+ with ThreadPoolExecutor(max_workers=num_gpus) as loader_pool:
464
+ futures = [loader_pool.submit(load_models_on_gpu, i) for i in range(num_gpus)]
465
+ results = [f.result() for f in futures]
 
 
466
 
467
+ # Sort by GPU ID to ensure consistent indexing
468
+ results.sort(key=lambda x: x[0])
469
+ for _, det, depth in results:
470
+ detectors.append(det)
 
 
 
 
 
471
  depth_estimators.append(depth)
 
 
 
472
  else:
473
+ logging.info("No GPUs detected. Loading CPU models...")
474
+ det = load_detector(active_detector)
475
+ det.lock = RLock()
476
+ detectors.append(det)
477
+ if depth_estimator_name:
478
+ depth = load_depth_estimator(depth_estimator_name)
479
+ depth.lock = RLock()
480
+ depth_estimators.append(depth)
481
+ else:
482
+ depth_estimators.append(None)
483
 
484
+ # 4. Processing Queues
485
+ # queue_in: (frame_idx, frame_data)
486
+ # queue_out: (frame_idx, processed_frame, detections)
487
+ queue_in = Queue(maxsize=16)
488
+ queue_out = Queue() # Unbounded, consumed by writer
489
 
490
+ # 5. Worker Function
491
+ def worker_task(gpu_idx: int):
492
+ detector_instance = detectors[gpu_idx]
493
+ depth_instance = depth_estimators[gpu_idx] if depth_estimators[gpu_idx] else None
494
+
495
+ while True:
496
+ item = queue_in.get()
497
+ if item is None:
498
+ queue_in.task_done()
499
+ break
500
+
501
+ frame_idx, frame_data = item
502
 
503
  if frame_idx % 30 == 0:
504
+ logging.info("Processing frame %d on device %s", frame_idx, "cpu" if num_gpus==0 else f"cuda:{gpu_idx}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
+ try:
507
+ # Depth strategy: Run every 3 frames
508
+ active_depth_name = depth_estimator_name if (frame_idx % 3 == 0) else None
509
+ active_depth_instance = depth_instance if (frame_idx % 3 == 0) else None
510
+
511
+ processed, frame_dets = infer_frame(
512
+ frame_data,
513
+ queries,
514
+ detector_name=None,
515
+ depth_estimator_name=active_depth_name,
516
+ depth_scale=depth_scale,
517
+ detector_instance=detector_instance,
518
+ depth_estimator_instance=active_depth_instance
519
+ )
520
+ queue_out.put((frame_idx, processed, frame_dets))
521
+ except Exception as e:
522
+ logging.exception("Error processing frame %d", frame_idx)
523
+ # Put placeholders to avoid hanging writer
524
+ queue_out.put((frame_idx, frame_data, []))
525
 
526
+ queue_in.task_done()
527
+
528
+ # 6. Start Workers
529
+ workers = []
530
+ num_workers = len(detectors)
531
+ # If using CPU, maybe use more threads? No, CPU models usually multithread internally.
532
+ # If using GPU, 1 thread per GPU is efficient.
533
+ for i in range(num_workers):
534
+ t = Thread(target=worker_task, args=(i,), daemon=True)
535
+ t.start()
536
+ workers.append(t)
537
+
538
+ # 7. Start Writer / Output Collection (Main Thread or separate)
539
+ # We will run writer logic in the main thread after feeding is done?
540
+ # No, we must write continuously.
541
+
542
+ all_detections_map = {}
543
+
544
+ writer_finished = False
545
+
546
+ def writer_loop():
547
+ nonlocal writer_finished
548
+ next_idx = 0
549
+ buffer = {}
550
 
551
+ try:
552
+ with VideoWriter(output_video_path, fps, width, height) as writer:
553
+ while next_idx < total_frames:
554
+ # Fetch from queue
555
+ try:
556
+ while next_idx not in buffer:
557
+ item = queue_out.get(timeout=1.0) # wait
558
+ idx, p_frame, dets = item
559
+ buffer[idx] = (p_frame, dets)
560
+
561
+ # Write next_idx
562
+ p_frame, dets = buffer.pop(next_idx)
563
+ writer.write(p_frame)
564
+ all_detections_map[next_idx] = dets
565
+ next_idx += 1
566
+
567
+ if next_idx % 30 == 0:
568
+ logging.debug("Wrote frame %d/%d", next_idx, total_frames)
569
+
570
+ except Exception as e:
571
+ # Check cancellation or timeout
572
+ if job_id and _check_cancellation(job_id): # This raises
573
+ pass
574
+ if not any(w.is_alive() for w in workers) and queue_out.empty():
575
+ # Workers dead, queue empty, but not finished? prevent infinite loop
576
+ logging.error("Workers stopped unexpectedly.")
577
+ break
578
+ continue
579
+ except Exception as e:
580
+ logging.exception("Writer loop failed")
581
+ finally:
582
+ writer_finished = True
583
+
584
+ writer_thread = Thread(target=writer_loop, daemon=True)
585
+ writer_thread.start()
586
+
587
+ # 8. Feed Frames (Main Thread)
588
+ try:
589
+ frames_fed = 0
590
+ for i, frame in enumerate(reader):
591
  _check_cancellation(job_id)
592
+ if max_frames is not None and i >= max_frames:
 
593
  break
 
594
 
595
+ queue_in.put((i, frame)) # Blocks if full
596
+ frames_fed += 1
 
597
 
598
+ # Signal workers to stop
599
+ for _ in range(num_workers):
600
+ queue_in.put(None)
601
+
602
+ # Wait for queue to process
603
+ queue_in.join()
604
+
605
+ except Exception as e:
606
+ logging.exception("Feeding frames failed")
607
+ raise
608
+ finally:
609
+ reader.close()
610
+
611
+ # Wait for writer
612
+ writer_thread.join()
613
 
614
+ # Sort detections
615
+ sorted_detections = []
616
+ # If we crashed early, we return what we have
617
+ max_key = max(all_detections_map.keys()) if all_detections_map else -1
618
+ for i in range(max_key + 1):
619
+ sorted_detections.append(all_detections_map.get(i, []))
620
+
621
+ logging.info("Inference complete. Output: %s", output_video_path)
622
+ return output_video_path, sorted_detections
623
 
 
624
 
625
 
626
  def run_segmentation(
 
631
  segmenter_name: Optional[str] = None,
632
  job_id: Optional[str] = None,
633
  ) -> str:
634
+ # 1. Setup Reader
635
  try:
636
+ reader = VideoReader(input_video_path)
637
+ except ValueError:
638
+ logging.exception("Failed to open video at %s", input_video_path)
639
  raise
640
 
641
+ fps = reader.fps
642
+ width = reader.width
643
+ height = reader.height
644
+ total_frames = reader.total_frames
645
+
646
+ if max_frames is not None:
647
+ total_frames = min(total_frames, max_frames)
648
+
649
  active_segmenter = segmenter_name or "sam3"
650
  logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
651
 
652
+ # 2. Load Segmenters (Parallel)
653
  num_gpus = torch.cuda.device_count()
654
+ segmenters = []
655
+
656
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
657
+ del os.environ["CUDA_VISIBLE_DEVICES"]
658
+
659
+ if num_gpus > 0:
660
+ logging.info("Detected %d GPUs. Loading segmenters...", num_gpus)
661
+ def load_seg(gpu_id: int):
662
+ device_str = f"cuda:{gpu_id}"
663
  seg = load_segmenter_on_device(active_segmenter, device_str)
664
  seg.lock = RLock()
665
+ return (gpu_id, seg)
666
+
667
+ with ThreadPoolExecutor(max_workers=num_gpus) as loader:
668
+ futures = [loader.submit(load_seg, i) for i in range(num_gpus)]
669
+ results = [f.result() for f in futures]
670
+ results.sort(key=lambda x: x[0])
671
+ segmenters = [r[1] for r in results]
672
  else:
673
+ seg = load_segmenter(active_segmenter)
674
+ seg.lock = RLock()
675
+ segmenters.append(seg)
676
 
677
+ # 3. Processing
678
+ queue_in = Queue(maxsize=16)
679
+ queue_out = Queue()
680
+
681
+ def worker_seg(gpu_idx: int):
682
+ seg = segmenters[gpu_idx]
683
+ while True:
684
+ item = queue_in.get()
685
+ if item is None:
686
+ queue_in.task_done()
687
+ break
688
 
689
+ idx, frame = item
 
690
 
691
+ if idx % 30 == 0:
692
+ logging.info("Segmenting frame %d (GPU %d)", idx, gpu_idx)
 
 
 
 
 
693
 
694
+ try:
695
+ processed, _ = infer_segmentation_frame(
696
+ frame,
697
+ text_queries=queries,
698
+ segmenter_name=None,
699
+ segmenter_instance=seg
700
+ )
701
+ queue_out.put((idx, processed))
702
+ except Exception as e:
703
+ logging.error("Segmentation failed frame %d: %s", idx, e)
704
+ queue_out.put((idx, frame))
705
 
706
+ queue_in.task_done()
 
 
 
 
707
 
708
+ workers = []
709
+ for i in range(len(segmenters)):
710
+ t = Thread(target=worker_seg, args=(i,), daemon=True)
711
+ t.start()
712
+ workers.append(t)
713
 
714
+ # Writer
715
+ writer_finished = False
716
+
717
+ def writer_loop():
718
+ nonlocal writer_finished
719
+ next_idx = 0
720
+ buffer = {}
721
+
722
+ try:
723
+ with VideoWriter(output_video_path, fps, width, height) as writer:
724
+ while next_idx < total_frames:
725
+ try:
726
+ while next_idx not in buffer:
727
+ idx, frm = queue_out.get(timeout=1.0)
728
+ buffer[idx] = frm
729
+
730
+ frm = buffer.pop(next_idx)
731
+ writer.write(frm)
732
+ next_idx += 1
733
+ except Exception:
734
+ if job_id and _check_cancellation(job_id): pass
735
+ if not any(w.is_alive() for w in workers) and queue_out.empty():
736
+ break
737
+ continue
738
+ finally:
739
+ writer_finished = True
740
+
741
+ w_thread = Thread(target=writer_loop, daemon=True)
742
+ w_thread.start()
743
+
744
+ # Feeder
745
+ try:
746
+ reader = VideoReader(input_video_path)
747
+ for i, frame in enumerate(reader):
748
+ _check_cancellation(job_id)
749
+ if max_frames is not None and i >= max_frames:
750
  break
751
+ queue_in.put((i, frame))
752
+
753
+ for _ in workers:
754
+ queue_in.put(None)
755
+ queue_in.join()
756
+
757
+ finally:
758
+ reader.close()
759
+
760
+ w_thread.join()
761
+
762
  logging.info("Segmented video written to: %s", output_video_path)
 
763
  return output_video_path
764
 
765
 
766
+
767
  def run_depth_inference(
768
  input_video_path: str,
769
  output_video_path: str,
 
773
  first_frame_depth_path: Optional[str] = None,
774
  job_id: Optional[str] = None,
775
  ) -> str:
776
+ # 1. Setup Reader
 
 
 
 
 
 
 
 
 
 
 
 
 
777
  try:
778
+ reader = VideoReader(input_video_path)
779
+ except ValueError:
780
+ logging.exception("Failed to open video at %s", input_video_path)
781
  raise
782
 
783
+ fps = reader.fps
784
+ width = reader.width
785
+ height = reader.height
786
+ total_frames = reader.total_frames
787
+
788
  if max_frames is not None:
789
+ total_frames = min(total_frames, max_frames)
790
+
791
+ logging.info("Using depth estimator: %s", depth_estimator_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
792
 
793
+ # 2. Load Estimators (Parallel)
794
  num_gpus = torch.cuda.device_count()
795
+ estimators = []
796
+
797
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
798
+ del os.environ["CUDA_VISIBLE_DEVICES"]
799
+
800
+ if num_gpus > 0:
801
+ logging.info("Detected %d GPUs. Loading depth estimators...", num_gpus)
802
+ def load_est(gpu_id: int):
803
+ device_str = f"cuda:{gpu_id}"
804
  est = load_depth_estimator_on_device(depth_estimator_name, device_str)
805
  est.lock = RLock()
806
+ return (gpu_id, est)
807
+
808
+ with ThreadPoolExecutor(max_workers=num_gpus) as loader:
809
+ futures = [loader.submit(load_est, i) for i in range(num_gpus)]
810
+ results = [f.result() for f in futures]
811
+ results.sort(key=lambda x: x[0])
812
+ estimators = [r[1] for r in results]
813
  else:
814
+ est = load_depth_estimator(depth_estimator_name)
815
+ est.lock = RLock()
816
+ estimators.append(est)
817
+
818
+ # 3. Phase 1: Pre-scan for Stats
819
+ # We sample ~5% of frames or at least 20 frames distributed evenly
820
+ stride = max(1, total_frames // 20)
821
+ logging.info("Starting Phase 1: Pre-scan (stride=%d)...", stride)
822
+
823
+ scan_values = []
824
+
825
+ def scan_task(gpu_idx: int, frame_data: np.ndarray):
826
+ est = estimators[gpu_idx]
827
+ with est.lock:
828
+ result = est.predict(frame_data)
829
+ return result.depth_map
830
+
831
+ # Run scan
832
+ # We can just run this sequentially or with pool? Pool is better.
833
+ # We need to construct a list of frames to scan.
834
+ scan_indices = list(range(0, total_frames, stride))
835
+
836
+ # We need to read specific frames. VideoReader is sequential.
837
+ # So we iterate and skip.
838
+ scan_frames = []
839
+
840
+ # Optimization: If total frames is huge, reading simply to skip might be slow?
841
+ # VideoReader uses cv2.read() which decodes.
842
+ # If we need random access, we should use set(cv2.CAP_PROP_POS_FRAMES).
843
+ # But for now, simple skip logic:
844
+
845
+ current_idx = 0
846
+ # To avoid re-opening multiple times or complex seeking, let's just use the Reader
847
+ # and skip if not in indices.
848
+ # BUT, if video is 1 hour, skipping 99% frames is wastage of decode.
849
+ # Re-opening with set POS is better for sparse sampling.
850
+
851
+ # Actually, for robustness, let's just stick to VideoReader sequential read but only process selective frames.
852
+ # If the video is truly huge, we might want to optimize this later.
853
+ # Given the constraints, let's just scan the first N frames + some middle ones?
854
+ # User agreed to "Small startup delay".
855
+
856
+ # Let's try to just grab the frames we want.
857
+ scan_frames_data = []
858
+
859
+ # Just grab first 50 frames? No, distribution is better.
860
+ # Let's use a temporary reader for scanning
861
+
862
+ try:
863
+ from concurrent.futures import as_completed
864
+
865
+ # Simple Approach: Process first 30 frames to get a baseline.
866
+ # This is usually enough for a "rough" estimation unless scenes change drastically.
867
+ # But for stability, spread is better.
868
+
869
+ # Let's read first 10, middle 10, last 10.
870
+ target_indices = set(list(range(0, 10)) +
871
+ list(range(total_frames//2, total_frames//2 + 10)) +
872
+ list(range(max(0, total_frames-10), total_frames)))
873
+
874
+ # Filter valid
875
+ target_indices = sorted([i for i in target_indices if i < total_frames])
876
+
877
+ # Manual read with seek is tricky with cv2 (unreliable keyframes).
878
+ # We will iterate and pick.
879
+
880
+ cnt = 0
881
+ reader_scan = VideoReader(input_video_path)
882
+ for i, frame in enumerate(reader_scan):
883
+ if i in target_indices:
884
+ scan_frames_data.append(frame)
885
+ if i > max(target_indices):
886
+ break
887
+ reader_scan.close()
888
+
889
+ # Run inference on these frames
890
+ with ThreadPoolExecutor(max_workers=min(len(estimators)*2, 8)) as pool:
891
+ futures = []
892
+ for i, frm in enumerate(scan_frames_data):
893
+ gpu = i % len(estimators)
894
+ futures.append(pool.submit(scan_task, gpu, frm))
895
 
896
+ for f in as_completed(futures):
897
+ dm = f.result()
898
+ scan_values.append(dm)
899
+
900
+ except Exception as e:
901
+ logging.warning("Pre-scan failed, falling back to default range: %s", e)
902
+
903
+ # Compute stats
904
+ global_min, global_max = 0.0, 1.0
905
+ if scan_values:
906
+ all_vals = np.concatenate([v.ravel() for v in scan_values])
907
+ valid = all_vals[np.isfinite(all_vals)]
908
+ if valid.size > 0:
909
+ global_min = float(np.percentile(valid, 1))
910
+ global_max = float(np.percentile(valid, 99))
911
 
912
+ # Safety
913
+ if abs(global_max - global_min) < 1e-6:
914
+ global_max = global_min + 1.0
915
+
916
+ logging.info("Global Depth Range: %.2f - %.2f", global_min, global_max)
917
+
918
+ # 4. Phase 2: Streaming Inference
919
+ logging.info("Starting Phase 2: Streaming...")
920
+
921
+ queue_in = Queue(maxsize=16)
922
+ queue_out = Queue()
923
+
924
+ def worker_depth(gpu_idx: int):
925
+ est = estimators[gpu_idx]
926
+ while True:
927
+ item = queue_in.get()
928
+ if item is None:
929
+ queue_in.task_done()
930
+ break
931
 
932
+ idx, frame = item
933
+ try:
934
+ if idx % 30 == 0:
935
+ logging.info("Depth frame %d (GPU %d)", idx, gpu_idx)
936
+
937
+ with est.lock:
938
+ res = est.predict(frame)
939
+
940
+ depth_map = res.depth_map
941
+ # Colorize
942
+ colored = colorize_depth_map(depth_map, global_min, global_max)
943
+
944
+ # Overlay Detections
945
+ # Detections list is [ [det1, det2], [det1, det2] ... ]
946
+ if detections and idx < len(detections):
947
+ frame_dets = detections[idx]
948
+ if frame_dets:
949
+ import cv2
950
+ boxes = []
951
+ labels = []
952
+ for d in frame_dets:
953
+ boxes.append(d.get("bbox"))
954
+ lbl = d.get("label", "obj")
955
+ if d.get("depth_est_m"):
956
+ lbl = f"{lbl} {int(d['depth_est_m'])}m"
957
+ labels.append(lbl)
958
+ colored = draw_boxes(colored, boxes=boxes, label_names=labels)
959
+
960
+ queue_out.put((idx, colored))
961
+ except Exception as e:
962
+ logging.error("Depth worker failed frame %d: %s", idx, e)
963
+ queue_out.put((idx, frame)) # Fallback to original?
964
+
965
+ queue_in.task_done()
966
 
967
+ # Workers
968
+ workers = []
969
+ for i in range(len(estimators)):
970
+ t = Thread(target=worker_depth, args=(i,), daemon=True)
971
+ t.start()
972
+ workers.append(t)
973
 
974
+ # Writer
975
+ writer_finished = False
976
+ first_frame_saved = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
977
 
978
+ def writer_loop():
979
+ nonlocal writer_finished, first_frame_saved
980
+ next_idx = 0
981
+ buffer = {}
982
+ processed_frames_subset = [] # Keep first frame for saving if needed
 
 
 
 
 
 
983
 
984
+ try:
985
+ with VideoWriter(output_video_path, fps, width, height) as writer:
986
+ while next_idx < total_frames:
987
+ try:
988
+ while next_idx not in buffer:
989
+ idx, frm = queue_out.get(timeout=1.0)
990
+ buffer[idx] = frm
991
+
992
+ frm = buffer.pop(next_idx)
993
+ writer.write(frm)
994
+
995
+ if first_frame_depth_path and not first_frame_saved and next_idx == 0:
996
+ cv2.imwrite(first_frame_depth_path, frm)
997
+ first_frame_saved = True
998
+
999
+ next_idx += 1
1000
+ if next_idx % 30 == 0:
1001
+ logging.debug("Wrote depth frame %d/%d", next_idx, total_frames)
1002
+ except Exception:
1003
+ if job_id and _check_cancellation(job_id): pass
1004
+ if not any(w.is_alive() for w in workers) and queue_out.empty():
1005
+ break
1006
+ continue
1007
+ finally:
1008
+ writer_finished = True
1009
+
1010
+ w_thread = Thread(target=writer_loop, daemon=True)
1011
+ w_thread.start()
1012
+
1013
+ # Feeder
1014
+ try:
1015
+ reader = VideoReader(input_video_path)
1016
+ for i, frame in enumerate(reader):
1017
+ _check_cancellation(job_id)
1018
+ if max_frames is not None and i >= max_frames:
1019
+ break
1020
+ queue_in.put((i, frame))
1021
 
1022
+ for _ in workers:
1023
+ queue_in.put(None)
1024
+ queue_in.join()
1025
 
1026
+ finally:
1027
+ reader.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1028
 
1029
+ w_thread.join()
1030
+
1031
+ return output_video_path
1032
 
 
1033
 
1034
 
1035
  def colorize_depth_map(
utils/video.py CHANGED
@@ -77,3 +77,80 @@ def write_video(frames: List[np.ndarray], output_path: str, fps: float, width: i
77
  except RuntimeError as exc:
78
  logging.warning("ffmpeg transcode failed (%s); serving fallback MP4V output.", exc)
79
  shutil.move(temp_path, output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except RuntimeError as exc:
78
  logging.warning("ffmpeg transcode failed (%s); serving fallback MP4V output.", exc)
79
  shutil.move(temp_path, output_path)
80
+
81
+ class VideoReader:
82
+ def __init__(self, video_path: str):
83
+ self.video_path = video_path
84
+ self.cap = cv2.VideoCapture(video_path)
85
+ if not self.cap.isOpened():
86
+ raise ValueError("Unable to open video.")
87
+
88
+ self.fps = self.cap.get(cv2.CAP_PROP_FPS) or 30.0
89
+ self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
90
+ self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
91
+ self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
92
+
93
+ def __iter__(self):
94
+ return self
95
+
96
+ def __next__(self) -> np.ndarray:
97
+ if not self.cap.isOpened():
98
+ raise StopIteration
99
+
100
+ success, frame = self.cap.read()
101
+ if not success:
102
+ self.cap.release()
103
+ raise StopIteration
104
+ return frame
105
+
106
+ def close(self):
107
+ if self.cap.isOpened():
108
+ self.cap.release()
109
+
110
+ def __enter__(self):
111
+ return self
112
+
113
+ def __exit__(self, exc_type, exc_val, exc_tb):
114
+ self.close()
115
+
116
+
117
+ class VideoWriter:
118
+ def __init__(self, output_path: str, fps: float, width: int, height: int):
119
+ self.output_path = output_path
120
+ self.fps = fps
121
+ self.width = width
122
+ self.height = height
123
+
124
+ self.temp_fd, self.temp_path = tempfile.mkstemp(prefix="raw_", suffix=".mp4")
125
+ os.close(self.temp_fd)
126
+
127
+ # Use mp4v for speed during writing, then transcode
128
+ self.writer = cv2.VideoWriter(self.temp_path, cv2.VideoWriter_fourcc(*"mp4v"), self.fps, (self.width, self.height))
129
+ if not self.writer.isOpened():
130
+ os.remove(self.temp_path)
131
+ raise ValueError("Failed to open VideoWriter.")
132
+
133
+ def write(self, frame: np.ndarray):
134
+ self.writer.write(frame)
135
+
136
+ def close(self):
137
+ if self.writer.isOpened():
138
+ self.writer.release()
139
+
140
+ # Transcode phase
141
+ try:
142
+ _transcode_with_ffmpeg(self.temp_path, self.output_path)
143
+ logging.debug("Transcoded video to H.264 for browser compatibility.")
144
+ os.remove(self.temp_path)
145
+ except FileNotFoundError:
146
+ logging.warning("ffmpeg not found; serving fallback MP4V output.")
147
+ shutil.move(self.temp_path, self.output_path)
148
+ except RuntimeError as exc:
149
+ logging.warning("ffmpeg transcode failed (%s); serving fallback MP4V output.", exc)
150
+ shutil.move(self.temp_path, self.output_path)
151
+
152
+ def __enter__(self):
153
+ return self
154
+
155
+ def __exit__(self, exc_type, exc_val, exc_tb):
156
+ self.close()