"""Command-line interface for MathVision Explorer.""" from __future__ import annotations import argparse import json from pathlib import Path from mathvision_explorer.dataset import load_jsonl_records, summarize_records from mathvision_explorer.demo import create_demo_dataset from mathvision_explorer.embeddings import ( ColorStatsEmbedder, IJepaImageEmbedder, ImageEmbedder, embed_record_image, ) from mathvision_explorer.explorer import build_image_index from mathvision_explorer.html import export_html from mathvision_explorer.index import VectorIndex def main() -> None: """Run the MathVision Explorer command-line interface.""" parser = argparse.ArgumentParser(prog="mathvision") subparsers = parser.add_subparsers(dest="command", required=True) inspect_parser = subparsers.add_parser("inspect", help="Inspect a MathVision-style JSONL file.") inspect_parser.add_argument("jsonl", type=Path) demo_parser = subparsers.add_parser("demo", help="Create a tiny local demo dataset.") demo_parser.add_argument("output_dir", type=Path) html_parser = subparsers.add_parser("export-html", help="Export records to a browser gallery.") html_parser.add_argument("jsonl", type=Path) html_parser.add_argument("output", type=Path) index_parser = subparsers.add_parser("index", help="Build a local image-feature index.") index_parser.add_argument("jsonl", type=Path) index_parser.add_argument("output", type=Path) _add_embedder_arguments(index_parser) search_parser = subparsers.add_parser("search", help="Search similar indexed records.") search_parser.add_argument("jsonl", type=Path) search_parser.add_argument("index", type=Path) search_parser.add_argument("query_id") search_parser.add_argument("--limit", type=int, default=5) _add_embedder_arguments(search_parser) args = parser.parse_args() if args.command == "inspect": _inspect(args.jsonl) elif args.command == "demo": _demo(args.output_dir) elif args.command == "export-html": _export_html(args.jsonl, args.output) elif args.command == "index": _index(args.jsonl, args.output, embedder=_embedder_from_args(args)) elif args.command == "search": _search( args.jsonl, args.index, args.query_id, limit=args.limit, embedder=_embedder_from_args(args), ) def _inspect(jsonl: Path) -> None: records = load_jsonl_records(jsonl) print(json.dumps(summarize_records(records), indent=2, sort_keys=True)) def _demo(output_dir: Path) -> None: jsonl_path = create_demo_dataset(output_dir) print(f"Wrote demo dataset to {jsonl_path}") def _export_html(jsonl: Path, output: Path) -> None: records = load_jsonl_records(jsonl) export_html(records, output) print(f"Wrote gallery to {output}") def _index(jsonl: Path, output: Path, *, embedder: ImageEmbedder) -> None: records = load_jsonl_records(jsonl) index = build_image_index(records, embedder) index.save_tsv(output) print(f"Wrote {len(index)} vectors to {output}") def _search( jsonl: Path, index_path: Path, query_id: str, *, limit: int, embedder: ImageEmbedder, ) -> None: records = load_jsonl_records(jsonl) record_by_id = {record.problem_id: record for record in records} query_record = record_by_id.get(query_id) if query_record is None: raise SystemExit(f"Unknown query id: {query_id}") query_vector = embed_record_image(query_record.image_path, embedder) index = VectorIndex.load_tsv(index_path) for neighbor in index.search(query_vector, limit=limit, exclude_id=query_id): record = record_by_id.get(neighbor.item_id) label = record.question if record is not None else neighbor.item_id print(f"{neighbor.score:.4f}\t{neighbor.item_id}\t{label}") def _add_embedder_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument("--embedder", choices=["color", "ijepa"], default="color") parser.add_argument("--jepa-model", default="facebook/ijepa_vith14_1k") parser.add_argument("--jepa-device", default=None) def _embedder_from_args(args: argparse.Namespace) -> ImageEmbedder: if args.embedder == "ijepa": return IJepaImageEmbedder( model_id=args.jepa_model, device=args.jepa_device, ) return ColorStatsEmbedder() if __name__ == "__main__": main()