| """Embedding helpers for image records. |
| |
| The default embedder is intentionally lightweight and deterministic. It gives the |
| project a testable local baseline while leaving room to plug in learned JEPA features. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from importlib import import_module |
| from pathlib import Path |
| from typing import Any, Protocol |
|
|
| from PIL import Image, ImageDraw, ImageStat |
|
|
|
|
| @dataclass(frozen=True) |
| class PatchInterestMap: |
| """Patch-level model interest scores for an image.""" |
|
|
| scores: tuple[tuple[float, ...], ...] |
| image_size: tuple[int, int] |
|
|
| @property |
| def grid_size(self) -> tuple[int, int]: |
| """Return ``(rows, columns)`` for the patch grid.""" |
|
|
| return (len(self.scores), len(self.scores[0]) if self.scores else 0) |
|
|
|
|
| class ImageEmbedder(Protocol): |
| """Protocol for objects that turn image paths into numeric vectors.""" |
|
|
| def embed_image(self, image_path: Path) -> tuple[float, ...]: |
| """Return an embedding vector for an image file.""" |
|
|
|
|
| class ColorStatsEmbedder: |
| """Embed images with normalized RGB mean and standard deviation features.""" |
|
|
| def embed_image(self, image_path: Path) -> tuple[float, ...]: |
| """Return six normalized color-statistics features for an image.""" |
|
|
| with Image.open(image_path) as image: |
| rgb_image = image.convert("RGB") |
| stat = ImageStat.Stat(rgb_image) |
| means = tuple(channel / 255.0 for channel in stat.mean) |
| stddevs = tuple(channel / 255.0 for channel in stat.stddev) |
| return means + stddevs |
|
|
|
|
| class MissingImageError(RuntimeError): |
| """Raised when a record cannot be embedded because no image path is available.""" |
|
|
|
|
| class JepaDependencyError(RuntimeError): |
| """Raised when optional JEPA dependencies are not installed.""" |
|
|
|
|
| class IJepaImageEmbedder: |
| """Embed images with a Hugging Face I-JEPA vision encoder.""" |
|
|
| def __init__( |
| self, |
| *, |
| model_id: str = "facebook/ijepa_vith14_1k", |
| device: str | None = None, |
| ) -> None: |
| """Load the I-JEPA processor and model lazily at embedder construction time.""" |
|
|
| self.model_id = model_id |
| self._torch = _import_optional("torch") |
| transformers = _import_optional("transformers") |
| _quiet_transformers_logging(transformers) |
|
|
| self._processor = transformers.AutoProcessor.from_pretrained(model_id) |
| self._model = transformers.AutoModel.from_pretrained(model_id) |
| self._device = device or ("cuda" if self._torch.cuda.is_available() else "cpu") |
| self._model.to(self._device) |
| self._model.eval() |
|
|
| def embed_image(self, image_path: Path) -> tuple[float, ...]: |
| """Return a pooled I-JEPA feature vector for an image.""" |
|
|
| rgb_image, outputs = self._encode_image(image_path) |
| rgb_image.close() |
|
|
| pooled = _mean_pool_features(outputs.last_hidden_state) |
|
|
| return tuple(float(value) for value in pooled.squeeze(0).detach().cpu().tolist()) |
|
|
| def patch_interest_map(self, image_path: Path) -> PatchInterestMap: |
| """Return normalized patch-interest scores from I-JEPA token activations.""" |
|
|
| rgb_image, outputs = self._encode_image(image_path) |
| scores = _tokens_to_patch_scores(outputs.last_hidden_state, self._torch) |
| image_size = rgb_image.size |
| rgb_image.close() |
| return PatchInterestMap(scores=scores, image_size=image_size) |
|
|
| def render_patch_attention_overlay( |
| self, |
| image_path: Path, |
| *, |
| alpha: int = 135, |
| ) -> Image.Image: |
| """Render a heatmap overlay for the patches with strongest activations.""" |
|
|
| interest_map = self.patch_interest_map(image_path) |
| return render_patch_interest_overlay(image_path, interest_map, alpha=alpha) |
|
|
| def _encode_image(self, image_path: Path) -> tuple[Image.Image, Any]: |
| with Image.open(image_path) as image: |
| rgb_image = image.convert("RGB") |
|
|
| encoded = self._processor(rgb_image, return_tensors="pt").to(self._model.device) |
|
|
| with self._torch.no_grad(): |
| outputs = self._model(**encoded) |
| return rgb_image, outputs |
|
|
|
|
| def render_patch_interest_heatmap( |
| interest_map: PatchInterestMap, |
| *, |
| alpha: int = 135, |
| ) -> Image.Image: |
| """Render patch scores as a transparent red/yellow heatmap.""" |
|
|
| width, height = interest_map.image_size |
| rows, columns = interest_map.grid_size |
| heatmap = Image.new("RGBA", (width, height), (0, 0, 0, 0)) |
| if rows == 0 or columns == 0: |
| return heatmap |
|
|
| draw = ImageDraw.Draw(heatmap, "RGBA") |
| for row_index, row in enumerate(interest_map.scores): |
| for column_index, score in enumerate(row): |
| x0 = round(column_index * width / columns) |
| x1 = round((column_index + 1) * width / columns) |
| y0 = round(row_index * height / rows) |
| y1 = round((row_index + 1) * height / rows) |
| draw.rectangle((x0, y0, x1, y1), fill=(*_score_to_heat_color(score), alpha)) |
| return heatmap |
|
|
|
|
| def render_patch_interest_overlay( |
| image_path: Path, |
| interest_map: PatchInterestMap, |
| *, |
| alpha: int = 135, |
| ) -> Image.Image: |
| """Overlay a patch-interest heatmap on top of the source image.""" |
|
|
| with Image.open(image_path) as image: |
| base = image.convert("RGBA") |
|
|
| heatmap = render_patch_interest_heatmap(interest_map, alpha=alpha) |
| return Image.alpha_composite(base, heatmap) |
|
|
|
|
| def _mean_pool_features(features: Any) -> Any: |
| """Pool token/time dimensions while preserving the final feature dimension.""" |
|
|
| if features.ndim <= 2: |
| return features |
| return features.mean(dim=tuple(range(1, features.ndim - 1))) |
|
|
|
|
| def _tokens_to_patch_scores(features: Any, torch: Any) -> tuple[tuple[float, ...], ...]: |
| """Convert model token features to a normalized square patch-score grid.""" |
|
|
| token_features = features.squeeze(0).detach() |
| token_count = int(token_features.shape[0]) |
| if _is_square(token_count - 1): |
| token_features = token_features[1:] |
| token_count -= 1 |
| if not _is_square(token_count): |
| msg = f"Cannot infer a square patch grid from {token_count} visual tokens." |
| raise RuntimeError(msg) |
|
|
| scores = torch.linalg.vector_norm(token_features.float(), dim=-1) |
| min_score = scores.min() |
| score_range = scores.max() - min_score |
| if float(score_range.detach().cpu()) == 0.0: |
| normalized = torch.zeros_like(scores) |
| else: |
| normalized = (scores - min_score) / score_range |
|
|
| grid_width = int(token_count**0.5) |
| values = normalized.reshape(grid_width, grid_width).detach().cpu().tolist() |
| return tuple(tuple(float(value) for value in row) for row in values) |
|
|
|
|
| def _is_square(value: int) -> bool: |
| if value <= 0: |
| return False |
| root = int(value**0.5) |
| return root * root == value |
|
|
|
|
| def _score_to_heat_color(score: float) -> tuple[int, int, int]: |
| clamped = max(0.0, min(1.0, score)) |
| red = 255 |
| green = int(round(40 + 190 * clamped)) |
| blue = int(round(30 * (1.0 - clamped))) |
| return (red, green, blue) |
|
|
|
|
| def embed_record_image(image_path: Path | None, embedder: ImageEmbedder) -> tuple[float, ...]: |
| """Embed a record image or raise a clear error when the path is missing.""" |
|
|
| if image_path is None: |
| raise MissingImageError("Record has no image path to embed.") |
| return embedder.embed_image(image_path) |
|
|
|
|
| def _import_optional(module_name: str) -> Any: |
| try: |
| return import_module(module_name) |
| except ImportError as error: |
| msg = ( |
| "I-JEPA dependencies are missing. Install them with " |
| "`uv sync --extra ijepa --dev` (or `make sync-ijepa`)." |
| ) |
| raise JepaDependencyError(msg) from error |
|
|
|
|
| def _quiet_transformers_logging(transformers: Any) -> None: |
| """Reduce noisy dev-version Transformers compatibility logging.""" |
|
|
| try: |
| transformers.logging.set_verbosity_error() |
| except AttributeError: |
| return |
|
|