Spaces:
Paused
Paused
Zhen Ye commited on
Commit Β·
c97a5f9
1
Parent(s): 05bd36a
Eliminate redundant JPEG frame loading via shared frame store
Browse files- inference.py +84 -27
- models/segmenters/grounded_sam2.py +25 -7
- utils/frame_store.py +154 -0
inference.py
CHANGED
|
@@ -1209,14 +1209,18 @@ def _gsam2_render_frame(
|
|
| 1209 |
height: int,
|
| 1210 |
width: int,
|
| 1211 |
masks_only: bool = False,
|
|
|
|
| 1212 |
) -> np.ndarray:
|
| 1213 |
"""Render a single GSAM2 tracking frame (masks + boxes). CPU-only.
|
| 1214 |
|
| 1215 |
When *masks_only* is True, skip box rendering so the writer thread can
|
| 1216 |
draw boxes later with enriched (GPT) labels.
|
| 1217 |
"""
|
| 1218 |
-
|
| 1219 |
-
|
|
|
|
|
|
|
|
|
|
| 1220 |
if frame is None:
|
| 1221 |
return np.zeros((height, width, 3), dtype=np.uint8)
|
| 1222 |
|
|
@@ -1290,6 +1294,7 @@ def run_grounded_sam2_tracking(
|
|
| 1290 |
from PIL import Image as PILImage
|
| 1291 |
|
| 1292 |
from utils.video import extract_frames_to_jpeg_dir
|
|
|
|
| 1293 |
from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, LazyFrameObjects
|
| 1294 |
|
| 1295 |
active_segmenter = segmenter_name or "GSAM2-L"
|
|
@@ -1305,26 +1310,40 @@ def run_grounded_sam2_tracking(
|
|
| 1305 |
active_segmenter, queries, step,
|
| 1306 |
)
|
| 1307 |
|
| 1308 |
-
# 1.
|
| 1309 |
-
|
|
|
|
|
|
|
| 1310 |
try:
|
| 1311 |
-
|
| 1312 |
-
|
| 1313 |
-
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1319 |
frame_names, fps, width, height = extract_frames_to_jpeg_dir(
|
| 1320 |
input_video_path, frame_dir, max_frames=max_frames,
|
| 1321 |
)
|
|
|
|
| 1322 |
|
|
|
|
| 1323 |
if _perf_metrics is not None:
|
|
|
|
|
|
|
|
|
|
| 1324 |
_perf_metrics["frame_extraction_ms"] = (time.perf_counter() - _t_ext) * 1000.0
|
| 1325 |
-
|
| 1326 |
-
_ttfs(f"frame_extraction done ({total_frames} frames)")
|
| 1327 |
-
logging.info("
|
| 1328 |
|
| 1329 |
num_gpus = torch.cuda.device_count()
|
| 1330 |
|
|
@@ -1358,6 +1377,7 @@ def run_grounded_sam2_tracking(
|
|
| 1358 |
frame_dir, frame_names, fidx, fobjs,
|
| 1359 |
height, width,
|
| 1360 |
masks_only=enable_gpt,
|
|
|
|
| 1361 |
)
|
| 1362 |
|
| 1363 |
if _perf_metrics is not None:
|
|
@@ -1824,6 +1844,7 @@ def run_grounded_sam2_tracking(
|
|
| 1824 |
on_segment_output=_feed_segment_gpu,
|
| 1825 |
_ttfs_t0=_ttfs_t0,
|
| 1826 |
_ttfs_job_id=job_id,
|
|
|
|
| 1827 |
)
|
| 1828 |
|
| 1829 |
if _perf_metrics is not None:
|
|
@@ -1875,13 +1896,46 @@ def run_grounded_sam2_tracking(
|
|
| 1875 |
if _perf_metrics is not None:
|
| 1876 |
_t_init = time.perf_counter()
|
| 1877 |
|
| 1878 |
-
|
| 1879 |
-
|
| 1880 |
-
|
| 1881 |
-
|
| 1882 |
-
|
| 1883 |
-
|
| 1884 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1885 |
|
| 1886 |
with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
|
| 1887 |
futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
|
|
@@ -1937,11 +1991,14 @@ def run_grounded_sam2_tracking(
|
|
| 1937 |
"GPU %d processing segment %d (frame %d)",
|
| 1938 |
gpu_idx, seg_idx, start_idx,
|
| 1939 |
)
|
| 1940 |
-
|
| 1941 |
-
|
| 1942 |
-
|
| 1943 |
-
|
| 1944 |
-
|
|
|
|
|
|
|
|
|
|
| 1945 |
|
| 1946 |
if job_id:
|
| 1947 |
_check_cancellation(job_id)
|
|
|
|
| 1209 |
height: int,
|
| 1210 |
width: int,
|
| 1211 |
masks_only: bool = False,
|
| 1212 |
+
frame_store=None,
|
| 1213 |
) -> np.ndarray:
|
| 1214 |
"""Render a single GSAM2 tracking frame (masks + boxes). CPU-only.
|
| 1215 |
|
| 1216 |
When *masks_only* is True, skip box rendering so the writer thread can
|
| 1217 |
draw boxes later with enriched (GPT) labels.
|
| 1218 |
"""
|
| 1219 |
+
if frame_store is not None:
|
| 1220 |
+
frame = frame_store.get_bgr(frame_idx).copy() # .copy() β render mutates
|
| 1221 |
+
else:
|
| 1222 |
+
frame_path = os.path.join(frame_dir, frame_names[frame_idx])
|
| 1223 |
+
frame = cv2.imread(frame_path)
|
| 1224 |
if frame is None:
|
| 1225 |
return np.zeros((height, width, 3), dtype=np.uint8)
|
| 1226 |
|
|
|
|
| 1294 |
from PIL import Image as PILImage
|
| 1295 |
|
| 1296 |
from utils.video import extract_frames_to_jpeg_dir
|
| 1297 |
+
from utils.frame_store import SharedFrameStore, MemoryBudgetExceeded
|
| 1298 |
from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, LazyFrameObjects
|
| 1299 |
|
| 1300 |
active_segmenter = segmenter_name or "GSAM2-L"
|
|
|
|
| 1310 |
active_segmenter, queries, step,
|
| 1311 |
)
|
| 1312 |
|
| 1313 |
+
# 1. Load frames β prefer in-memory SharedFrameStore, fall back to JPEG dir
|
| 1314 |
+
_use_frame_store = True
|
| 1315 |
+
frame_store = None
|
| 1316 |
+
_t_ext = time.perf_counter()
|
| 1317 |
try:
|
| 1318 |
+
frame_store = SharedFrameStore(input_video_path, max_frames=max_frames)
|
| 1319 |
+
fps, width, height = frame_store.fps, frame_store.width, frame_store.height
|
| 1320 |
+
total_frames = len(frame_store)
|
| 1321 |
+
frame_names = [f"{i:06d}.jpg" for i in range(total_frames)]
|
| 1322 |
+
|
| 1323 |
+
# Write single dummy JPEG for init_state bootstrapping
|
| 1324 |
+
dummy_frame_dir = tempfile.mkdtemp(prefix="gsam2_dummy_")
|
| 1325 |
+
cv2.imwrite(os.path.join(dummy_frame_dir, "000000.jpg"), frame_store.get_bgr(0))
|
| 1326 |
+
frame_dir = dummy_frame_dir
|
| 1327 |
+
logging.info("SharedFrameStore: %d frames in memory (dummy dir: %s)", total_frames, frame_dir)
|
| 1328 |
+
except MemoryBudgetExceeded:
|
| 1329 |
+
logging.info("Memory budget exceeded, falling back to JPEG extraction")
|
| 1330 |
+
_use_frame_store = False
|
| 1331 |
+
frame_store = None
|
| 1332 |
+
frame_dir = tempfile.mkdtemp(prefix="gsam2_frames_")
|
| 1333 |
frame_names, fps, width, height = extract_frames_to_jpeg_dir(
|
| 1334 |
input_video_path, frame_dir, max_frames=max_frames,
|
| 1335 |
)
|
| 1336 |
+
total_frames = len(frame_names)
|
| 1337 |
|
| 1338 |
+
try:
|
| 1339 |
if _perf_metrics is not None:
|
| 1340 |
+
_t_e2e = time.perf_counter()
|
| 1341 |
+
if torch.cuda.is_available():
|
| 1342 |
+
torch.cuda.reset_peak_memory_stats()
|
| 1343 |
_perf_metrics["frame_extraction_ms"] = (time.perf_counter() - _t_ext) * 1000.0
|
| 1344 |
+
|
| 1345 |
+
_ttfs(f"frame_extraction done ({total_frames} frames, in_memory={_use_frame_store})")
|
| 1346 |
+
logging.info("Loaded %d frames (in_memory=%s)", total_frames, _use_frame_store)
|
| 1347 |
|
| 1348 |
num_gpus = torch.cuda.device_count()
|
| 1349 |
|
|
|
|
| 1377 |
frame_dir, frame_names, fidx, fobjs,
|
| 1378 |
height, width,
|
| 1379 |
masks_only=enable_gpt,
|
| 1380 |
+
frame_store=frame_store,
|
| 1381 |
)
|
| 1382 |
|
| 1383 |
if _perf_metrics is not None:
|
|
|
|
| 1844 |
on_segment_output=_feed_segment_gpu,
|
| 1845 |
_ttfs_t0=_ttfs_t0,
|
| 1846 |
_ttfs_job_id=job_id,
|
| 1847 |
+
frame_store=frame_store,
|
| 1848 |
)
|
| 1849 |
|
| 1850 |
if _perf_metrics is not None:
|
|
|
|
| 1896 |
if _perf_metrics is not None:
|
| 1897 |
_t_init = time.perf_counter()
|
| 1898 |
|
| 1899 |
+
if frame_store is not None:
|
| 1900 |
+
# Models are lazy-loaded; ensure at least one is ready so we
|
| 1901 |
+
# can read image_size. Phase 1 (load_segmenter_on_device)
|
| 1902 |
+
# only constructs the object β _video_predictor is still None.
|
| 1903 |
+
segmenters[0]._ensure_models_loaded()
|
| 1904 |
+
sam2_img_size = segmenters[0]._video_predictor.image_size
|
| 1905 |
+
|
| 1906 |
+
# Pre-create the shared adapter (validates memory budget)
|
| 1907 |
+
shared_adapter = frame_store.sam2_adapter(image_size=sam2_img_size)
|
| 1908 |
+
|
| 1909 |
+
_REQUIRED_KEYS = {"images", "num_frames", "video_height", "video_width", "cached_features"}
|
| 1910 |
+
|
| 1911 |
+
def _init_seg_state(seg):
|
| 1912 |
+
seg._ensure_models_loaded()
|
| 1913 |
+
state = seg._video_predictor.init_state(
|
| 1914 |
+
video_path=frame_dir, # dummy dir with 1 JPEG
|
| 1915 |
+
offload_video_to_cpu=True,
|
| 1916 |
+
async_loading_frames=False, # 1 dummy frame, instant
|
| 1917 |
+
)
|
| 1918 |
+
# Validate expected keys exist before patching
|
| 1919 |
+
missing = _REQUIRED_KEYS - set(state.keys())
|
| 1920 |
+
if missing:
|
| 1921 |
+
raise RuntimeError(f"SAM2 init_state missing expected keys: {missing}")
|
| 1922 |
+
# CRITICAL: Clear cached_features BEFORE patching images
|
| 1923 |
+
# init_state caches dummy frame 0's backbone features β must evict
|
| 1924 |
+
state["cached_features"] = {}
|
| 1925 |
+
# Patch in real frame data
|
| 1926 |
+
state["images"] = shared_adapter
|
| 1927 |
+
state["num_frames"] = total_frames
|
| 1928 |
+
state["video_height"] = height
|
| 1929 |
+
state["video_width"] = width
|
| 1930 |
+
return state
|
| 1931 |
+
else:
|
| 1932 |
+
def _init_seg_state(seg):
|
| 1933 |
+
seg._ensure_models_loaded()
|
| 1934 |
+
return seg._video_predictor.init_state(
|
| 1935 |
+
video_path=frame_dir,
|
| 1936 |
+
offload_video_to_cpu=True,
|
| 1937 |
+
async_loading_frames=True,
|
| 1938 |
+
)
|
| 1939 |
|
| 1940 |
with ThreadPoolExecutor(max_workers=len(segmenters)) as pool:
|
| 1941 |
futs = [pool.submit(_init_seg_state, seg) for seg in segmenters]
|
|
|
|
| 1991 |
"GPU %d processing segment %d (frame %d)",
|
| 1992 |
gpu_idx, seg_idx, start_idx,
|
| 1993 |
)
|
| 1994 |
+
if frame_store is not None:
|
| 1995 |
+
image = frame_store.get_pil_rgb(start_idx)
|
| 1996 |
+
else:
|
| 1997 |
+
img_path = os.path.join(
|
| 1998 |
+
frame_dir, frame_names[start_idx]
|
| 1999 |
+
)
|
| 2000 |
+
with PILImage.open(img_path) as pil_img:
|
| 2001 |
+
image = pil_img.convert("RGB")
|
| 2002 |
|
| 2003 |
if job_id:
|
| 2004 |
_check_cancellation(job_id)
|
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -717,6 +717,7 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 717 |
on_segment_output: Optional[Callable[["SegmentOutput"], None]] = None,
|
| 718 |
_ttfs_t0: Optional[float] = None,
|
| 719 |
_ttfs_job_id: Optional[str] = None,
|
|
|
|
| 720 |
) -> Dict[int, Dict[int, ObjectInfo]]:
|
| 721 |
"""Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
|
| 722 |
|
|
@@ -758,11 +759,26 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 758 |
if _pm is not None:
|
| 759 |
_t_init = time.perf_counter()
|
| 760 |
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
|
| 767 |
if _pm is not None:
|
| 768 |
_pl = getattr(self, '_perf_lock', None)
|
|
@@ -775,8 +791,10 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 775 |
for start_idx in range(0, total_frames, step):
|
| 776 |
logging.info("Processing keyframe %d / %d", start_idx, total_frames)
|
| 777 |
|
| 778 |
-
|
| 779 |
-
|
|
|
|
|
|
|
| 780 |
|
| 781 |
mask_dict = MaskDictionary()
|
| 782 |
|
|
|
|
| 717 |
on_segment_output: Optional[Callable[["SegmentOutput"], None]] = None,
|
| 718 |
_ttfs_t0: Optional[float] = None,
|
| 719 |
_ttfs_job_id: Optional[str] = None,
|
| 720 |
+
frame_store=None,
|
| 721 |
) -> Dict[int, Dict[int, ObjectInfo]]:
|
| 722 |
"""Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
|
| 723 |
|
|
|
|
| 759 |
if _pm is not None:
|
| 760 |
_t_init = time.perf_counter()
|
| 761 |
|
| 762 |
+
if frame_store is not None:
|
| 763 |
+
inference_state = self._video_predictor.init_state(
|
| 764 |
+
video_path=frame_dir, # dummy dir with 1 JPEG
|
| 765 |
+
offload_video_to_cpu=True,
|
| 766 |
+
async_loading_frames=False,
|
| 767 |
+
)
|
| 768 |
+
# Clear cached_features (dummy frame 0's backbone features)
|
| 769 |
+
inference_state["cached_features"] = {}
|
| 770 |
+
# Patch in real frame data
|
| 771 |
+
img_size = self._video_predictor.image_size
|
| 772 |
+
inference_state["images"] = frame_store.sam2_adapter(image_size=img_size)
|
| 773 |
+
inference_state["num_frames"] = len(frame_store)
|
| 774 |
+
inference_state["video_height"] = frame_store.height
|
| 775 |
+
inference_state["video_width"] = frame_store.width
|
| 776 |
+
else:
|
| 777 |
+
inference_state = self._video_predictor.init_state(
|
| 778 |
+
video_path=frame_dir,
|
| 779 |
+
offload_video_to_cpu=True,
|
| 780 |
+
async_loading_frames=True,
|
| 781 |
+
)
|
| 782 |
|
| 783 |
if _pm is not None:
|
| 784 |
_pl = getattr(self, '_perf_lock', None)
|
|
|
|
| 791 |
for start_idx in range(0, total_frames, step):
|
| 792 |
logging.info("Processing keyframe %d / %d", start_idx, total_frames)
|
| 793 |
|
| 794 |
+
if frame_store is not None:
|
| 795 |
+
image = frame_store.get_pil_rgb(start_idx)
|
| 796 |
+
else:
|
| 797 |
+
image = Image.open(os.path.join(frame_dir, frame_names[start_idx])).convert("RGB")
|
| 798 |
|
| 799 |
mask_dict = MaskDictionary()
|
| 800 |
|
utils/frame_store.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""In-memory shared frame store to eliminate redundant JPEG encoding/decoding.
|
| 2 |
+
|
| 3 |
+
Replaces the pipeline:
|
| 4 |
+
MP4 β cv2 decode β JPEG encode to disk β N GPUs each decode all JPEGs back
|
| 5 |
+
With:
|
| 6 |
+
MP4 β cv2 decode once β SharedFrameStore in RAM β all GPUs read from same memory
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MemoryBudgetExceeded(Exception):
|
| 19 |
+
"""Raised when estimated memory usage exceeds the configured ceiling."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, estimated_bytes: int):
|
| 22 |
+
self.estimated_bytes = estimated_bytes
|
| 23 |
+
super().__init__(
|
| 24 |
+
f"Estimated memory {estimated_bytes / 1024**3:.1f} GiB exceeds budget"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SharedFrameStore:
|
| 29 |
+
"""Read-only in-memory store for decoded video frames (BGR uint8).
|
| 30 |
+
|
| 31 |
+
Decodes the video once via cv2.VideoCapture and holds all frames in a list.
|
| 32 |
+
Thread-safe for concurrent reads (frames list is never mutated after init).
|
| 33 |
+
|
| 34 |
+
Raises MemoryBudgetExceeded BEFORE decoding if estimated memory exceeds
|
| 35 |
+
the budget ceiling, giving callers a chance to fall back to JPEG path.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
MAX_BUDGET_BYTES = 12 * 1024**3 # 12 GiB ceiling
|
| 39 |
+
|
| 40 |
+
def __init__(self, video_path: str, max_frames: Optional[int] = None):
|
| 41 |
+
cap = cv2.VideoCapture(video_path)
|
| 42 |
+
if not cap.isOpened():
|
| 43 |
+
raise RuntimeError(f"Cannot open video: {video_path}")
|
| 44 |
+
|
| 45 |
+
self.fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
| 46 |
+
self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 47 |
+
self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 48 |
+
|
| 49 |
+
# Estimate frame count BEFORE decoding to check memory budget
|
| 50 |
+
reported_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 51 |
+
if reported_count <= 0:
|
| 52 |
+
reported_count = 10000 # conservative fallback
|
| 53 |
+
est_frames = min(reported_count, max_frames) if max_frames else reported_count
|
| 54 |
+
|
| 55 |
+
# Budget: raw BGR frames + worst-case SAM2 adapter tensors (image_size=1024)
|
| 56 |
+
per_frame_raw = self.height * self.width * 3 # uint8 BGR
|
| 57 |
+
per_frame_adapter = 3 * 1024 * 1024 * 4 # float32, worst-case 1024x1024
|
| 58 |
+
total_est = est_frames * (per_frame_raw + per_frame_adapter)
|
| 59 |
+
if total_est > self.MAX_BUDGET_BYTES:
|
| 60 |
+
cap.release()
|
| 61 |
+
logging.warning(
|
| 62 |
+
"SharedFrameStore: estimated ~%.1f GiB for %d frames exceeds "
|
| 63 |
+
"%.1f GiB budget; skipping in-memory path",
|
| 64 |
+
total_est / 1024**3, est_frames, self.MAX_BUDGET_BYTES / 1024**3,
|
| 65 |
+
)
|
| 66 |
+
raise MemoryBudgetExceeded(total_est)
|
| 67 |
+
|
| 68 |
+
frames = []
|
| 69 |
+
while True:
|
| 70 |
+
if max_frames is not None and len(frames) >= max_frames:
|
| 71 |
+
break
|
| 72 |
+
ret, frame = cap.read()
|
| 73 |
+
if not ret:
|
| 74 |
+
break
|
| 75 |
+
frames.append(frame)
|
| 76 |
+
cap.release()
|
| 77 |
+
|
| 78 |
+
if not frames:
|
| 79 |
+
raise RuntimeError(f"No frames decoded from: {video_path}")
|
| 80 |
+
|
| 81 |
+
self.frames = frames
|
| 82 |
+
logging.info(
|
| 83 |
+
"SharedFrameStore: %d frames, %dx%d, %.1f fps",
|
| 84 |
+
len(self.frames), self.width, self.height, self.fps,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def __len__(self) -> int:
|
| 88 |
+
return len(self.frames)
|
| 89 |
+
|
| 90 |
+
def get_bgr(self, idx: int) -> np.ndarray:
|
| 91 |
+
"""Return BGR frame. Caller must .copy() if mutating."""
|
| 92 |
+
return self.frames[idx]
|
| 93 |
+
|
| 94 |
+
def get_pil_rgb(self, idx: int) -> Image.Image:
|
| 95 |
+
"""Return PIL RGB Image for the given frame index."""
|
| 96 |
+
bgr = self.frames[idx]
|
| 97 |
+
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
| 98 |
+
return Image.fromarray(rgb)
|
| 99 |
+
|
| 100 |
+
def sam2_adapter(self, image_size: int) -> "SAM2FrameAdapter":
|
| 101 |
+
"""Factory for SAM2-compatible frame adapter. Returns same adapter for same size."""
|
| 102 |
+
if not hasattr(self, "_adapters"):
|
| 103 |
+
self._adapters = {}
|
| 104 |
+
if image_size not in self._adapters:
|
| 105 |
+
self._adapters[image_size] = SAM2FrameAdapter(self, image_size)
|
| 106 |
+
return self._adapters[image_size]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SAM2FrameAdapter:
|
| 110 |
+
"""Drop-in replacement for SAM2's AsyncVideoFrameLoader.
|
| 111 |
+
|
| 112 |
+
Matches the interface that SAM2's init_state / propagate_in_video expects:
|
| 113 |
+
- __len__() β number of frames
|
| 114 |
+
- __getitem__(idx) β normalized float32 tensor (3, H, W)
|
| 115 |
+
- .images list (SAM2 accesses this directly in some paths)
|
| 116 |
+
- .video_height, .video_width
|
| 117 |
+
- .exception (AsyncVideoFrameLoader compat)
|
| 118 |
+
|
| 119 |
+
Transform parity: uses PIL Image.resize() with BICUBIC (the default),
|
| 120 |
+
matching SAM2's _load_img_as_tensor exactly.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, store: SharedFrameStore, image_size: int):
|
| 124 |
+
self._store = store
|
| 125 |
+
self._image_size = image_size
|
| 126 |
+
self.images = [None] * len(store) # SAM2 accesses .images directly
|
| 127 |
+
self.video_height = store.height
|
| 128 |
+
self.video_width = store.width
|
| 129 |
+
self.exception = None # AsyncVideoFrameLoader compat
|
| 130 |
+
|
| 131 |
+
# ImageNet normalization constants (must match SAM2's _load_img_as_tensor)
|
| 132 |
+
self._mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
|
| 133 |
+
self._std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1)
|
| 134 |
+
|
| 135 |
+
def __len__(self) -> int:
|
| 136 |
+
return len(self._store)
|
| 137 |
+
|
| 138 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
| 139 |
+
if self.images[idx] is not None:
|
| 140 |
+
return self.images[idx]
|
| 141 |
+
|
| 142 |
+
# TRANSFORM PARITY: Must match SAM2's _load_img_as_tensor exactly.
|
| 143 |
+
# SAM2 does: PIL Image β .convert("RGB") β .resize((size, size)) β /255 β permute β normalize
|
| 144 |
+
# PIL.resize default = BICUBIC. We must use PIL resize, NOT cv2.resize.
|
| 145 |
+
bgr = self._store.get_bgr(idx)
|
| 146 |
+
pil_img = Image.fromarray(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
|
| 147 |
+
pil_resized = pil_img.resize(
|
| 148 |
+
(self._image_size, self._image_size)
|
| 149 |
+
) # BICUBIC default
|
| 150 |
+
img_np = np.array(pil_resized) / 255.0
|
| 151 |
+
img = torch.from_numpy(img_np).permute(2, 0, 1).float()
|
| 152 |
+
img = (img - self._mean) / self._std
|
| 153 |
+
self.images[idx] = img
|
| 154 |
+
return img
|