File size: 3,670 Bytes
10f01b1
 
15d3835
 
10f01b1
c29be10
 
 
10f01b1
c29be10
10f01b1
 
c29be10
10f01b1
c29be10
10f01b1
 
 
c29be10
10f01b1
 
 
 
 
 
 
 
 
 
c29be10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10f01b1
 
 
c29be10
 
 
 
10f01b1
 
15d3835
10f01b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15d3835
 
 
 
10f01b1
c29be10
15d3835
fed0900
c29be10
10f01b1
15d3835
 
 
10f01b1
15d3835
 
 
c29be10
15d3835
 
 
 
 
10f01b1
15d3835
 
 
c29be10
 
 
15d3835
10f01b1
c29be10
15d3835
10f01b1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
End-to-end test: pulls images and videos from HF datasets, indexes them
with ragstudio, syncs the index to the team-0/ragstudio bucket, then runs
a sample search over each modality.

Both the downloaded datasets and the FAISS index live under tests/.cache/,
so re-runs reuse them. Pass --force to rebuild the index from scratch.

Usage:
    python tests/e2e.py [--force]
"""

import argparse
import os
import re
import sys
from pathlib import Path

import cv2
from huggingface_hub import snapshot_download

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))

IMAGE_DATASET = "team-0/pytorch-conference"
VIDEO_DATASET = "team-0/test-videos"

IMAGE_QUERY = "a person presenting on stage"

CACHE_DIR = ROOT / "tests" / ".cache"
DATA_DIR = CACHE_DIR / "data"
FRAMES_DIR = CACHE_DIR / "frames"

# Video meta entries are formatted by indexers/video.py as "<path> @ <t>s".
VIDEO_META_RE = re.compile(r"^(.*) @ (\d+\.\d+)s$")


def extract_video_frame(meta: str, out_dir: Path) -> Path:
    m = VIDEO_META_RE.match(meta)
    if not m:
        raise ValueError(f"unexpected video meta format: {meta!r}")
    vpath, t = m.group(1), float(m.group(2))

    cap = cv2.VideoCapture(vpath)
    fps = cap.get(cv2.CAP_PROP_FPS) or 0
    if fps > 0:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(round(t * fps)))
    else:
        cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
    ok, frame = cap.read()
    cap.release()
    if not ok:
        raise RuntimeError(f"failed to read frame at {t}s from {vpath}")

    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / f"{Path(vpath).stem}_{t:.2f}s.png"
    cv2.imwrite(str(out_path), frame)
    return out_path


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__.strip())
    parser.add_argument("--force", action="store_true", help="rebuild the index even if cached")
    args = parser.parse_args()

    DATA_DIR.mkdir(parents=True, exist_ok=True)

    print(f"[1/4] Downloading datasets to {DATA_DIR}")
    snapshot_download(
        repo_id=IMAGE_DATASET,
        repo_type="dataset",
        local_dir=DATA_DIR / "images",
    )
    videos_dir = DATA_DIR / "videos"
    snapshot_download(
        repo_id=VIDEO_DATASET,
        repo_type="dataset",
        local_dir=videos_dir,
    )
    video_queries = [
        line.strip()
        for line in (videos_dir / "queries.txt").read_text().splitlines()
        if line.strip()
    ]
    print(f"  {len(video_queries)} video queries from dataset")

    # `index_data/` is resolved relative to cwd by index.py, sync.py, and
    # searchers/__init__.py — pin cwd to the repo root so the index persists
    # at <repo>/index_data/ instead of vanishing with a temp dir.
    os.chdir(ROOT)

    print(f"[2/4] Building index from {DATA_DIR}{' (force)' if args.force else ''}")
    from index import build_index

    build_index(DATA_DIR, force=args.force)

    print("[3/4] Syncing index to team-0/ragstudio bucket")
    from sync import sync
    sync()

    print("[4/4] Searching")
    from searchers import SEARCHERS


    image_hits = SEARCHERS["image"](IMAGE_QUERY, top_k=3)
    print(f"\n=== image: {IMAGE_QUERY!r} ===")
    for s, p in image_hits:
        print(f"  {s:.3f}  {p}")
    assert image_hits, "no image results"

    for q in video_queries:
        video_hits = SEARCHERS["video"](q, top_k=3)
        print(f"\n=== video: {q!r} ===")
        for s, meta in video_hits:
            frame_path = extract_video_frame(meta, FRAMES_DIR)
            print(f"  {s:.3f}  {meta}  ->  {frame_path}")
        assert video_hits, f"no video results for {q!r}"


    print("\nOK")


if __name__ == "__main__":
    main()