Spaces:
Paused
Paused
| """In-memory shared frame store to eliminate redundant JPEG encoding/decoding. | |
| Replaces the pipeline: | |
| MP4 β cv2 decode β JPEG encode to disk β N GPUs each decode all JPEGs back | |
| With: | |
| MP4 β cv2 decode once β SharedFrameStore in RAM β all GPUs read from same memory | |
| """ | |
| import logging | |
| from typing import Optional | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| class MemoryBudgetExceeded(Exception): | |
| """Raised when estimated memory usage exceeds the configured ceiling.""" | |
| def __init__(self, estimated_bytes: int): | |
| self.estimated_bytes = estimated_bytes | |
| super().__init__( | |
| f"Estimated memory {estimated_bytes / 1024**3:.1f} GiB exceeds budget" | |
| ) | |
| class SharedFrameStore: | |
| """Read-only in-memory store for decoded video frames (BGR uint8). | |
| Decodes the video once via cv2.VideoCapture and holds all frames in a list. | |
| Thread-safe for concurrent reads (frames list is never mutated after init). | |
| Raises MemoryBudgetExceeded BEFORE decoding if estimated memory exceeds | |
| the budget ceiling, giving callers a chance to fall back to JPEG path. | |
| """ | |
| MAX_BUDGET_BYTES = 12 * 1024**3 # 12 GiB ceiling | |
| def __init__(self, video_path: str, max_frames: Optional[int] = None): | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise RuntimeError(f"Cannot open video: {video_path}") | |
| self.fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
| self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # Estimate frame count BEFORE decoding to check memory budget | |
| reported_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if reported_count <= 0: | |
| reported_count = 10000 # conservative fallback | |
| est_frames = min(reported_count, max_frames) if max_frames else reported_count | |
| # Budget: raw BGR frames + worst-case SAM2 adapter tensors (image_size=1024) | |
| per_frame_raw = self.height * self.width * 3 # uint8 BGR | |
| per_frame_adapter = 3 * 1024 * 1024 * 4 # float32, worst-case 1024x1024 | |
| total_est = est_frames * (per_frame_raw + per_frame_adapter) | |
| if total_est > self.MAX_BUDGET_BYTES: | |
| cap.release() | |
| logging.warning( | |
| "SharedFrameStore: estimated ~%.1f GiB for %d frames exceeds " | |
| "%.1f GiB budget; skipping in-memory path", | |
| total_est / 1024**3, est_frames, self.MAX_BUDGET_BYTES / 1024**3, | |
| ) | |
| raise MemoryBudgetExceeded(total_est) | |
| frames = [] | |
| while True: | |
| if max_frames is not None and len(frames) >= max_frames: | |
| break | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(frame) | |
| cap.release() | |
| if not frames: | |
| raise RuntimeError(f"No frames decoded from: {video_path}") | |
| self.frames = frames | |
| logging.info( | |
| "SharedFrameStore: %d frames, %dx%d, %.1f fps", | |
| len(self.frames), self.width, self.height, self.fps, | |
| ) | |
| def __len__(self) -> int: | |
| return len(self.frames) | |
| def get_bgr(self, idx: int) -> np.ndarray: | |
| """Return BGR frame. Caller must .copy() if mutating.""" | |
| return self.frames[idx] | |
| def get_pil_rgb(self, idx: int) -> Image.Image: | |
| """Return PIL RGB Image for the given frame index.""" | |
| bgr = self.frames[idx] | |
| rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) | |
| return Image.fromarray(rgb) | |
| def sam2_adapter(self, image_size: int) -> "SAM2FrameAdapter": | |
| """Factory for SAM2-compatible frame adapter. Returns same adapter for same size.""" | |
| if not hasattr(self, "_adapters"): | |
| self._adapters = {} | |
| if image_size not in self._adapters: | |
| self._adapters[image_size] = SAM2FrameAdapter(self, image_size) | |
| return self._adapters[image_size] | |
| class SAM2FrameAdapter: | |
| """Drop-in replacement for SAM2's AsyncVideoFrameLoader. | |
| Matches the interface that SAM2's init_state / propagate_in_video expects: | |
| - __len__() β number of frames | |
| - __getitem__(idx) β normalized float32 tensor (3, H, W) | |
| - .images list (SAM2 accesses this directly in some paths) | |
| - .video_height, .video_width | |
| - .exception (AsyncVideoFrameLoader compat) | |
| Transform parity: uses PIL Image.resize() with BICUBIC (the default), | |
| matching SAM2's _load_img_as_tensor exactly. | |
| """ | |
| def __init__(self, store: SharedFrameStore, image_size: int): | |
| self._store = store | |
| self._image_size = image_size | |
| self.images = [None] * len(store) # SAM2 accesses .images directly | |
| self.video_height = store.height | |
| self.video_width = store.width | |
| self.exception = None # AsyncVideoFrameLoader compat | |
| # ImageNet normalization constants (must match SAM2's _load_img_as_tensor) | |
| self._mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1) | |
| self._std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1) | |
| def __len__(self) -> int: | |
| return len(self._store) | |
| def __getitem__(self, idx: int) -> torch.Tensor: | |
| if self.images[idx] is not None: | |
| return self.images[idx] | |
| # TRANSFORM PARITY: Must match SAM2's _load_img_as_tensor exactly. | |
| # SAM2 does: PIL Image β .convert("RGB") β .resize((size, size)) β /255 β permute β normalize | |
| # PIL.resize default = BICUBIC. We must use PIL resize, NOT cv2.resize. | |
| bgr = self._store.get_bgr(idx) | |
| pil_img = Image.fromarray(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| pil_resized = pil_img.resize( | |
| (self._image_size, self._image_size) | |
| ) # BICUBIC default | |
| img_np = np.array(pil_resized) / 255.0 | |
| img = torch.from_numpy(img_np).permute(2, 0, 1).float() | |
| img = (img - self._mean) / self._std | |
| self.images[idx] = img | |
| return img | |