detection_base / utils /frame_store.py
Zhen Ye
Eliminate redundant JPEG frame loading via shared frame store
c97a5f9
"""In-memory shared frame store to eliminate redundant JPEG encoding/decoding.
Replaces the pipeline:
MP4 β†’ cv2 decode β†’ JPEG encode to disk β†’ N GPUs each decode all JPEGs back
With:
MP4 β†’ cv2 decode once β†’ SharedFrameStore in RAM β†’ all GPUs read from same memory
"""
import logging
from typing import Optional
import cv2
import numpy as np
import torch
from PIL import Image
class MemoryBudgetExceeded(Exception):
"""Raised when estimated memory usage exceeds the configured ceiling."""
def __init__(self, estimated_bytes: int):
self.estimated_bytes = estimated_bytes
super().__init__(
f"Estimated memory {estimated_bytes / 1024**3:.1f} GiB exceeds budget"
)
class SharedFrameStore:
"""Read-only in-memory store for decoded video frames (BGR uint8).
Decodes the video once via cv2.VideoCapture and holds all frames in a list.
Thread-safe for concurrent reads (frames list is never mutated after init).
Raises MemoryBudgetExceeded BEFORE decoding if estimated memory exceeds
the budget ceiling, giving callers a chance to fall back to JPEG path.
"""
MAX_BUDGET_BYTES = 12 * 1024**3 # 12 GiB ceiling
def __init__(self, video_path: str, max_frames: Optional[int] = None):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video: {video_path}")
self.fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Estimate frame count BEFORE decoding to check memory budget
reported_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if reported_count <= 0:
reported_count = 10000 # conservative fallback
est_frames = min(reported_count, max_frames) if max_frames else reported_count
# Budget: raw BGR frames + worst-case SAM2 adapter tensors (image_size=1024)
per_frame_raw = self.height * self.width * 3 # uint8 BGR
per_frame_adapter = 3 * 1024 * 1024 * 4 # float32, worst-case 1024x1024
total_est = est_frames * (per_frame_raw + per_frame_adapter)
if total_est > self.MAX_BUDGET_BYTES:
cap.release()
logging.warning(
"SharedFrameStore: estimated ~%.1f GiB for %d frames exceeds "
"%.1f GiB budget; skipping in-memory path",
total_est / 1024**3, est_frames, self.MAX_BUDGET_BYTES / 1024**3,
)
raise MemoryBudgetExceeded(total_est)
frames = []
while True:
if max_frames is not None and len(frames) >= max_frames:
break
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
if not frames:
raise RuntimeError(f"No frames decoded from: {video_path}")
self.frames = frames
logging.info(
"SharedFrameStore: %d frames, %dx%d, %.1f fps",
len(self.frames), self.width, self.height, self.fps,
)
def __len__(self) -> int:
return len(self.frames)
def get_bgr(self, idx: int) -> np.ndarray:
"""Return BGR frame. Caller must .copy() if mutating."""
return self.frames[idx]
def get_pil_rgb(self, idx: int) -> Image.Image:
"""Return PIL RGB Image for the given frame index."""
bgr = self.frames[idx]
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
return Image.fromarray(rgb)
def sam2_adapter(self, image_size: int) -> "SAM2FrameAdapter":
"""Factory for SAM2-compatible frame adapter. Returns same adapter for same size."""
if not hasattr(self, "_adapters"):
self._adapters = {}
if image_size not in self._adapters:
self._adapters[image_size] = SAM2FrameAdapter(self, image_size)
return self._adapters[image_size]
class SAM2FrameAdapter:
"""Drop-in replacement for SAM2's AsyncVideoFrameLoader.
Matches the interface that SAM2's init_state / propagate_in_video expects:
- __len__() β†’ number of frames
- __getitem__(idx) β†’ normalized float32 tensor (3, H, W)
- .images list (SAM2 accesses this directly in some paths)
- .video_height, .video_width
- .exception (AsyncVideoFrameLoader compat)
Transform parity: uses PIL Image.resize() with BICUBIC (the default),
matching SAM2's _load_img_as_tensor exactly.
"""
def __init__(self, store: SharedFrameStore, image_size: int):
self._store = store
self._image_size = image_size
self.images = [None] * len(store) # SAM2 accesses .images directly
self.video_height = store.height
self.video_width = store.width
self.exception = None # AsyncVideoFrameLoader compat
# ImageNet normalization constants (must match SAM2's _load_img_as_tensor)
self._mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
self._std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1)
def __len__(self) -> int:
return len(self._store)
def __getitem__(self, idx: int) -> torch.Tensor:
if self.images[idx] is not None:
return self.images[idx]
# TRANSFORM PARITY: Must match SAM2's _load_img_as_tensor exactly.
# SAM2 does: PIL Image β†’ .convert("RGB") β†’ .resize((size, size)) β†’ /255 β†’ permute β†’ normalize
# PIL.resize default = BICUBIC. We must use PIL resize, NOT cv2.resize.
bgr = self._store.get_bgr(idx)
pil_img = Image.fromarray(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
pil_resized = pil_img.resize(
(self._image_size, self._image_size)
) # BICUBIC default
img_np = np.array(pil_resized) / 255.0
img = torch.from_numpy(img_np).permute(2, 0, 1).float()
img = (img - self._mean) / self._std
self.images[idx] = img
return img