| """Tests for local image embedding helpers.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import cast |
|
|
| import pytest |
| from PIL import Image |
|
|
| from mathvision_explorer.embeddings import ( |
| ColorStatsEmbedder, |
| IJepaImageEmbedder, |
| JepaDependencyError, |
| MissingImageError, |
| PatchInterestMap, |
| _mean_pool_features, |
| embed_record_image, |
| render_patch_interest_heatmap, |
| render_patch_interest_overlay, |
| ) |
|
|
|
|
| def test_color_stats_embedder_returns_six_normalized_features(tmp_path: Path) -> None: |
| """The baseline image embedder returns RGB means and standard deviations.""" |
|
|
| image_path = tmp_path / "red.png" |
| Image.new("RGB", (4, 4), color=(255, 0, 0)).save(image_path) |
|
|
| vector = ColorStatsEmbedder().embed_image(image_path) |
|
|
| assert vector == (1.0, 0.0, 0.0, 0.0, 0.0, 0.0) |
|
|
|
|
| def test_embed_record_image_requires_path() -> None: |
| """Embedding a record without an image path fails loudly.""" |
|
|
| with pytest.raises(MissingImageError): |
| embed_record_image(None, ColorStatsEmbedder()) |
|
|
|
|
| def test_ijepa_embedder_reports_missing_optional_dependencies() -> None: |
| """I-JEPA is optional and explains how to install its dependencies.""" |
|
|
| try: |
| IJepaImageEmbedder(model_id="unused") |
| except JepaDependencyError as error: |
| assert "sync-ijepa" in str(error) |
| except OSError: |
| pytest.skip("I-JEPA dependencies are installed but model files are unavailable.") |
| else: |
| pytest.skip("I-JEPA dependencies and model files are available in this environment.") |
|
|
|
|
| def test_mean_pool_features_preserves_last_dimension() -> None: |
| """Pooling should average token axes but keep hidden features.""" |
|
|
| fake_features = _FakeFeatures(ndim=4) |
|
|
| pooled = _mean_pool_features(fake_features) |
|
|
| assert pooled is fake_features |
| assert fake_features.mean_dim == (1, 2) |
|
|
|
|
| def test_patch_interest_heatmap_matches_image_size() -> None: |
| """Patch interest scores render into a transparent overlay at image resolution.""" |
|
|
| interest_map = PatchInterestMap( |
| scores=((0.0, 1.0), (0.5, 0.75)), |
| image_size=(12, 8), |
| ) |
|
|
| heatmap = render_patch_interest_heatmap(interest_map) |
|
|
| assert heatmap.mode == "RGBA" |
| assert heatmap.size == (12, 8) |
| pixel = cast(tuple[int, int, int, int], heatmap.getpixel((1, 1))) |
| assert pixel[3] > 0 |
|
|
|
|
| def test_patch_interest_overlay_preserves_source_size(tmp_path: Path) -> None: |
| """The rendered overlay can be shown directly by Streamlit.""" |
|
|
| image_path = tmp_path / "source.png" |
| Image.new("RGB", (10, 10), color=(20, 30, 40)).save(image_path) |
| interest_map = PatchInterestMap(scores=((1.0,),), image_size=(10, 10)) |
|
|
| overlay = render_patch_interest_overlay(image_path, interest_map) |
|
|
| assert overlay.mode == "RGBA" |
| assert overlay.size == (10, 10) |
| assert overlay.getpixel((5, 5)) != (20, 30, 40, 255) |
|
|
|
|
| class _FakeFeatures: |
| """Tiny tensor-like object for testing pooling dimensions without torch.""" |
|
|
| def __init__(self, *, ndim: int) -> None: |
| self.ndim = ndim |
| self.mean_dim: tuple[int, ...] | None = None |
|
|
| def mean(self, *, dim: tuple[int, ...]) -> _FakeFeatures: |
| self.mean_dim = dim |
| return self |
|
|