Zhen Ye commited on
Commit
c97a5f9
Β·
1 Parent(s): 05bd36a

Eliminate redundant JPEG frame loading via shared frame store

Browse files
inference.py CHANGED
@@ -1209,14 +1209,18 @@ def _gsam2_render_frame(
1209
  height: int,
1210
  width: int,
1211
  masks_only: bool = False,
 
1212
  ) -> np.ndarray:
1213
  """Render a single GSAM2 tracking frame (masks + boxes). CPU-only.
1214
 
1215
  When *masks_only* is True, skip box rendering so the writer thread can
1216
  draw boxes later with enriched (GPT) labels.
1217
  """
1218
- frame_path = os.path.join(frame_dir, frame_names[frame_idx])
1219
- frame = cv2.imread(frame_path)
 
 
 
1220
  if frame is None:
1221
  return np.zeros((height, width, 3), dtype=np.uint8)
1222
 
@@ -1290,6 +1294,7 @@ def run_grounded_sam2_tracking(
1290
  from PIL import Image as PILImage
1291
 
1292
  from utils.video import extract_frames_to_jpeg_dir
 
1293
  from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, LazyFrameObjects
1294
 
1295
  active_segmenter = segmenter_name or "GSAM2-L"
@@ -1305,26 +1310,40 @@ def run_grounded_sam2_tracking(
1305
  active_segmenter, queries, step,
1306
  )
1307
 
1308
- # 1. Extract frames to JPEG directory
1309
- frame_dir = tempfile.mkdtemp(prefix="gsam2_frames_")
 
 
1310
  try:
1311
- if _perf_metrics is not None:
1312
- _t_e2e = time.perf_counter()
1313
- if torch.cuda.is_available():
1314
- torch.cuda.reset_peak_memory_stats()
1315
-
1316
- if _perf_metrics is not None:
1317
- _t_ext = time.perf_counter()
1318
-
 
 
 
 
 
 
 
1319
  frame_names, fps, width, height = extract_frames_to_jpeg_dir(
1320
  input_video_path, frame_dir, max_frames=max_frames,
1321
  )
 
1322
 
 
1323
  if _perf_metrics is not None:
 
 
 
1324
  _perf_metrics["frame_extraction_ms"] = (time.perf_counter() - _t_ext) * 1000.0
1325
- total_frames = len(frame_names)
1326
- _ttfs(f"frame_extraction done ({total_frames} frames)")
1327
- logging.info("Extracted %d frames to %s", total_frames, frame_dir)
1328
 
1329
  num_gpus = torch.cuda.device_count()
1330
 
@@ -1358,6 +1377,7 @@ def run_grounded_sam2_tracking(
1358
  frame_dir, frame_names, fidx, fobjs,
1359
  height, width,
1360
  masks_only=enable_gpt,
 
1361
  )
1362
 
1363
  if _perf_metrics is not None:
@@ -1824,6 +1844,7 @@ def run_grounded_sam2_tracking(
1824
  on_segment_output=_feed_segment_gpu,
1825
  _ttfs_t0=_ttfs_t0,
1826
  _ttfs_job_id=job_id,
 
1827
  )
1828
 
1829
  if _perf_metrics is not None:
@@ -1875,13 +1896,46 @@ def run_grounded_sam2_tracking(
1875
  if _perf_metrics is not None:
1876
  _t_init = time.perf_counter()
1877
 
1878
- def _init_seg_state(seg):
1879
- seg._ensure_models_loaded()
1880
- return seg._video_predictor.init_state(
1881
- video_path=frame_dir,
1882
- offload_video_to_cpu=True,
1883
- async_loading_frames=True,
1884
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1885
 
1886
  with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
1887
  futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
@@ -1937,11 +1991,14 @@ def run_grounded_sam2_tracking(
1937
  "GPU %d processing segment %d (frame %d)",
1938
  gpu_idx, seg_idx, start_idx,
1939
  )
1940
- img_path = os.path.join(
1941
- frame_dir, frame_names[start_idx]
1942
- )
1943
- with PILImage.open(img_path) as pil_img:
1944
- image = pil_img.convert("RGB")
 
 
 
1945
 
1946
  if job_id:
1947
  _check_cancellation(job_id)
 
1209
  height: int,
1210
  width: int,
1211
  masks_only: bool = False,
1212
+ frame_store=None,
1213
  ) -> np.ndarray:
1214
  """Render a single GSAM2 tracking frame (masks + boxes). CPU-only.
1215
 
1216
  When *masks_only* is True, skip box rendering so the writer thread can
1217
  draw boxes later with enriched (GPT) labels.
1218
  """
1219
+ if frame_store is not None:
1220
+ frame = frame_store.get_bgr(frame_idx).copy() # .copy() β€” render mutates
1221
+ else:
1222
+ frame_path = os.path.join(frame_dir, frame_names[frame_idx])
1223
+ frame = cv2.imread(frame_path)
1224
  if frame is None:
1225
  return np.zeros((height, width, 3), dtype=np.uint8)
1226
 
 
1294
  from PIL import Image as PILImage
1295
 
1296
  from utils.video import extract_frames_to_jpeg_dir
1297
+ from utils.frame_store import SharedFrameStore, MemoryBudgetExceeded
1298
  from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, LazyFrameObjects
1299
 
1300
  active_segmenter = segmenter_name or "GSAM2-L"
 
1310
  active_segmenter, queries, step,
1311
  )
1312
 
1313
+ # 1. Load frames β€” prefer in-memory SharedFrameStore, fall back to JPEG dir
1314
+ _use_frame_store = True
1315
+ frame_store = None
1316
+ _t_ext = time.perf_counter()
1317
  try:
1318
+ frame_store = SharedFrameStore(input_video_path, max_frames=max_frames)
1319
+ fps, width, height = frame_store.fps, frame_store.width, frame_store.height
1320
+ total_frames = len(frame_store)
1321
+ frame_names = [f"{i:06d}.jpg" for i in range(total_frames)]
1322
+
1323
+ # Write single dummy JPEG for init_state bootstrapping
1324
+ dummy_frame_dir = tempfile.mkdtemp(prefix="gsam2_dummy_")
1325
+ cv2.imwrite(os.path.join(dummy_frame_dir, "000000.jpg"), frame_store.get_bgr(0))
1326
+ frame_dir = dummy_frame_dir
1327
+ logging.info("SharedFrameStore: %d frames in memory (dummy dir: %s)", total_frames, frame_dir)
1328
+ except MemoryBudgetExceeded:
1329
+ logging.info("Memory budget exceeded, falling back to JPEG extraction")
1330
+ _use_frame_store = False
1331
+ frame_store = None
1332
+ frame_dir = tempfile.mkdtemp(prefix="gsam2_frames_")
1333
  frame_names, fps, width, height = extract_frames_to_jpeg_dir(
1334
  input_video_path, frame_dir, max_frames=max_frames,
1335
  )
1336
+ total_frames = len(frame_names)
1337
 
1338
+ try:
1339
  if _perf_metrics is not None:
1340
+ _t_e2e = time.perf_counter()
1341
+ if torch.cuda.is_available():
1342
+ torch.cuda.reset_peak_memory_stats()
1343
  _perf_metrics["frame_extraction_ms"] = (time.perf_counter() - _t_ext) * 1000.0
1344
+
1345
+ _ttfs(f"frame_extraction done ({total_frames} frames, in_memory={_use_frame_store})")
1346
+ logging.info("Loaded %d frames (in_memory=%s)", total_frames, _use_frame_store)
1347
 
1348
  num_gpus = torch.cuda.device_count()
1349
 
 
1377
  frame_dir, frame_names, fidx, fobjs,
1378
  height, width,
1379
  masks_only=enable_gpt,
1380
+ frame_store=frame_store,
1381
  )
