File size: 4,497 Bytes
f9306c2 3e67073 f9306c2 3e67073 f9306c2 3e67073 f9306c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """Command-line interface for MathVision Explorer."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from mathvision_explorer.dataset import load_jsonl_records, summarize_records
from mathvision_explorer.demo import create_demo_dataset
from mathvision_explorer.embeddings import (
ColorStatsEmbedder,
IJepaImageEmbedder,
ImageEmbedder,
embed_record_image,
)
from mathvision_explorer.explorer import build_image_index
from mathvision_explorer.html import export_html
from mathvision_explorer.index import VectorIndex
def main() -> None:
"""Run the MathVision Explorer command-line interface."""
parser = argparse.ArgumentParser(prog="mathvision")
subparsers = parser.add_subparsers(dest="command", required=True)
inspect_parser = subparsers.add_parser("inspect", help="Inspect a MathVision-style JSONL file.")
inspect_parser.add_argument("jsonl", type=Path)
demo_parser = subparsers.add_parser("demo", help="Create a tiny local demo dataset.")
demo_parser.add_argument("output_dir", type=Path)
html_parser = subparsers.add_parser("export-html", help="Export records to a browser gallery.")
html_parser.add_argument("jsonl", type=Path)
html_parser.add_argument("output", type=Path)
index_parser = subparsers.add_parser("index", help="Build a local image-feature index.")
index_parser.add_argument("jsonl", type=Path)
index_parser.add_argument("output", type=Path)
_add_embedder_arguments(index_parser)
search_parser = subparsers.add_parser("search", help="Search similar indexed records.")
search_parser.add_argument("jsonl", type=Path)
search_parser.add_argument("index", type=Path)
search_parser.add_argument("query_id")
search_parser.add_argument("--limit", type=int, default=5)
_add_embedder_arguments(search_parser)
args = parser.parse_args()
if args.command == "inspect":
_inspect(args.jsonl)
elif args.command == "demo":
_demo(args.output_dir)
elif args.command == "export-html":
_export_html(args.jsonl, args.output)
elif args.command == "index":
_index(args.jsonl, args.output, embedder=_embedder_from_args(args))
elif args.command == "search":
_search(
args.jsonl,
args.index,
args.query_id,
limit=args.limit,
embedder=_embedder_from_args(args),
)
def _inspect(jsonl: Path) -> None:
records = load_jsonl_records(jsonl)
print(json.dumps(summarize_records(records), indent=2, sort_keys=True))
def _demo(output_dir: Path) -> None:
jsonl_path = create_demo_dataset(output_dir)
print(f"Wrote demo dataset to {jsonl_path}")
def _export_html(jsonl: Path, output: Path) -> None:
records = load_jsonl_records(jsonl)
export_html(records, output)
print(f"Wrote gallery to {output}")
def _index(jsonl: Path, output: Path, *, embedder: ImageEmbedder) -> None:
records = load_jsonl_records(jsonl)
index = build_image_index(records, embedder)
index.save_tsv(output)
print(f"Wrote {len(index)} vectors to {output}")
def _search(
jsonl: Path,
index_path: Path,
query_id: str,
*,
limit: int,
embedder: ImageEmbedder,
) -> None:
records = load_jsonl_records(jsonl)
record_by_id = {record.problem_id: record for record in records}
query_record = record_by_id.get(query_id)
if query_record is None:
raise SystemExit(f"Unknown query id: {query_id}")
query_vector = embed_record_image(query_record.image_path, embedder)
index = VectorIndex.load_tsv(index_path)
for neighbor in index.search(query_vector, limit=limit, exclude_id=query_id):
record = record_by_id.get(neighbor.item_id)
label = record.question if record is not None else neighbor.item_id
print(f"{neighbor.score:.4f}\t{neighbor.item_id}\t{label}")
def _add_embedder_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--embedder", choices=["color", "ijepa"], default="color")
parser.add_argument("--jepa-model", default="facebook/ijepa_vith14_1k")
parser.add_argument("--jepa-device", default=None)
def _embedder_from_args(args: argparse.Namespace) -> ImageEmbedder:
if args.embedder == "ijepa":
return IJepaImageEmbedder(
model_id=args.jepa_model,
device=args.jepa_device,
)
return ColorStatsEmbedder()
if __name__ == "__main__":
main()
|