ddebree's picture
Prepare Hugging Face Space deploy
3e67073
"""Command-line interface for MathVision Explorer."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from mathvision_explorer.dataset import load_jsonl_records, summarize_records
from mathvision_explorer.demo import create_demo_dataset
from mathvision_explorer.embeddings import (
ColorStatsEmbedder,
IJepaImageEmbedder,
ImageEmbedder,
embed_record_image,
)
from mathvision_explorer.explorer import build_image_index
from mathvision_explorer.html import export_html
from mathvision_explorer.index import VectorIndex
def main() -> None:
"""Run the MathVision Explorer command-line interface."""
parser = argparse.ArgumentParser(prog="mathvision")
subparsers = parser.add_subparsers(dest="command", required=True)
inspect_parser = subparsers.add_parser("inspect", help="Inspect a MathVision-style JSONL file.")
inspect_parser.add_argument("jsonl", type=Path)
demo_parser = subparsers.add_parser("demo", help="Create a tiny local demo dataset.")
demo_parser.add_argument("output_dir", type=Path)
html_parser = subparsers.add_parser("export-html", help="Export records to a browser gallery.")
html_parser.add_argument("jsonl", type=Path)
html_parser.add_argument("output", type=Path)
index_parser = subparsers.add_parser("index", help="Build a local image-feature index.")
index_parser.add_argument("jsonl", type=Path)
index_parser.add_argument("output", type=Path)
_add_embedder_arguments(index_parser)
search_parser = subparsers.add_parser("search", help="Search similar indexed records.")
search_parser.add_argument("jsonl", type=Path)
search_parser.add_argument("index", type=Path)
search_parser.add_argument("query_id")
search_parser.add_argument("--limit", type=int, default=5)
_add_embedder_arguments(search_parser)
args = parser.parse_args()
if args.command == "inspect":
_inspect(args.jsonl)
elif args.command == "demo":
_demo(args.output_dir)
elif args.command == "export-html":
_export_html(args.jsonl, args.output)
elif args.command == "index":
_index(args.jsonl, args.output, embedder=_embedder_from_args(args))
elif args.command == "search":
_search(
args.jsonl,
args.index,
args.query_id,
limit=args.limit,
embedder=_embedder_from_args(args),
)
def _inspect(jsonl: Path) -> None:
records = load_jsonl_records(jsonl)
print(json.dumps(summarize_records(records), indent=2, sort_keys=True))
def _demo(output_dir: Path) -> None:
jsonl_path = create_demo_dataset(output_dir)
print(f"Wrote demo dataset to {jsonl_path}")
def _export_html(jsonl: Path, output: Path) -> None:
records = load_jsonl_records(jsonl)
export_html(records, output)
print(f"Wrote gallery to {output}")
def _index(jsonl: Path, output: Path, *, embedder: ImageEmbedder) -> None:
records = load_jsonl_records(jsonl)
index = build_image_index(records, embedder)
index.save_tsv(output)
print(f"Wrote {len(index)} vectors to {output}")
def _search(
jsonl: Path,
index_path: Path,
query_id: str,
*,
limit: int,
embedder: ImageEmbedder,
) -> None:
records = load_jsonl_records(jsonl)
record_by_id = {record.problem_id: record for record in records}
query_record = record_by_id.get(query_id)
if query_record is None:
raise SystemExit(f"Unknown query id: {query_id}")
query_vector = embed_record_image(query_record.image_path, embedder)
index = VectorIndex.load_tsv(index_path)
for neighbor in index.search(query_vector, limit=limit, exclude_id=query_id):
record = record_by_id.get(neighbor.item_id)
label = record.question if record is not None else neighbor.item_id
print(f"{neighbor.score:.4f}\t{neighbor.item_id}\t{label}")
def _add_embedder_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--embedder", choices=["color", "ijepa"], default="color")
parser.add_argument("--jepa-model", default="facebook/ijepa_vith14_1k")
parser.add_argument("--jepa-device", default=None)
def _embedder_from_args(args: argparse.Namespace) -> ImageEmbedder:
if args.embedder == "ijepa":
return IJepaImageEmbedder(
model_id=args.jepa_model,
device=args.jepa_device,
)
return ColorStatsEmbedder()
if __name__ == "__main__":
main()