1382
 
1383
  if _perf_metrics is not None:
 
1844
  on_segment_output=_feed_segment_gpu,
1845
  _ttfs_t0=_ttfs_t0,
1846
  _ttfs_job_id=job_id,
1847
+ frame_store=frame_store,
1848
  )
1849
 
1850
  if _perf_metrics is not None:
 
1896
  if _perf_metrics is not None:
1897
  _t_init = time.perf_counter()
1898
 
1899
+ if frame_store is not None:
1900
+ # Models are lazy-loaded; ensure at least one is ready so we
1901
+ # can read image_size. Phase 1 (load_segmenter_on_device)
1902
+ # only constructs the object β€” _video_predictor is still None.
1903
+ segmenters[0]._ensure_models_loaded()
1904
+ sam2_img_size = segmenters[0]._video_predictor.image_size
1905
+
1906
+ # Pre-create the shared adapter (validates memory budget)
1907
+ shared_adapter = frame_store.sam2_adapter(image_size=sam2_img_size)
1908
+
1909
+ _REQUIRED_KEYS = {"images", "num_frames", "video_height", "video_width", "cached_features"}
1910
+
1911
+ def _init_seg_state(seg):
1912
+ seg._ensure_models_loaded()
1913
+ state = seg._video_predictor.init_state(
1914
+ video_path=frame_dir, # dummy dir with 1 JPEG
1915
+ offload_video_to_cpu=True,
1916
+ async_loading_frames=False, # 1 dummy frame, instant
1917
+ )
1918
+ # Validate expected keys exist before patching
1919
+ missing = _REQUIRED_KEYS - set(state.keys())
1920
+ if missing:
1921
+ raise RuntimeError(f"SAM2 init_state missing expected keys: {missing}")
1922
+ # CRITICAL: Clear cached_features BEFORE patching images
1923
+ # init_state caches dummy frame 0's backbone features β€” must evict
1924
+ state["cached_features"] = {}
1925
+ # Patch in real frame data
1926
+ state["images"] = shared_adapter
1927
+ state["num_frames"] = total_frames
1928
+ state["video_height"] = height
1929
+ state["video_width"] = width
1930
+ return state
1931
+ else:
1932
+ def _init_seg_state(seg):
1933
+ seg._ensure_models_loaded()
1934
+ return seg._video_predictor.init_state(
1935
+ video_path=frame_dir,
1936
+ offload_video_to_cpu=True,
1937
+ async_loading_frames=True,
1938
+ )
1939
 
