File size: 4,497 Bytes
f9306c2
 
 
 
 
 
 
 
 
 
 
 
3e67073
f9306c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e67073
 
f9306c2
 
 
 
3e67073
 
f9306c2
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Command-line interface for MathVision Explorer."""

from __future__ import annotations

import argparse
import json
from pathlib import Path

from mathvision_explorer.dataset import load_jsonl_records, summarize_records
from mathvision_explorer.demo import create_demo_dataset
from mathvision_explorer.embeddings import (
    ColorStatsEmbedder,
    IJepaImageEmbedder,
    ImageEmbedder,
    embed_record_image,
)
from mathvision_explorer.explorer import build_image_index
from mathvision_explorer.html import export_html
from mathvision_explorer.index import VectorIndex


def main() -> None:
    """Run the MathVision Explorer command-line interface."""

    parser = argparse.ArgumentParser(prog="mathvision")
    subparsers = parser.add_subparsers(dest="command", required=True)

    inspect_parser = subparsers.add_parser("inspect", help="Inspect a MathVision-style JSONL file.")
    inspect_parser.add_argument("jsonl", type=Path)

    demo_parser = subparsers.add_parser("demo", help="Create a tiny local demo dataset.")
    demo_parser.add_argument("output_dir", type=Path)

    html_parser = subparsers.add_parser("export-html", help="Export records to a browser gallery.")
    html_parser.add_argument("jsonl", type=Path)
    html_parser.add_argument("output", type=Path)

    index_parser = subparsers.add_parser("index", help="Build a local image-feature index.")
    index_parser.add_argument("jsonl", type=Path)
    index_parser.add_argument("output", type=Path)
    _add_embedder_arguments(index_parser)

    search_parser = subparsers.add_parser("search", help="Search similar indexed records.")
    search_parser.add_argument("jsonl", type=Path)
    search_parser.add_argument("index", type=Path)
    search_parser.add_argument("query_id")
    search_parser.add_argument("--limit", type=int, default=5)
    _add_embedder_arguments(search_parser)

    args = parser.parse_args()
    if args.command == "inspect":
        _inspect(args.jsonl)
    elif args.command == "demo":
        _demo(args.output_dir)
    elif args.command == "export-html":
        _export_html(args.jsonl, args.output)
    elif args.command == "index":
        _index(args.jsonl, args.output, embedder=_embedder_from_args(args))
    elif args.command == "search":
        _search(
            args.jsonl,
            args.index,
            args.query_id,
            limit=args.limit,
            embedder=_embedder_from_args(args),
        )


def _inspect(jsonl: Path) -> None:
    records = load_jsonl_records(jsonl)
    print(json.dumps(summarize_records(records), indent=2, sort_keys=True))


def _demo(output_dir: Path) -> None:
    jsonl_path = create_demo_dataset(output_dir)
    print(f"Wrote demo dataset to {jsonl_path}")


def _export_html(jsonl: Path, output: Path) -> None:
    records = load_jsonl_records(jsonl)
    export_html(records, output)
    print(f"Wrote gallery to {output}")


def _index(jsonl: Path, output: Path, *, embedder: ImageEmbedder) -> None:
    records = load_jsonl_records(jsonl)
    index = build_image_index(records, embedder)
    index.save_tsv(output)
    print(f"Wrote {len(index)} vectors to {output}")


def _search(
    jsonl: Path,
    index_path: Path,
    query_id: str,
    *,
    limit: int,
    embedder: ImageEmbedder,
) -> None:
    records = load_jsonl_records(jsonl)
    record_by_id = {record.problem_id: record for record in records}
    query_record = record_by_id.get(query_id)
    if query_record is None:
        raise SystemExit(f"Unknown query id: {query_id}")

    query_vector = embed_record_image(query_record.image_path, embedder)
    index = VectorIndex.load_tsv(index_path)
    for neighbor in index.search(query_vector, limit=limit, exclude_id=query_id):
        record = record_by_id.get(neighbor.item_id)
        label = record.question if record is not None else neighbor.item_id
        print(f"{neighbor.score:.4f}\t{neighbor.item_id}\t{label}")


def _add_embedder_arguments(parser: argparse.ArgumentParser) -> None:
    parser.add_argument("--embedder", choices=["color", "ijepa"], default="color")
    parser.add_argument("--jepa-model", default="facebook/ijepa_vith14_1k")
    parser.add_argument("--jepa-device", default=None)


def _embedder_from_args(args: argparse.Namespace) -> ImageEmbedder:
    if args.embedder == "ijepa":
        return IJepaImageEmbedder(
            model_id=args.jepa_model,
            device=args.jepa_device,
        )
    return ColorStatsEmbedder()


if __name__ == "__main__":
    main()