File size: 1,568 Bytes
f9306c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""High-level workflows for MathVision exploration."""

from __future__ import annotations

from pathlib import Path

from mathvision_explorer.dataset import MathVisionRecord, filter_records, load_jsonl_records
from mathvision_explorer.embeddings import ImageEmbedder, embed_record_image
from mathvision_explorer.index import Neighbor, VectorIndex


def build_image_index(records: list[MathVisionRecord], embedder: ImageEmbedder) -> VectorIndex:
    """Build a vector index for all records that have image paths."""

    index = VectorIndex()
    for record in records:
        if record.image_path is None:
            continue
        index.add(record.problem_id, embed_record_image(record.image_path, embedder))
    return index


def find_similar_records(
    records: list[MathVisionRecord],
    index: VectorIndex,
    query_id: str,
    query_vector: tuple[float, ...],
    *,
    limit: int = 5,
) -> list[tuple[MathVisionRecord, Neighbor]]:
    """Find records nearest to a query vector."""

    record_by_id = {record.problem_id: record for record in records}
    neighbors = index.search(query_vector, limit=limit, exclude_id=query_id)
    return [
        (record_by_id[neighbor.item_id], neighbor)
        for neighbor in neighbors
        if neighbor.item_id in record_by_id
    ]


def load_filtered_records(
    path: Path, *, subject: str | None = None, level: int | None = None
) -> list[MathVisionRecord]:
    """Load records and apply optional explorer filters."""

    return filter_records(load_jsonl_records(path), subject=subject, level=level)