1940
  with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
1941
  futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
 
1991
  "GPU %d processing segment %d (frame %d)",
1992
  gpu_idx, seg_idx, start_idx,
1993
  )
1994
+ if frame_store is not None:
1995
+ image = frame_store.get_pil_rgb(start_idx)
1996
+ else:
1997
+ img_path = os.path.join(
1998
+ frame_dir, frame_names[start_idx]
1999
+ )
2000
+ with PILImage.open(img_path) as pil_img:
2001
+ image = pil_img.convert("RGB")
2002
 
2003
  if job_id:
2004
  _check_cancellation(job_id)
models/segmenters/grounded_sam2.py CHANGED
@@ -717,6 +717,7 @@ class GroundedSAM2Segmenter(Segmenter):
717
  on_segment_output: Optional[Callable[["SegmentOutput"], None]] = None,
718
  _ttfs_t0: Optional[float] = None,
719
  _ttfs_job_id: Optional[str] = None,
 
720
  ) -> Dict[int, Dict[int, ObjectInfo]]:
721
  """Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
722
 
@@ -758,11 +759,26 @@ class GroundedSAM2Segmenter(Segmenter):
758
  if _pm is not None:
759
  _t_init = time.perf_counter()
760
 
761
- inference_state = self._video_predictor.init_state(
762
- video_path=frame_dir,
763
- offload_video_to_cpu=True,
764
- async_loading_frames=True,
765
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766
 
767
  if _pm is not None:
768
  _pl = getattr(self, '_perf_lock', None)
@@ -775,8 +791,10 @@ class GroundedSAM2Segmenter(Segmenter):
775
  for start_idx in range(0, total_frames, step):
776
  logging.info("Processing keyframe %d / %d", start_idx, total_frames)
777
 
778
- img_path = os.path.join(frame_dir, frame_names[start_idx])
779
- image = Image.open(img_path).convert("RGB")
 
 
780
 
781
  mask_dict = MaskDictionary()
782
 
 
717
  on_segment_output: Optional[Callable[["SegmentOutput"], None]] = None,
718
  _ttfs_t0: Optional[float] = None,
719
  _ttfs_job_id: Optional[str] = None,
720
+ frame_store=None,
721
  ) -> Dict[int, Dict[int, ObjectInfo]]:
722
  """Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
