"""Tests for local image embedding helpers.""" from __future__ import annotations from pathlib import Path from typing import cast import pytest from PIL import Image from mathvision_explorer.embeddings import ( ColorStatsEmbedder, IJepaImageEmbedder, JepaDependencyError, MissingImageError, PatchInterestMap, _mean_pool_features, embed_record_image, render_patch_interest_heatmap, render_patch_interest_overlay, ) def test_color_stats_embedder_returns_six_normalized_features(tmp_path: Path) -> None: """The baseline image embedder returns RGB means and standard deviations.""" image_path = tmp_path / "red.png" Image.new("RGB", (4, 4), color=(255, 0, 0)).save(image_path) vector = ColorStatsEmbedder().embed_image(image_path) assert vector == (1.0, 0.0, 0.0, 0.0, 0.0, 0.0) def test_embed_record_image_requires_path() -> None: """Embedding a record without an image path fails loudly.""" with pytest.raises(MissingImageError): embed_record_image(None, ColorStatsEmbedder()) def test_ijepa_embedder_reports_missing_optional_dependencies() -> None: """I-JEPA is optional and explains how to install its dependencies.""" try: IJepaImageEmbedder(model_id="unused") except JepaDependencyError as error: assert "sync-ijepa" in str(error) except OSError: pytest.skip("I-JEPA dependencies are installed but model files are unavailable.") else: pytest.skip("I-JEPA dependencies and model files are available in this environment.") def test_mean_pool_features_preserves_last_dimension() -> None: """Pooling should average token axes but keep hidden features.""" fake_features = _FakeFeatures(ndim=4) pooled = _mean_pool_features(fake_features) assert pooled is fake_features assert fake_features.mean_dim == (1, 2) def test_patch_interest_heatmap_matches_image_size() -> None: """Patch interest scores render into a transparent overlay at image resolution.""" interest_map = PatchInterestMap( scores=((0.0, 1.0), (0.5, 0.75)), image_size=(12, 8), ) heatmap = render_patch_interest_heatmap(interest_map) assert heatmap.mode == "RGBA" assert heatmap.size == (12, 8) pixel = cast(tuple[int, int, int, int], heatmap.getpixel((1, 1))) assert pixel[3] > 0 def test_patch_interest_overlay_preserves_source_size(tmp_path: Path) -> None: """The rendered overlay can be shown directly by Streamlit.""" image_path = tmp_path / "source.png" Image.new("RGB", (10, 10), color=(20, 30, 40)).save(image_path) interest_map = PatchInterestMap(scores=((1.0,),), image_size=(10, 10)) overlay = render_patch_interest_overlay(image_path, interest_map) assert overlay.mode == "RGBA" assert overlay.size == (10, 10) assert overlay.getpixel((5, 5)) != (20, 30, 40, 255) class _FakeFeatures: """Tiny tensor-like object for testing pooling dimensions without torch.""" def __init__(self, *, ndim: int) -> None: self.ndim = ndim self.mean_dim: tuple[int, ...] | None = None def mean(self, *, dim: tuple[int, ...]) -> _FakeFeatures: self.mean_dim = dim return self