File size: 8,993 Bytes
2bc3168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
from __future__ import annotations

from dataclasses import dataclass

import numpy as np
import torch
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoImageProcessor, AutoModel


@dataclass
class LocalizationResult:
    box_xyxy: tuple[int, int, int, int]
    candidate_boxes_xyxy: list[tuple[int, int, int, int]]
    heatmap: np.ndarray
    score: float
    image_embedding: np.ndarray


class IJepaPatchLocalizer:
    """Patch-similarity localizer for I-JEPA-style encoders.

    I-JEPA is not an object detector. This class uses its patch embeddings as a
    representation probe: patches most similar to the image-level embedding are
    treated as the likely salient object region.
    """

    def __init__(self, model_name: str = "facebook/ijepa_vith14_1k", device: str | None = None):
        self.model_name = model_name
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
        self.processor = AutoImageProcessor.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()

    def localize(
        self,
        image: Image.Image,
        threshold_quantile: float = 0.85,
        max_boxes: int = 8,
    ) -> LocalizationResult:
        width, height = image.size
        inputs = self.processor(images=image.convert("RGB"), return_tensors="pt")
        inputs = {key: value.to(self.device) for key, value in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs)

        hidden = outputs.last_hidden_state[0].detach().cpu().float().numpy()
        patch_embeddings = self._patch_tokens(hidden)
        grid_size = int(np.sqrt(len(patch_embeddings)))
        if grid_size * grid_size != len(patch_embeddings):
            raise ValueError(f"Cannot reshape {len(patch_embeddings)} patch tokens into a square grid.")

        image_embedding = patch_embeddings.mean(axis=0, keepdims=True)
        scores = cosine_similarity(patch_embeddings, image_embedding).reshape(grid_size, grid_size)
        heatmap = normalize(scores)
        box = heatmap_to_box(heatmap, width, height, threshold_quantile)
        candidate_boxes = heatmap_to_connected_boxes(
            heatmap,
            width,
            height,
            max_boxes=max_boxes,
            threshold_quantile=threshold_quantile,
        )
        if not candidate_boxes:
            candidate_boxes = heatmap_to_candidate_boxes(heatmap, width, height, max_boxes=max_boxes)
        return LocalizationResult(
            box_xyxy=box,
            candidate_boxes_xyxy=candidate_boxes,
            heatmap=heatmap,
            score=float(heatmap.max()),
            image_embedding=image_embedding[0].astype(np.float32),
        )

    def embed_image(self, image: Image.Image) -> np.ndarray:
        inputs = self.processor(images=image.convert("RGB"), return_tensors="pt")
        inputs = {key: value.to(self.device) for key, value in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs)
        hidden = outputs.last_hidden_state[0].detach().cpu().float().numpy()
        patch_embeddings = self._patch_tokens(hidden)
        return patch_embeddings.mean(axis=0).astype(np.float32)

    @staticmethod
    def _patch_tokens(hidden: np.ndarray) -> np.ndarray:
        token_count = hidden.shape[0]
        grid_with_cls = int(np.sqrt(token_count - 1))
        if grid_with_cls * grid_with_cls == token_count - 1:
            return hidden[1:]
        grid_without_cls = int(np.sqrt(token_count))
        if grid_without_cls * grid_without_cls == token_count:
            return hidden
        return hidden[1:]


def normalize(values: np.ndarray) -> np.ndarray:
    min_value = float(values.min())
    max_value = float(values.max())
    if max_value == min_value:
        return np.zeros_like(values, dtype=np.float32)
    return ((values - min_value) / (max_value - min_value)).astype(np.float32)


def heatmap_to_box(
    heatmap: np.ndarray,
    image_width: int,
    image_height: int,
    threshold_quantile: float,
) -> tuple[int, int, int, int]:
    threshold = float(np.quantile(heatmap, threshold_quantile))
    ys, xs = np.where(heatmap >= threshold)
    if len(xs) == 0 or len(ys) == 0:
        best_y, best_x = np.unravel_index(np.argmax(heatmap), heatmap.shape)
        xs = np.array([best_x])
        ys = np.array([best_y])

    grid_h, grid_w = heatmap.shape
    x1 = int(xs.min() / grid_w * image_width)
    y1 = int(ys.min() / grid_h * image_height)
    x2 = int((xs.max() + 1) / grid_w * image_width)
    y2 = int((ys.max() + 1) / grid_h * image_height)
    return x1, y1, min(image_width - 1, x2), min(image_height - 1, y2)


