File size: 2,191 Bytes
3e67073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Tests for high-level explorer workflows."""

from __future__ import annotations

from pathlib import Path

from PIL import Image

from mathvision_explorer.dataset import MathVisionRecord
from mathvision_explorer.embeddings import ColorStatsEmbedder
from mathvision_explorer.explorer import build_image_index, find_similar_records


def test_build_image_index_skips_records_without_images(tmp_path: Path) -> None:
    """Only records with image paths are embedded."""

    red_path = tmp_path / "red.png"
    blue_path = tmp_path / "blue.png"
    Image.new("RGB", (3, 3), color=(255, 0, 0)).save(red_path)
    Image.new("RGB", (3, 3), color=(0, 0, 255)).save(blue_path)
    records = [
        MathVisionRecord(problem_id="red", question="Q", answer="A", image_path=red_path),
        MathVisionRecord(problem_id="blue", question="Q", answer="A", image_path=blue_path),
        MathVisionRecord(problem_id="missing", question="Q", answer="A"),
    ]

    index = build_image_index(records, ColorStatsEmbedder())

    assert len(index) == 2


def test_find_similar_records_returns_record_metadata(tmp_path: Path) -> None:
    """Nearest-neighbor output keeps the original dataset record alongside the score."""

    red_path = tmp_path / "red.png"
    near_red_path = tmp_path / "near-red.png"
    blue_path = tmp_path / "blue.png"
    Image.new("RGB", (3, 3), color=(255, 0, 0)).save(red_path)
    Image.new("RGB", (3, 3), color=(240, 10, 10)).save(near_red_path)
    Image.new("RGB", (3, 3), color=(0, 0, 255)).save(blue_path)
    records = [
        MathVisionRecord(problem_id="red", question="Red", answer="A", image_path=red_path),
        MathVisionRecord(
            problem_id="near-red",
            question="Near red",
            answer="A",
            image_path=near_red_path,
        ),
        MathVisionRecord(problem_id="blue", question="Blue", answer="A", image_path=blue_path),
    ]
    embedder = ColorStatsEmbedder()
    index = build_image_index(records, embedder)

    matches = find_similar_records(
        records,
        index,
        "red",
        embedder.embed_image(red_path),
        limit=1,
    )

    assert matches[0][0].problem_id == "near-red"