723
 
 
759
  if _pm is not None:
760
  _t_init = time.perf_counter()
761
 
762
+ if frame_store is not None:
763
+ inference_state = self._video_predictor.init_state(
764
+ video_path=frame_dir, # dummy dir with 1 JPEG
765
+ offload_video_to_cpu=True,
766
+ async_loading_frames=False,
767
+ )
768
+ # Clear cached_features (dummy frame 0's backbone features)
769
+ inference_state["cached_features"] = {}
770
+ # Patch in real frame data
771
+ img_size = self._video_predictor.image_size
772
+ inference_state["images"] = frame_store.sam2_adapter(image_size=img_size)
773
+ inference_state["num_frames"] = len(frame_store)
774
+ inference_state["video_height"] = frame_store.height
775
+ inference_state["video_width"] = frame_store.width
776
+ else:
777
+ inference_state = self._video_predictor.init_state(
778
+ video_path=frame_dir,
779
+ offload_video_to_cpu=True,
780
+ async_loading_frames=True,
781
+ )
782
 
783
  if _pm is not None:
784
  _pl = getattr(self, '_perf_lock', None)
 
791
  for start_idx in range(0, total_frames, step):
792
  logging.info("Processing keyframe %d / %d", start_idx, total_frames)
793
 
794
+ if frame_store is not None:
795
+ image = frame_store.get_pil_rgb(start_idx)
796
+ else:
797
+ image = Image.open(os.path.join(frame_dir, frame_names[start_idx])).convert("RGB")
798
 
799
  mask_dict = MaskDictionary()
800
 
