File size: 6,011 Bytes
c97a5f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""In-memory shared frame store to eliminate redundant JPEG encoding/decoding.

Replaces the pipeline:
    MP4 β†’ cv2 decode β†’ JPEG encode to disk β†’ N GPUs each decode all JPEGs back
With:
    MP4 β†’ cv2 decode once β†’ SharedFrameStore in RAM β†’ all GPUs read from same memory
"""

import logging
from typing import Optional

import cv2
import numpy as np
import torch
from PIL import Image


class MemoryBudgetExceeded(Exception):
    """Raised when estimated memory usage exceeds the configured ceiling."""

    def __init__(self, estimated_bytes: int):
        self.estimated_bytes = estimated_bytes
        super().__init__(
            f"Estimated memory {estimated_bytes / 1024**3:.1f} GiB exceeds budget"
        )


class SharedFrameStore:
    """Read-only in-memory store for decoded video frames (BGR uint8).

    Decodes the video once via cv2.VideoCapture and holds all frames in a list.
    Thread-safe for concurrent reads (frames list is never mutated after init).

    Raises MemoryBudgetExceeded BEFORE decoding if estimated memory exceeds
    the budget ceiling, giving callers a chance to fall back to JPEG path.
    """

    MAX_BUDGET_BYTES = 12 * 1024**3  # 12 GiB ceiling

    def __init__(self, video_path: str, max_frames: Optional[int] = None):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise RuntimeError(f"Cannot open video: {video_path}")

        self.fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
        self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        # Estimate frame count BEFORE decoding to check memory budget
        reported_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if reported_count <= 0:
            reported_count = 10000  # conservative fallback
        est_frames = min(reported_count, max_frames) if max_frames else reported_count

        # Budget: raw BGR frames + worst-case SAM2 adapter tensors (image_size=1024)
        per_frame_raw = self.height * self.width * 3  # uint8 BGR
        per_frame_adapter = 3 * 1024 * 1024 * 4  # float32, worst-case 1024x1024
        total_est = est_frames * (per_frame_raw + per_frame_adapter)
        if total_est > self.MAX_BUDGET_BYTES:
            cap.release()
            logging.warning(
                "SharedFrameStore: estimated ~%.1f GiB for %d frames exceeds "
                "%.1f GiB budget; skipping in-memory path",
                total_est / 1024**3, est_frames, self.MAX_BUDGET_BYTES / 1024**3,
            )
            raise MemoryBudgetExceeded(total_est)

        frames = []
        while True:
            if max_frames is not None and len(frames) >= max_frames:
                break
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)
        cap.release()

        if not frames:
            raise RuntimeError(f"No frames decoded from: {video_path}")

        self.frames = frames
        logging.info(
            "SharedFrameStore: %d frames, %dx%d, %.1f fps",
            len(self.frames), self.width, self.height, self.fps,
        )

    def __len__(self) -> int:
        return len(self.frames)

    def get_bgr(self, idx: int) -> np.ndarray:
        """Return BGR frame. Caller must .copy() if mutating."""
        return self.frames[idx]

    def get_pil_rgb(self, idx: int) -> Image.Image:
        """Return PIL RGB Image for the given frame index."""
        bgr = self.frames[idx]
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        return Image.fromarray(rgb)

    def sam2_adapter(self, image_size: int) -> "SAM2FrameAdapter":
        """Factory for SAM2-compatible frame adapter. Returns same adapter for same size."""
        if not hasattr(self, "_adapters"):
            self._adapters = {}
        if image_size not in self._adapters:
            self._adapters[image_size] = SAM2FrameAdapter(self, image_size)
        return self._adapters[image_size]


class SAM2FrameAdapter:
    """Drop-in replacement for SAM2's AsyncVideoFrameLoader.

    Matches the interface that SAM2's init_state / propagate_in_video expects:
    - __len__() β†’ number of frames
    - __getitem__(idx) β†’ normalized float32 tensor (3, H, W)
    - .images list (SAM2 accesses this directly in some paths)
    - .video_height, .video_width
    - .exception (AsyncVideoFrameLoader compat)

    Transform parity: uses PIL Image.resize() with BICUBIC (the default),
    matching SAM2's _load_img_as_tensor exactly.
    """

    def __init__(self, store: SharedFrameStore, image_size: int):
        self._store = store
        self._image_size = image_size
        self.images = [None] * len(store)  # SAM2 accesses .images directly
        self.video_height = store.height
        self.video_width = store.width
        self.exception = None  # AsyncVideoFrameLoader compat

        # ImageNet normalization constants (must match SAM2's _load_img_as_tensor)
        self._mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
        self._std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1)

    def __len__(self) -> int:
        return len(self._store)

    def __getitem__(self, idx: int) -> torch.Tensor:
        if self.images[idx] is not None:
            return self.images[idx]

        # TRANSFORM PARITY: Must match SAM2's _load_img_as_tensor exactly.
        # SAM2 does: PIL Image β†’ .convert("RGB") β†’ .resize((size, size)) β†’ /255 β†’ permute β†’ normalize
        # PIL.resize default = BICUBIC. We must use PIL resize, NOT cv2.resize.
        bgr = self._store.get_bgr(idx)
        pil_img = Image.fromarray(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
        pil_resized = pil_img.resize(
            (self._image_size, self._image_size)
        )  # BICUBIC default
        img_np = np.array(pil_resized) / 255.0
        img = torch.from_numpy(img_np).permute(2, 0, 1).float()
        img = (img - self._mean) / self._std
        self.images[idx] = img
        return img