"""In-memory shared frame store to eliminate redundant JPEG encoding/decoding. Replaces the pipeline: MP4 → cv2 decode → JPEG encode to disk → N GPUs each decode all JPEGs back With: MP4 → cv2 decode once → SharedFrameStore in RAM → all GPUs read from same memory """ import logging from typing import Optional import cv2 import numpy as np import torch from PIL import Image class MemoryBudgetExceeded(Exception): """Raised when estimated memory usage exceeds the configured ceiling.""" def __init__(self, estimated_bytes: int): self.estimated_bytes = estimated_bytes super().__init__( f"Estimated memory {estimated_bytes / 1024**3:.1f} GiB exceeds budget" ) class SharedFrameStore: """Read-only in-memory store for decoded video frames (BGR uint8). Decodes the video once via cv2.VideoCapture and holds all frames in a list. Thread-safe for concurrent reads (frames list is never mutated after init). Raises MemoryBudgetExceeded BEFORE decoding if estimated memory exceeds the budget ceiling, giving callers a chance to fall back to JPEG path. """ MAX_BUDGET_BYTES = 12 * 1024**3 # 12 GiB ceiling def __init__(self, video_path: str, max_frames: Optional[int] = None): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise RuntimeError(f"Cannot open video: {video_path}") self.fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Estimate frame count BEFORE decoding to check memory budget reported_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if reported_count <= 0: reported_count = 10000 # conservative fallback est_frames = min(reported_count, max_frames) if max_frames else reported_count # Budget: raw BGR frames + worst-case SAM2 adapter tensors (image_size=1024) per_frame_raw = self.height * self.width * 3 # uint8 BGR per_frame_adapter = 3 * 1024 * 1024 * 4 # float32, worst-case 1024x1024 total_est = est_frames * (per_frame_raw + per_frame_adapter) if total_est > self.MAX_BUDGET_BYTES: cap.release() logging.warning( "SharedFrameStore: estimated ~%.1f GiB for %d frames exceeds " "%.1f GiB budget; skipping in-memory path", total_est / 1024**3, est_frames, self.MAX_BUDGET_BYTES / 1024**3, ) raise MemoryBudgetExceeded(total_est) frames = [] while True: if max_frames is not None and len(frames) >= max_frames: break ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() if not frames: raise RuntimeError(f"No frames decoded from: {video_path}") self.frames = frames logging.info( "SharedFrameStore: %d frames, %dx%d, %.1f fps", len(self.frames), self.width, self.height, self.fps, ) def __len__(self) -> int: return len(self.frames) def get_bgr(self, idx: int) -> np.ndarray: """Return BGR frame. Caller must .copy() if mutating.""" return self.frames[idx] def get_pil_rgb(self, idx: int) -> Image.Image: """Return PIL RGB Image for the given frame index.""" bgr = self.frames[idx] rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) return Image.fromarray(rgb) def sam2_adapter(self, image_size: int) -> "SAM2FrameAdapter": """Factory for SAM2-compatible frame adapter. Returns same adapter for same size.""" if not hasattr(self, "_adapters"): self._adapters = {} if image_size not in self._adapters: self._adapters[image_size] = SAM2FrameAdapter(self, image_size) return self._adapters[image_size] class SAM2FrameAdapter: """Drop-in replacement for SAM2's AsyncVideoFrameLoader. Matches the interface that SAM2's init_state / propagate_in_video expects: - __len__() → number of frames - __getitem__(idx) → normalized float32 tensor (3, H, W) - .images list (SAM2 accesses this directly in some paths) - .video_height, .video_width - .exception (AsyncVideoFrameLoader compat) Transform parity: uses PIL Image.resize() with BICUBIC (the default), matching SAM2's _load_img_as_tensor exactly. """ def __init__(self, store: SharedFrameStore, image_size: int): self._store = store self._image_size = image_size self.images = [None] * len(store) # SAM2 accesses .images directly self.video_height = store.height self.video_width = store.width self.exception = None # AsyncVideoFrameLoader compat # ImageNet normalization constants (must match SAM2's _load_img_as_tensor) self._mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1) self._std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1) def __len__(self) -> int: return len(self._store) def __getitem__(self, idx: int) -> torch.Tensor: if self.images[idx] is not None: return self.images[idx] # TRANSFORM PARITY: Must match SAM2's _load_img_as_tensor exactly. # SAM2 does: PIL Image → .convert("RGB") → .resize((size, size)) → /255 → permute → normalize # PIL.resize default = BICUBIC. We must use PIL resize, NOT cv2.resize. bgr = self._store.get_bgr(idx) pil_img = Image.fromarray(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) pil_resized = pil_img.resize( (self._image_size, self._image_size) ) # BICUBIC default img_np = np.array(pil_resized) / 255.0 img = torch.from_numpy(img_np).permute(2, 0, 1).float() img = (img - self._mean) / self._std self.images[idx] = img return img