tarekziade's picture
tarekziade HF Staff
extract frame
c29be10
"""
End-to-end test: pulls images and videos from HF datasets, indexes them
with ragstudio, syncs the index to the team-0/ragstudio bucket, then runs
a sample search over each modality.
Both the downloaded datasets and the FAISS index live under tests/.cache/,
so re-runs reuse them. Pass --force to rebuild the index from scratch.
Usage:
python tests/e2e.py [--force]
"""
import argparse
import os
import re
import sys
from pathlib import Path
import cv2
from huggingface_hub import snapshot_download
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
IMAGE_DATASET = "team-0/pytorch-conference"
VIDEO_DATASET = "team-0/test-videos"
IMAGE_QUERY = "a person presenting on stage"
CACHE_DIR = ROOT / "tests" / ".cache"
DATA_DIR = CACHE_DIR / "data"
FRAMES_DIR = CACHE_DIR / "frames"
# Video meta entries are formatted by indexers/video.py as "<path> @ <t>s".
VIDEO_META_RE = re.compile(r"^(.*) @ (\d+\.\d+)s$")
def extract_video_frame(meta: str, out_dir: Path) -> Path:
m = VIDEO_META_RE.match(meta)
if not m:
raise ValueError(f"unexpected video meta format: {meta!r}")
vpath, t = m.group(1), float(m.group(2))
cap = cv2.VideoCapture(vpath)
fps = cap.get(cv2.CAP_PROP_FPS) or 0
if fps > 0:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(round(t * fps)))
else:
cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
ok, frame = cap.read()
cap.release()
if not ok:
raise RuntimeError(f"failed to read frame at {t}s from {vpath}")
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{Path(vpath).stem}_{t:.2f}s.png"
cv2.imwrite(str(out_path), frame)
return out_path
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__.strip())
parser.add_argument("--force", action="store_true", help="rebuild the index even if cached")
args = parser.parse_args()
DATA_DIR.mkdir(parents=True, exist_ok=True)
print(f"[1/4] Downloading datasets to {DATA_DIR}")
snapshot_download(
repo_id=IMAGE_DATASET,
repo_type="dataset",
local_dir=DATA_DIR / "images",
)
videos_dir = DATA_DIR / "videos"
snapshot_download(
repo_id=VIDEO_DATASET,
repo_type="dataset",
local_dir=videos_dir,
)
video_queries = [
line.strip()
for line in (videos_dir / "queries.txt").read_text().splitlines()
if line.strip()
]
print(f" {len(video_queries)} video queries from dataset")
# `index_data/` is resolved relative to cwd by index.py, sync.py, and
# searchers/__init__.py — pin cwd to the repo root so the index persists
# at <repo>/index_data/ instead of vanishing with a temp dir.
os.chdir(ROOT)
print(f"[2/4] Building index from {DATA_DIR}{' (force)' if args.force else ''}")
from index import build_index
build_index(DATA_DIR, force=args.force)
print("[3/4] Syncing index to team-0/ragstudio bucket")
from sync import sync
sync()
print("[4/4] Searching")
from searchers import SEARCHERS
image_hits = SEARCHERS["image"](IMAGE_QUERY, top_k=3)
print(f"\n=== image: {IMAGE_QUERY!r} ===")
for s, p in image_hits:
print(f" {s:.3f} {p}")
assert image_hits, "no image results"
for q in video_queries:
video_hits = SEARCHERS["video"](q, top_k=3)
print(f"\n=== video: {q!r} ===")
for s, meta in video_hits:
frame_path = extract_video_frame(meta, FRAMES_DIR)
print(f" {s:.3f} {meta} -> {frame_path}")
assert video_hits, f"no video results for {q!r}"
print("\nOK")
if __name__ == "__main__":
main()