def heatmap_to_candidate_boxes(
    heatmap: np.ndarray,
    image_width: int,
    image_height: int,
    max_boxes: int = 8,
    min_distance: int = 2,
) -> list[tuple[int, int, int, int]]:
    grid_h, grid_w = heatmap.shape
    flat_indices = np.argsort(heatmap.reshape(-1))[::-1]
    peaks: list[tuple[int, int]] = []

    for flat_index in flat_indices:
        y, x = np.unravel_index(flat_index, heatmap.shape)
        if any(abs(y - py) <= min_distance and abs(x - px) <= min_distance for py, px in peaks):
            continue
        peaks.append((int(y), int(x)))
        if len(peaks) >= max_boxes:
            break

    cell_w = image_width / grid_w
    cell_h = image_height / grid_h
    boxes = []
    for y, x in peaks:
        x1 = int(max(0, (x - 1) * cell_w))
        y1 = int(max(0, (y - 1) * cell_h))
        x2 = int(min(image_width - 1, (x + 2) * cell_w))
        y2 = int(min(image_height - 1, (y + 2) * cell_h))
        boxes.append((x1, y1, x2, y2))
    return boxes


def heatmap_to_connected_boxes(
    heatmap: np.ndarray,
    image_width: int,
    image_height: int,
    max_boxes: int = 8,
    threshold_quantile: float = 0.82,
    min_cells: int = 2,
) -> list[tuple[int, int, int, int]]:
    threshold = max(float(np.quantile(heatmap, threshold_quantile)), 0.55)
    mask = heatmap > threshold
    components = connected_components(mask)
    grid_h, grid_w = heatmap.shape
    candidates = []

    for component in components:
        if len(component) < min_cells:
            continue
        ys = np.array([cell[0] for cell in component])
        xs = np.array([cell[1] for cell in component])
        score = float(heatmap[ys, xs].mean())
        x1 = int(xs.min() / grid_w * image_width)
        y1 = int(ys.min() / grid_h * image_height)
        x2 = int((xs.max() + 1) / grid_w * image_width)
        y2 = int((ys.max() + 1) / grid_h * image_height)
        box = (x1, y1, min(image_width - 1, x2), min(image_height - 1, y2))
        if box_area(box) < image_width * image_height * 0.005:
            continue
        candidates.append((score, box))

    candidates.sort(key=lambda item: item[0], reverse=True)
    boxes = non_max_suppression([box for _, box in candidates], iou_threshold=0.25)
    return boxes[:max_boxes]


def connected_components(mask: np.ndarray) -> list[list[tuple[int, int]]]:
    visited = np.zeros(mask.shape, dtype=bool)
    components: list[list[tuple[int, int]]] = []
    height, width = mask.shape

    for y in range(height):
        for x in range(width):
            if visited[y, x] or not mask[y, x]:
                continue
            stack = [(y, x)]
            visited[y, x] = True
            component = []
            while stack:
                cy, cx = stack.pop()
                component.append((cy, cx))
                for ny in range(max(0, cy - 1), min(height, cy + 2)):
                    for nx in range(max(0, cx - 1), min(width, cx + 2)):
                        if visited[ny, nx] or not mask[ny, nx]:
                            continue
                        visited[ny, nx] = True
                        stack.append((ny, nx))
            components.append(component)
    return components


def non_max_suppression(
    boxes: list[tuple[int, int, int, int]],
    iou_threshold: float,
) -> list[tuple[int, int, int, int]]:
    kept: list[tuple[int, int, int, int]] = []
    for box in boxes:
        if all(iou(box, kept_box) < iou_threshold for kept_box in kept):
            kept.append(box)
    return kept


def box_area(box: tuple[int, int, int, int]) -> int:
    x1, y1, x2, y2 = box
    return max(0, x2 - x1) * max(0, y2 - y1)


def iou(box_a: tuple[int, int, int, int], box_b: tuple[int, int, int, int]) -> float:
    ax1, ay1, ax2, ay2 = box_a
    bx1, by1, bx2, by2 = box_b
    inter_x1 = max(ax1, bx1)
    inter_y1 = max(ay1, by1)
    inter_x2 = min(ax2, bx2)
    inter_y2 = min(ay2, by2)
    inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
    area_a = max(0, ax2 - ax1) * max(0, ay2 - ay1)
    area_b = max(0, bx2 - bx1) * max(0, by2 - by1)
    union = area_a + area_b - inter_area
    if union <= 0:
        return 0.0
    return float(inter_area / union)