utils/frame_store.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """In-memory shared frame store to eliminate redundant JPEG encoding/decoding.
2
+
3
+ Replaces the pipeline:
4
+ MP4 β†’ cv2 decode β†’ JPEG encode to disk β†’ N GPUs each decode all JPEGs back
5
+ With:
6
+ MP4 β†’ cv2 decode once β†’ SharedFrameStore in RAM β†’ all GPUs read from same memory
7
+ """
8
+
9
+ import logging
10
+ from typing import Optional
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image
16
+
17
+
18
+ class MemoryBudgetExceeded(Exception):
19
+ """Raised when estimated memory usage exceeds the configured ceiling."""
20
+
21
+ def __init__(self, estimated_bytes: int):
22
+ self.estimated_bytes = estimated_bytes
23
+ super().__init__(
24
+ f"Estimated memory {estimated_bytes / 1024**3:.1f} GiB exceeds budget"
25
+ )
26
+
27
+
28
+ class SharedFrameStore:
29
+ """Read-only in-memory store for decoded video frames (BGR uint8).
30
+
31
+ Decodes the video once via cv2.VideoCapture and holds all frames in a list.
32
+ Thread-safe for concurrent reads (frames list is never mutated after init).
33
+
34
+ Raises MemoryBudgetExceeded BEFORE decoding if estimated memory exceeds
35
+ the budget ceiling, giving callers a chance to fall back to JPEG path.
36
+ """
37
+
38
+ MAX_BUDGET_BYTES = 12 * 1024**3 # 12 GiB ceiling
39
+
40
+ def __init__(self, video_path: str, max_frames: Optional[int] = None):
41
+ cap = cv2.VideoCapture(video_path)
42
+ if not cap.isOpened():
43
+ raise RuntimeError(f"Cannot open video: {video_path}")
44
+
45
+ self.fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
46
+ self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
47
+ self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
48
+
49
+ # Estimate frame count BEFORE decoding to check memory budget
50
+ reported_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
51
+ if reported_count <= 0:
52
+ reported_count = 10000 # conservative fallback
53
+ est_frames = min(reported_count, max_frames) if max_frames else reported_count
54
+
55
+ # Budget: raw BGR frames + worst-case SAM2 adapter tensors (image_size=1024)
56
+ per_frame_raw = self.height * self.width * 3 # uint8 BGR
57
+ per_frame_adapter = 3 * 1024 * 1024 * 4 # float32, worst-case 1024x1024
58
+ total_est = est_frames * (per_frame_raw + per_frame_adapter)
59
+ if total_est > self.MAX_BUDGET_BYTES:
60
+ cap.release()
61
+ logging.warning(
62
+ "SharedFrameStore: estimated ~%.1f GiB for %d frames exceeds "
63
+ "%.1f GiB budget; skipping in-memory path",
64
+ total_est / 1024**3, est_frames, self.MAX_BUDGET_BYTES / 1024**3,
65
+ )
66
+ raise MemoryBudgetExceeded(total_est)
67
+
68
+ frames = []
69
+ while True:
70
+ if max_frames is not None and len(frames) >= max_frames:
71
+ break
72
+ ret, frame = cap.read()
73
+ if not ret:
74
+ break
75
+ frames.append(frame)
76
+ cap.release()
77
+
78
+ if not frames:
79
+ raise RuntimeError(f"No frames decoded from: {video_path}")
80
+
81
+ self.frames = frames
82
+ logging.info(
83
+ "SharedFrameStore: %d frames, %dx%d, %.1f fps",
84
+ len(self.frames), self.width, self.height, self.fps,
85
+ )
86
+
87
+ def __len__(self) -> int:
88
+ return len(self.frames)
89
+
90
+ def get_bgr(self, idx: int) -> np.ndarray:
91
+ """Return BGR frame. Caller must .copy() if mutating."""
92
+ return self.frames[idx]
93
+
94
+ def get_pil_rgb(self, idx: int) -> Image.Image:
95
+ """Return PIL RGB Image for the given frame index."""
96
+ bgr = self.frames[idx]
97
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
98
+ return Image.fromarray(rgb)
99
+
100
+ def sam2_adapter(self, image_size: int) -> "SAM2FrameAdapter":
101
+ """Factory for SAM2-compatible frame adapter. Returns same adapter for same size."""
102
+ if not hasattr(self, "_adapters"):
103
+ self._adapters = {}
104
+ if image_size not in self._adapters:
105
+ self._adapters[image_size] = SAM2FrameAdapter(self, image_size)
106
+ return self._adapters[image_size]
107
+
108
+
109
+ class SAM2FrameAdapter:
110
+ """Drop-in replacement for SAM2's AsyncVideoFrameLoader.
111
+
112
+ Matches the interface that SAM2's init_state / propagate_in_video expects:
113
+ - __len__() β†’ number of frames
114
+ - __getitem__(idx) β†’ normalized float32 tensor (3, H, W)
115
+ - .images list (SAM2 accesses this directly in some paths)
116
+ - .video_height, .video_width
117
+ - .exception (AsyncVideoFrameLoader compat)
118
+
119
+ Transform parity: uses PIL Image.resize() with BICUBIC (the default),
120
+ matching SAM2's _load_img_as_tensor exactly.
121
+ """
122
+
123
+ def __init__(self, store: SharedFrameStore, image_size: int):
124
+ self._store = store
125
+ self._image_size = image_size
126
+ self.images = [None] * len(store) # SAM2 accesses .images directly
127
+ self.video_height = store.height
128
+ self.video_width = store.width
129
+ self.exception = None # AsyncVideoFrameLoader compat
130
+
131
+ # ImageNet normalization constants (must match SAM2's _load_img_as_tensor)
132
+ self._mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
133
+ self._std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1)
134
+
135
+ def __len__(self) -> int:
136
+ return len(self._store)
137
+
138
+ def __getitem__(self, idx: int) -> torch.Tensor:
139
+ if self.images[idx] is not None:
140
+ return self.images[idx]
141
+
142
+ # TRANSFORM PARITY: Must match SAM2's _load_img_as_tensor exactly.
143
+ # SAM2 does: PIL Image β†’ .convert("RGB") β†’ .resize((size, size)) β†’ /255 β†’ permute β†’ normalize
144
+ # PIL.resize default = BICUBIC. We must use PIL resize, NOT cv2.resize.
145
+ bgr = self._store.get_bgr(idx)
146
+ pil_img = Image.fromarray(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
147
+ pil_resized = pil_img.resize(
148
+ (self._image_size, self._image_size)
149
+ ) # BICUBIC default
150
+ img_np = np.array(pil_resized) / 255.0
151
+ img = torch.from_numpy(img_np).permute(2, 0, 1).float()
152
+ img = (img - self._mean) / self._std
153
+ self.images[idx] = img
154
+ return img