File size: 1,568 Bytes
f9306c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | """High-level workflows for MathVision exploration."""
from __future__ import annotations
from pathlib import Path
from mathvision_explorer.dataset import MathVisionRecord, filter_records, load_jsonl_records
from mathvision_explorer.embeddings import ImageEmbedder, embed_record_image
from mathvision_explorer.index import Neighbor, VectorIndex
def build_image_index(records: list[MathVisionRecord], embedder: ImageEmbedder) -> VectorIndex:
"""Build a vector index for all records that have image paths."""
index = VectorIndex()
for record in records:
if record.image_path is None:
continue
index.add(record.problem_id, embed_record_image(record.image_path, embedder))
return index
def find_similar_records(
records: list[MathVisionRecord],
index: VectorIndex,
query_id: str,
query_vector: tuple[float, ...],
*,
limit: int = 5,
) -> list[tuple[MathVisionRecord, Neighbor]]:
"""Find records nearest to a query vector."""
record_by_id = {record.problem_id: record for record in records}
neighbors = index.search(query_vector, limit=limit, exclude_id=query_id)
return [
(record_by_id[neighbor.item_id], neighbor)
for neighbor in neighbors
if neighbor.item_id in record_by_id
]
def load_filtered_records(
path: Path, *, subject: str | None = None, level: int | None = None
) -> list[MathVisionRecord]:
"""Load records and apply optional explorer filters."""
return filter_records(load_jsonl_records(path), subject=subject, level=level)
|