File size: 8,031 Bytes
f9306c2
 
 
3e67073
f9306c2
 
 
 
51bdf55
f9306c2
 
 
 
51bdf55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9306c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e67073
f9306c2
 
3e67073
 
f9306c2
 
 
 
3e67073
f9306c2
 
3e67073
f9306c2
 
 
 
 
 
3e67073
 
f9306c2
 
 
 
 
3e67073
f9306c2
51bdf55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9306c2
 
 
 
 
 
3e67073
51bdf55
f9306c2
51bdf55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9306c2
 
 
 
 
 
 
 
 
 
51bdf55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9306c2
 
 
 
 
 
 
 
 
 
 
 
 
3e67073
 
f9306c2
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""Embedding helpers for image records.

The default embedder is intentionally lightweight and deterministic. It gives the
project a testable local baseline while leaving room to plug in learned JEPA features.
"""

from __future__ import annotations

from dataclasses import dataclass
from importlib import import_module
from pathlib import Path
from typing import Any, Protocol

from PIL import Image, ImageDraw, ImageStat


@dataclass(frozen=True)
class PatchInterestMap:
    """Patch-level model interest scores for an image."""

    scores: tuple[tuple[float, ...], ...]
    image_size: tuple[int, int]

    @property
    def grid_size(self) -> tuple[int, int]:
        """Return ``(rows, columns)`` for the patch grid."""

        return (len(self.scores), len(self.scores[0]) if self.scores else 0)


class ImageEmbedder(Protocol):
    """Protocol for objects that turn image paths into numeric vectors."""

    def embed_image(self, image_path: Path) -> tuple[float, ...]:
        """Return an embedding vector for an image file."""


class ColorStatsEmbedder:
    """Embed images with normalized RGB mean and standard deviation features."""

    def embed_image(self, image_path: Path) -> tuple[float, ...]:
        """Return six normalized color-statistics features for an image."""

        with Image.open(image_path) as image:
            rgb_image = image.convert("RGB")
            stat = ImageStat.Stat(rgb_image)
        means = tuple(channel / 255.0 for channel in stat.mean)
        stddevs = tuple(channel / 255.0 for channel in stat.stddev)
        return means + stddevs


class MissingImageError(RuntimeError):
    """Raised when a record cannot be embedded because no image path is available."""


class JepaDependencyError(RuntimeError):
    """Raised when optional JEPA dependencies are not installed."""


class IJepaImageEmbedder:
    """Embed images with a Hugging Face I-JEPA vision encoder."""

    def __init__(
        self,
        *,
        model_id: str = "facebook/ijepa_vith14_1k",
        device: str | None = None,
    ) -> None:
        """Load the I-JEPA processor and model lazily at embedder construction time."""

        self.model_id = model_id
        self._torch = _import_optional("torch")
        transformers = _import_optional("transformers")
        _quiet_transformers_logging(transformers)

        self._processor = transformers.AutoProcessor.from_pretrained(model_id)
        self._model = transformers.AutoModel.from_pretrained(model_id)
        self._device = device or ("cuda" if self._torch.cuda.is_available() else "cpu")
        self._model.to(self._device)
        self._model.eval()

    def embed_image(self, image_path: Path) -> tuple[float, ...]:
        """Return a pooled I-JEPA feature vector for an image."""

        rgb_image, outputs = self._encode_image(image_path)
        rgb_image.close()

        pooled = _mean_pool_features(outputs.last_hidden_state)

        return tuple(float(value) for value in pooled.squeeze(0).detach().cpu().tolist())

    def patch_interest_map(self, image_path: Path) -> PatchInterestMap:
        """Return normalized patch-interest scores from I-JEPA token activations."""

        rgb_image, outputs = self._encode_image(image_path)
        scores = _tokens_to_patch_scores(outputs.last_hidden_state, self._torch)
        image_size = rgb_image.size
        rgb_image.close()
        return PatchInterestMap(scores=scores, image_size=image_size)

    def render_patch_attention_overlay(
        self,
        image_path: Path,
        *,
        alpha: int = 135,
    ) -> Image.Image:
        """Render a heatmap overlay for the patches with strongest activations."""

        interest_map = self.patch_interest_map(image_path)
        return render_patch_interest_overlay(image_path, interest_map, alpha=alpha)

    def _encode_image(self, image_path: Path) -> tuple[Image.Image, Any]:
        with Image.open(image_path) as image:
            rgb_image = image.convert("RGB")

        encoded = self._processor(rgb_image, return_tensors="pt").to(self._model.device)

        with self._torch.no_grad():
            outputs = self._model(**encoded)
        return rgb_image, outputs


def render_patch_interest_heatmap(
    interest_map: PatchInterestMap,
    *,
    alpha: int = 135,
) -> Image.Image:
    """Render patch scores as a transparent red/yellow heatmap."""

    width, height = interest_map.image_size
    rows, columns = interest_map.grid_size
    heatmap = Image.new("RGBA", (width, height), (0, 0, 0, 0))
    if rows == 0 or columns == 0:
        return heatmap

    draw = ImageDraw.Draw(heatmap, "RGBA")
    for row_index, row in enumerate(interest_map.scores):
        for column_index, score in enumerate(row):
            x0 = round(column_index * width / columns)
            x1 = round((column_index + 1) * width / columns)
            y0 = round(row_index * height / rows)
            y1 = round((row_index + 1) * height / rows)
            draw.rectangle((x0, y0, x1, y1), fill=(*_score_to_heat_color(score), alpha))
    return heatmap


def render_patch_interest_overlay(
    image_path: Path,
    interest_map: PatchInterestMap,
    *,
    alpha: int = 135,
) -> Image.Image:
    """Overlay a patch-interest heatmap on top of the source image."""

    with Image.open(image_path) as image:
        base = image.convert("RGBA")

    heatmap = render_patch_interest_heatmap(interest_map, alpha=alpha)
    return Image.alpha_composite(base, heatmap)


def _mean_pool_features(features: Any) -> Any:
    """Pool token/time dimensions while preserving the final feature dimension."""

    if features.ndim <= 2:
        return features
    return features.mean(dim=tuple(range(1, features.ndim - 1)))


def _tokens_to_patch_scores(features: Any, torch: Any) -> tuple[tuple[float, ...], ...]:
    """Convert model token features to a normalized square patch-score grid."""

    token_features = features.squeeze(0).detach()
    token_count = int(token_features.shape[0])
    if _is_square(token_count - 1):
        token_features = token_features[1:]
        token_count -= 1
    if not _is_square(token_count):
        msg = f"Cannot infer a square patch grid from {token_count} visual tokens."
        raise RuntimeError(msg)

    scores = torch.linalg.vector_norm(token_features.float(), dim=-1)
    min_score = scores.min()
    score_range = scores.max() - min_score
    if float(score_range.detach().cpu()) == 0.0:
        normalized = torch.zeros_like(scores)
    else:
        normalized = (scores - min_score) / score_range

    grid_width = int(token_count**0.5)
    values = normalized.reshape(grid_width, grid_width).detach().cpu().tolist()
    return tuple(tuple(float(value) for value in row) for row in values)


def _is_square(value: int) -> bool:
    if value <= 0:
        return False
    root = int(value**0.5)
    return root * root == value


def _score_to_heat_color(score: float) -> tuple[int, int, int]:
    clamped = max(0.0, min(1.0, score))
    red = 255
    green = int(round(40 + 190 * clamped))
    blue = int(round(30 * (1.0 - clamped)))
    return (red, green, blue)


def embed_record_image(image_path: Path | None, embedder: ImageEmbedder) -> tuple[float, ...]:
    """Embed a record image or raise a clear error when the path is missing."""

    if image_path is None:
        raise MissingImageError("Record has no image path to embed.")
    return embedder.embed_image(image_path)


def _import_optional(module_name: str) -> Any:
    try:
        return import_module(module_name)
    except ImportError as error:
        msg = (
            "I-JEPA dependencies are missing. Install them with "
            "`uv sync --extra ijepa --dev` (or `make sync-ijepa`)."
        )
        raise JepaDependencyError(msg) from error


def _quiet_transformers_logging(transformers: Any) -> None:
    """Reduce noisy dev-version Transformers compatibility logging."""

    try:
        transformers.logging.set_verbosity_error()
    except AttributeError:
        return