ddebree's picture
Visualize
51bdf55
"""Embedding helpers for image records.
The default embedder is intentionally lightweight and deterministic. It gives the
project a testable local baseline while leaving room to plug in learned JEPA features.
"""
from __future__ import annotations
from dataclasses import dataclass
from importlib import import_module
from pathlib import Path
from typing import Any, Protocol
from PIL import Image, ImageDraw, ImageStat
@dataclass(frozen=True)
class PatchInterestMap:
"""Patch-level model interest scores for an image."""
scores: tuple[tuple[float, ...], ...]
image_size: tuple[int, int]
@property
def grid_size(self) -> tuple[int, int]:
"""Return ``(rows, columns)`` for the patch grid."""
return (len(self.scores), len(self.scores[0]) if self.scores else 0)
class ImageEmbedder(Protocol):
"""Protocol for objects that turn image paths into numeric vectors."""
def embed_image(self, image_path: Path) -> tuple[float, ...]:
"""Return an embedding vector for an image file."""
class ColorStatsEmbedder:
"""Embed images with normalized RGB mean and standard deviation features."""
def embed_image(self, image_path: Path) -> tuple[float, ...]:
"""Return six normalized color-statistics features for an image."""
with Image.open(image_path) as image:
rgb_image = image.convert("RGB")
stat = ImageStat.Stat(rgb_image)
means = tuple(channel / 255.0 for channel in stat.mean)
stddevs = tuple(channel / 255.0 for channel in stat.stddev)
return means + stddevs
class MissingImageError(RuntimeError):
"""Raised when a record cannot be embedded because no image path is available."""
class JepaDependencyError(RuntimeError):
"""Raised when optional JEPA dependencies are not installed."""
class IJepaImageEmbedder:
"""Embed images with a Hugging Face I-JEPA vision encoder."""
def __init__(
self,
*,
model_id: str = "facebook/ijepa_vith14_1k",
device: str | None = None,
) -> None:
"""Load the I-JEPA processor and model lazily at embedder construction time."""
self.model_id = model_id
self._torch = _import_optional("torch")
transformers = _import_optional("transformers")
_quiet_transformers_logging(transformers)
self._processor = transformers.AutoProcessor.from_pretrained(model_id)
self._model = transformers.AutoModel.from_pretrained(model_id)
self._device = device or ("cuda" if self._torch.cuda.is_available() else "cpu")
self._model.to(self._device)
self._model.eval()
def embed_image(self, image_path: Path) -> tuple[float, ...]:
"""Return a pooled I-JEPA feature vector for an image."""
rgb_image, outputs = self._encode_image(image_path)
rgb_image.close()
pooled = _mean_pool_features(outputs.last_hidden_state)
return tuple(float(value) for value in pooled.squeeze(0).detach().cpu().tolist())
def patch_interest_map(self, image_path: Path) -> PatchInterestMap:
"""Return normalized patch-interest scores from I-JEPA token activations."""
rgb_image, outputs = self._encode_image(image_path)
scores = _tokens_to_patch_scores(outputs.last_hidden_state, self._torch)
image_size = rgb_image.size
rgb_image.close()
return PatchInterestMap(scores=scores, image_size=image_size)
def render_patch_attention_overlay(
self,
image_path: Path,
*,
alpha: int = 135,
) -> Image.Image:
"""Render a heatmap overlay for the patches with strongest activations."""
interest_map = self.patch_interest_map(image_path)
return render_patch_interest_overlay(image_path, interest_map, alpha=alpha)
def _encode_image(self, image_path: Path) -> tuple[Image.Image, Any]:
with Image.open(image_path) as image:
rgb_image = image.convert("RGB")
encoded = self._processor(rgb_image, return_tensors="pt").to(self._model.device)
with self._torch.no_grad():
outputs = self._model(**encoded)
return rgb_image, outputs
def render_patch_interest_heatmap(
interest_map: PatchInterestMap,
*,
alpha: int = 135,
) -> Image.Image:
"""Render patch scores as a transparent red/yellow heatmap."""
width, height = interest_map.image_size
rows, columns = interest_map.grid_size
heatmap = Image.new("RGBA", (width, height), (0, 0, 0, 0))
if rows == 0 or columns == 0:
return heatmap
draw = ImageDraw.Draw(heatmap, "RGBA")
for row_index, row in enumerate(interest_map.scores):
for column_index, score in enumerate(row):
x0 = round(column_index * width / columns)
x1 = round((column_index + 1) * width / columns)
y0 = round(row_index * height / rows)
y1 = round((row_index + 1) * height / rows)
draw.rectangle((x0, y0, x1, y1), fill=(*_score_to_heat_color(score), alpha))
return heatmap
def render_patch_interest_overlay(
image_path: Path,
interest_map: PatchInterestMap,
*,
alpha: int = 135,
) -> Image.Image:
"""Overlay a patch-interest heatmap on top of the source image."""
with Image.open(image_path) as image:
base = image.convert("RGBA")
heatmap = render_patch_interest_heatmap(interest_map, alpha=alpha)
return Image.alpha_composite(base, heatmap)
def _mean_pool_features(features: Any) -> Any:
"""Pool token/time dimensions while preserving the final feature dimension."""
if features.ndim <= 2:
return features
return features.mean(dim=tuple(range(1, features.ndim - 1)))
def _tokens_to_patch_scores(features: Any, torch: Any) -> tuple[tuple[float, ...], ...]:
"""Convert model token features to a normalized square patch-score grid."""
token_features = features.squeeze(0).detach()
token_count = int(token_features.shape[0])
if _is_square(token_count - 1):
token_features = token_features[1:]
token_count -= 1
if not _is_square(token_count):
msg = f"Cannot infer a square patch grid from {token_count} visual tokens."
raise RuntimeError(msg)
scores = torch.linalg.vector_norm(token_features.float(), dim=-1)
min_score = scores.min()
score_range = scores.max() - min_score
if float(score_range.detach().cpu()) == 0.0:
normalized = torch.zeros_like(scores)
else:
normalized = (scores - min_score) / score_range
grid_width = int(token_count**0.5)
values = normalized.reshape(grid_width, grid_width).detach().cpu().tolist()
return tuple(tuple(float(value) for value in row) for row in values)
def _is_square(value: int) -> bool:
if value <= 0:
return False
root = int(value**0.5)
return root * root == value
def _score_to_heat_color(score: float) -> tuple[int, int, int]:
clamped = max(0.0, min(1.0, score))
red = 255
green = int(round(40 + 190 * clamped))
blue = int(round(30 * (1.0 - clamped)))
return (red, green, blue)
def embed_record_image(image_path: Path | None, embedder: ImageEmbedder) -> tuple[float, ...]:
"""Embed a record image or raise a clear error when the path is missing."""
if image_path is None:
raise MissingImageError("Record has no image path to embed.")
return embedder.embed_image(image_path)
def _import_optional(module_name: str) -> Any:
try:
return import_module(module_name)
except ImportError as error:
msg = (
"I-JEPA dependencies are missing. Install them with "
"`uv sync --extra ijepa --dev` (or `make sync-ijepa`)."
)
raise JepaDependencyError(msg) from error
def _quiet_transformers_logging(transformers: Any) -> None:
"""Reduce noisy dev-version Transformers compatibility logging."""
try:
transformers.logging.set_verbosity_error()
except AttributeError:
return