"""Embedding helpers for image records. The default embedder is intentionally lightweight and deterministic. It gives the project a testable local baseline while leaving room to plug in learned JEPA features. """ from __future__ import annotations from dataclasses import dataclass from importlib import import_module from pathlib import Path from typing import Any, Protocol from PIL import Image, ImageDraw, ImageStat @dataclass(frozen=True) class PatchInterestMap: """Patch-level model interest scores for an image.""" scores: tuple[tuple[float, ...], ...] image_size: tuple[int, int] @property def grid_size(self) -> tuple[int, int]: """Return ``(rows, columns)`` for the patch grid.""" return (len(self.scores), len(self.scores[0]) if self.scores else 0) class ImageEmbedder(Protocol): """Protocol for objects that turn image paths into numeric vectors.""" def embed_image(self, image_path: Path) -> tuple[float, ...]: """Return an embedding vector for an image file.""" class ColorStatsEmbedder: """Embed images with normalized RGB mean and standard deviation features.""" def embed_image(self, image_path: Path) -> tuple[float, ...]: """Return six normalized color-statistics features for an image.""" with Image.open(image_path) as image: rgb_image = image.convert("RGB") stat = ImageStat.Stat(rgb_image) means = tuple(channel / 255.0 for channel in stat.mean) stddevs = tuple(channel / 255.0 for channel in stat.stddev) return means + stddevs class MissingImageError(RuntimeError): """Raised when a record cannot be embedded because no image path is available.""" class JepaDependencyError(RuntimeError): """Raised when optional JEPA dependencies are not installed.""" class IJepaImageEmbedder: """Embed images with a Hugging Face I-JEPA vision encoder.""" def __init__( self, *, model_id: str = "facebook/ijepa_vith14_1k", device: str | None = None, ) -> None: """Load the I-JEPA processor and model lazily at embedder construction time.""" self.model_id = model_id self._torch = _import_optional("torch") transformers = _import_optional("transformers") _quiet_transformers_logging(transformers) self._processor = transformers.AutoProcessor.from_pretrained(model_id) self._model = transformers.AutoModel.from_pretrained(model_id) self._device = device or ("cuda" if self._torch.cuda.is_available() else "cpu") self._model.to(self._device) self._model.eval() def embed_image(self, image_path: Path) -> tuple[float, ...]: """Return a pooled I-JEPA feature vector for an image.""" rgb_image, outputs = self._encode_image(image_path) rgb_image.close() pooled = _mean_pool_features(outputs.last_hidden_state) return tuple(float(value) for value in pooled.squeeze(0).detach().cpu().tolist()) def patch_interest_map(self, image_path: Path) -> PatchInterestMap: """Return normalized patch-interest scores from I-JEPA token activations.""" rgb_image, outputs = self._encode_image(image_path) scores = _tokens_to_patch_scores(outputs.last_hidden_state, self._torch) image_size = rgb_image.size rgb_image.close() return PatchInterestMap(scores=scores, image_size=image_size) def render_patch_attention_overlay( self, image_path: Path, *, alpha: int = 135, ) -> Image.Image: """Render a heatmap overlay for the patches with strongest activations.""" interest_map = self.patch_interest_map(image_path) return render_patch_interest_overlay(image_path, interest_map, alpha=alpha) def _encode_image(self, image_path: Path) -> tuple[Image.Image, Any]: with Image.open(image_path) as image: rgb_image = image.convert("RGB") encoded = self._processor(rgb_image, return_tensors="pt").to(self._model.device) with self._torch.no_grad(): outputs = self._model(**encoded) return rgb_image, outputs def render_patch_interest_heatmap( interest_map: PatchInterestMap, *, alpha: int = 135, ) -> Image.Image: """Render patch scores as a transparent red/yellow heatmap.""" width, height = interest_map.image_size rows, columns = interest_map.grid_size heatmap = Image.new("RGBA", (width, height), (0, 0, 0, 0)) if rows == 0 or columns == 0: return heatmap draw = ImageDraw.Draw(heatmap, "RGBA") for row_index, row in enumerate(interest_map.scores): for column_index, score in enumerate(row): x0 = round(column_index * width / columns) x1 = round((column_index + 1) * width / columns) y0 = round(row_index * height / rows) y1 = round((row_index + 1) * height / rows) draw.rectangle((x0, y0, x1, y1), fill=(*_score_to_heat_color(score), alpha)) return heatmap def render_patch_interest_overlay( image_path: Path, interest_map: PatchInterestMap, *, alpha: int = 135, ) -> Image.Image: """Overlay a patch-interest heatmap on top of the source image.""" with Image.open(image_path) as image: base = image.convert("RGBA") heatmap = render_patch_interest_heatmap(interest_map, alpha=alpha) return Image.alpha_composite(base, heatmap) def _mean_pool_features(features: Any) -> Any: """Pool token/time dimensions while preserving the final feature dimension.""" if features.ndim <= 2: return features return features.mean(dim=tuple(range(1, features.ndim - 1))) def _tokens_to_patch_scores(features: Any, torch: Any) -> tuple[tuple[float, ...], ...]: """Convert model token features to a normalized square patch-score grid.""" token_features = features.squeeze(0).detach() token_count = int(token_features.shape[0]) if _is_square(token_count - 1): token_features = token_features[1:] token_count -= 1 if not _is_square(token_count): msg = f"Cannot infer a square patch grid from {token_count} visual tokens." raise RuntimeError(msg) scores = torch.linalg.vector_norm(token_features.float(), dim=-1) min_score = scores.min() score_range = scores.max() - min_score if float(score_range.detach().cpu()) == 0.0: normalized = torch.zeros_like(scores) else: normalized = (scores - min_score) / score_range grid_width = int(token_count**0.5) values = normalized.reshape(grid_width, grid_width).detach().cpu().tolist() return tuple(tuple(float(value) for value in row) for row in values) def _is_square(value: int) -> bool: if value <= 0: return False root = int(value**0.5) return root * root == value def _score_to_heat_color(score: float) -> tuple[int, int, int]: clamped = max(0.0, min(1.0, score)) red = 255 green = int(round(40 + 190 * clamped)) blue = int(round(30 * (1.0 - clamped))) return (red, green, blue) def embed_record_image(image_path: Path | None, embedder: ImageEmbedder) -> tuple[float, ...]: """Embed a record image or raise a clear error when the path is missing.""" if image_path is None: raise MissingImageError("Record has no image path to embed.") return embedder.embed_image(image_path) def _import_optional(module_name: str) -> Any: try: return import_module(module_name) except ImportError as error: msg = ( "I-JEPA dependencies are missing. Install them with " "`uv sync --extra ijepa --dev` (or `make sync-ijepa`)." ) raise JepaDependencyError(msg) from error def _quiet_transformers_logging(transformers: Any) -> None: """Reduce noisy dev-version Transformers compatibility logging.""" try: transformers.logging.set_verbosity_error() except AttributeError: return