mathvision-jepa-explorer / tests /test_embeddings.py
ddebree's picture
Visualize
51bdf55
"""Tests for local image embedding helpers."""
from __future__ import annotations
from pathlib import Path
from typing import cast
import pytest
from PIL import Image
from mathvision_explorer.embeddings import (
ColorStatsEmbedder,
IJepaImageEmbedder,
JepaDependencyError,
MissingImageError,
PatchInterestMap,
_mean_pool_features,
embed_record_image,
render_patch_interest_heatmap,
render_patch_interest_overlay,
)
def test_color_stats_embedder_returns_six_normalized_features(tmp_path: Path) -> None:
"""The baseline image embedder returns RGB means and standard deviations."""
image_path = tmp_path / "red.png"
Image.new("RGB", (4, 4), color=(255, 0, 0)).save(image_path)
vector = ColorStatsEmbedder().embed_image(image_path)
assert vector == (1.0, 0.0, 0.0, 0.0, 0.0, 0.0)
def test_embed_record_image_requires_path() -> None:
"""Embedding a record without an image path fails loudly."""
with pytest.raises(MissingImageError):
embed_record_image(None, ColorStatsEmbedder())
def test_ijepa_embedder_reports_missing_optional_dependencies() -> None:
"""I-JEPA is optional and explains how to install its dependencies."""
try:
IJepaImageEmbedder(model_id="unused")
except JepaDependencyError as error:
assert "sync-ijepa" in str(error)
except OSError:
pytest.skip("I-JEPA dependencies are installed but model files are unavailable.")
else:
pytest.skip("I-JEPA dependencies and model files are available in this environment.")
def test_mean_pool_features_preserves_last_dimension() -> None:
"""Pooling should average token axes but keep hidden features."""
fake_features = _FakeFeatures(ndim=4)
pooled = _mean_pool_features(fake_features)
assert pooled is fake_features
assert fake_features.mean_dim == (1, 2)
def test_patch_interest_heatmap_matches_image_size() -> None:
"""Patch interest scores render into a transparent overlay at image resolution."""
interest_map = PatchInterestMap(
scores=((0.0, 1.0), (0.5, 0.75)),
image_size=(12, 8),
)
heatmap = render_patch_interest_heatmap(interest_map)
assert heatmap.mode == "RGBA"
assert heatmap.size == (12, 8)
pixel = cast(tuple[int, int, int, int], heatmap.getpixel((1, 1)))
assert pixel[3] > 0
def test_patch_interest_overlay_preserves_source_size(tmp_path: Path) -> None:
"""The rendered overlay can be shown directly by Streamlit."""
image_path = tmp_path / "source.png"
Image.new("RGB", (10, 10), color=(20, 30, 40)).save(image_path)
interest_map = PatchInterestMap(scores=((1.0,),), image_size=(10, 10))
overlay = render_patch_interest_overlay(image_path, interest_map)
assert overlay.mode == "RGBA"
assert overlay.size == (10, 10)
assert overlay.getpixel((5, 5)) != (20, 30, 40, 255)
class _FakeFeatures:
"""Tiny tensor-like object for testing pooling dimensions without torch."""
def __init__(self, *, ndim: int) -> None:
self.ndim = ndim
self.mean_dim: tuple[int, ...] | None = None
def mean(self, *, dim: tuple[int, ...]) -> _FakeFeatures:
self.mean_dim = dim
return self