File size: 3,230 Bytes
3e67073
 
 
 
 
51bdf55
3e67073
 
 
 
 
 
 
 
 
51bdf55
3e67073
 
51bdf55
 
3e67073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51bdf55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e67073
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Tests for local image embedding helpers."""

from __future__ import annotations

from pathlib import Path
from typing import cast

import pytest
from PIL import Image

from mathvision_explorer.embeddings import (
    ColorStatsEmbedder,
    IJepaImageEmbedder,
    JepaDependencyError,
    MissingImageError,
    PatchInterestMap,
    _mean_pool_features,
    embed_record_image,
    render_patch_interest_heatmap,
    render_patch_interest_overlay,
)


def test_color_stats_embedder_returns_six_normalized_features(tmp_path: Path) -> None:
    """The baseline image embedder returns RGB means and standard deviations."""

    image_path = tmp_path / "red.png"
    Image.new("RGB", (4, 4), color=(255, 0, 0)).save(image_path)

    vector = ColorStatsEmbedder().embed_image(image_path)

    assert vector == (1.0, 0.0, 0.0, 0.0, 0.0, 0.0)


def test_embed_record_image_requires_path() -> None:
    """Embedding a record without an image path fails loudly."""

    with pytest.raises(MissingImageError):
        embed_record_image(None, ColorStatsEmbedder())


def test_ijepa_embedder_reports_missing_optional_dependencies() -> None:
    """I-JEPA is optional and explains how to install its dependencies."""

    try:
        IJepaImageEmbedder(model_id="unused")
    except JepaDependencyError as error:
        assert "sync-ijepa" in str(error)
    except OSError:
        pytest.skip("I-JEPA dependencies are installed but model files are unavailable.")
    else:
        pytest.skip("I-JEPA dependencies and model files are available in this environment.")


def test_mean_pool_features_preserves_last_dimension() -> None:
    """Pooling should average token axes but keep hidden features."""

    fake_features = _FakeFeatures(ndim=4)

    pooled = _mean_pool_features(fake_features)

    assert pooled is fake_features
    assert fake_features.mean_dim == (1, 2)


def test_patch_interest_heatmap_matches_image_size() -> None:
    """Patch interest scores render into a transparent overlay at image resolution."""

    interest_map = PatchInterestMap(
        scores=((0.0, 1.0), (0.5, 0.75)),
        image_size=(12, 8),
    )

    heatmap = render_patch_interest_heatmap(interest_map)

    assert heatmap.mode == "RGBA"
    assert heatmap.size == (12, 8)
    pixel = cast(tuple[int, int, int, int], heatmap.getpixel((1, 1)))
    assert pixel[3] > 0


def test_patch_interest_overlay_preserves_source_size(tmp_path: Path) -> None:
    """The rendered overlay can be shown directly by Streamlit."""

    image_path = tmp_path / "source.png"
    Image.new("RGB", (10, 10), color=(20, 30, 40)).save(image_path)
    interest_map = PatchInterestMap(scores=((1.0,),), image_size=(10, 10))

    overlay = render_patch_interest_overlay(image_path, interest_map)

    assert overlay.mode == "RGBA"
    assert overlay.size == (10, 10)
    assert overlay.getpixel((5, 5)) != (20, 30, 40, 255)


class _FakeFeatures:
    """Tiny tensor-like object for testing pooling dimensions without torch."""

    def __init__(self, *, ndim: int) -> None:
        self.ndim = ndim
        self.mean_dim: tuple[int, ...] | None = None

    def mean(self, *, dim: tuple[int, ...]) -> _FakeFeatures:
        self.mean_dim = dim
        return self