| """Command-line interface for MathVision Explorer.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| from mathvision_explorer.dataset import load_jsonl_records, summarize_records |
| from mathvision_explorer.demo import create_demo_dataset |
| from mathvision_explorer.embeddings import ( |
| ColorStatsEmbedder, |
| IJepaImageEmbedder, |
| ImageEmbedder, |
| embed_record_image, |
| ) |
| from mathvision_explorer.explorer import build_image_index |
| from mathvision_explorer.html import export_html |
| from mathvision_explorer.index import VectorIndex |
|
|
|
|
| def main() -> None: |
| """Run the MathVision Explorer command-line interface.""" |
|
|
| parser = argparse.ArgumentParser(prog="mathvision") |
| subparsers = parser.add_subparsers(dest="command", required=True) |
|
|
| inspect_parser = subparsers.add_parser("inspect", help="Inspect a MathVision-style JSONL file.") |
| inspect_parser.add_argument("jsonl", type=Path) |
|
|
| demo_parser = subparsers.add_parser("demo", help="Create a tiny local demo dataset.") |
| demo_parser.add_argument("output_dir", type=Path) |
|
|
| html_parser = subparsers.add_parser("export-html", help="Export records to a browser gallery.") |
| html_parser.add_argument("jsonl", type=Path) |
| html_parser.add_argument("output", type=Path) |
|
|
| index_parser = subparsers.add_parser("index", help="Build a local image-feature index.") |
| index_parser.add_argument("jsonl", type=Path) |
| index_parser.add_argument("output", type=Path) |
| _add_embedder_arguments(index_parser) |
|
|
| search_parser = subparsers.add_parser("search", help="Search similar indexed records.") |
| search_parser.add_argument("jsonl", type=Path) |
| search_parser.add_argument("index", type=Path) |
| search_parser.add_argument("query_id") |
| search_parser.add_argument("--limit", type=int, default=5) |
| _add_embedder_arguments(search_parser) |
|
|
| args = parser.parse_args() |
| if args.command == "inspect": |
| _inspect(args.jsonl) |
| elif args.command == "demo": |
| _demo(args.output_dir) |
| elif args.command == "export-html": |
| _export_html(args.jsonl, args.output) |
| elif args.command == "index": |
| _index(args.jsonl, args.output, embedder=_embedder_from_args(args)) |
| elif args.command == "search": |
| _search( |
| args.jsonl, |
| args.index, |
| args.query_id, |
| limit=args.limit, |
| embedder=_embedder_from_args(args), |
| ) |
|
|
|
|
| def _inspect(jsonl: Path) -> None: |
| records = load_jsonl_records(jsonl) |
| print(json.dumps(summarize_records(records), indent=2, sort_keys=True)) |
|
|
|
|
| def _demo(output_dir: Path) -> None: |
| jsonl_path = create_demo_dataset(output_dir) |
| print(f"Wrote demo dataset to {jsonl_path}") |
|
|
|
|
| def _export_html(jsonl: Path, output: Path) -> None: |
| records = load_jsonl_records(jsonl) |
| export_html(records, output) |
| print(f"Wrote gallery to {output}") |
|
|
|
|
| def _index(jsonl: Path, output: Path, *, embedder: ImageEmbedder) -> None: |
| records = load_jsonl_records(jsonl) |
| index = build_image_index(records, embedder) |
| index.save_tsv(output) |
| print(f"Wrote {len(index)} vectors to {output}") |
|
|
|
|
| def _search( |
| jsonl: Path, |
| index_path: Path, |
| query_id: str, |
| *, |
| limit: int, |
| embedder: ImageEmbedder, |
| ) -> None: |
| records = load_jsonl_records(jsonl) |
| record_by_id = {record.problem_id: record for record in records} |
| query_record = record_by_id.get(query_id) |
| if query_record is None: |
| raise SystemExit(f"Unknown query id: {query_id}") |
|
|
| query_vector = embed_record_image(query_record.image_path, embedder) |
| index = VectorIndex.load_tsv(index_path) |
| for neighbor in index.search(query_vector, limit=limit, exclude_id=query_id): |
| record = record_by_id.get(neighbor.item_id) |
| label = record.question if record is not None else neighbor.item_id |
| print(f"{neighbor.score:.4f}\t{neighbor.item_id}\t{label}") |
|
|
|
|
| def _add_embedder_arguments(parser: argparse.ArgumentParser) -> None: |
| parser.add_argument("--embedder", choices=["color", "ijepa"], default="color") |
| parser.add_argument("--jepa-model", default="facebook/ijepa_vith14_1k") |
| parser.add_argument("--jepa-device", default=None) |
|
|
|
|
| def _embedder_from_args(args: argparse.Namespace) -> ImageEmbedder: |
| if args.embedder == "ijepa": |
| return IJepaImageEmbedder( |
| model_id=args.jepa_model, |
| device=args.jepa_device, |
| ) |
| return ColorStatsEmbedder() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|