from __future__ import annotations from dataclasses import dataclass import numpy as np import torch from PIL import Image from sklearn.metrics.pairwise import cosine_similarity from transformers import AutoImageProcessor, AutoModel @dataclass class LocalizationResult: box_xyxy: tuple[int, int, int, int] candidate_boxes_xyxy: list[tuple[int, int, int, int]] heatmap: np.ndarray score: float image_embedding: np.ndarray class IJepaPatchLocalizer: """Patch-similarity localizer for I-JEPA-style encoders. I-JEPA is not an object detector. This class uses its patch embeddings as a representation probe: patches most similar to the image-level embedding are treated as the likely salient object region. """ def __init__(self, model_name: str = "facebook/ijepa_vith14_1k", device: str | None = None): self.model_name = model_name self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) self.processor = AutoImageProcessor.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.model.to(self.device) self.model.eval() def localize( self, image: Image.Image, threshold_quantile: float = 0.85, max_boxes: int = 8, ) -> LocalizationResult: width, height = image.size inputs = self.processor(images=image.convert("RGB"), return_tensors="pt") inputs = {key: value.to(self.device) for key, value in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) hidden = outputs.last_hidden_state[0].detach().cpu().float().numpy() patch_embeddings = self._patch_tokens(hidden) grid_size = int(np.sqrt(len(patch_embeddings))) if grid_size * grid_size != len(patch_embeddings): raise ValueError(f"Cannot reshape {len(patch_embeddings)} patch tokens into a square grid.") image_embedding = patch_embeddings.mean(axis=0, keepdims=True) scores = cosine_similarity(patch_embeddings, image_embedding).reshape(grid_size, grid_size) heatmap = normalize(scores) box = heatmap_to_box(heatmap, width, height, threshold_quantile) candidate_boxes = heatmap_to_connected_boxes( heatmap, width, height, max_boxes=max_boxes, threshold_quantile=threshold_quantile, ) if not candidate_boxes: candidate_boxes = heatmap_to_candidate_boxes(heatmap, width, height, max_boxes=max_boxes) return LocalizationResult( box_xyxy=box, candidate_boxes_xyxy=candidate_boxes, heatmap=heatmap, score=float(heatmap.max()), image_embedding=image_embedding[0].astype(np.float32), ) def embed_image(self, image: Image.Image) -> np.ndarray: inputs = self.processor(images=image.convert("RGB"), return_tensors="pt") inputs = {key: value.to(self.device) for key, value in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) hidden = outputs.last_hidden_state[0].detach().cpu().float().numpy() patch_embeddings = self._patch_tokens(hidden) return patch_embeddings.mean(axis=0).astype(np.float32) @staticmethod def _patch_tokens(hidden: np.ndarray) -> np.ndarray: token_count = hidden.shape[0] grid_with_cls = int(np.sqrt(token_count - 1)) if grid_with_cls * grid_with_cls == token_count - 1: return hidden[1:] grid_without_cls = int(np.sqrt(token_count)) if grid_without_cls * grid_without_cls == token_count: return hidden return hidden[1:] def normalize(values: np.ndarray) -> np.ndarray: min_value = float(values.min()) max_value = float(values.max()) if max_value == min_value: return np.zeros_like(values, dtype=np.float32) return ((values - min_value) / (max_value - min_value)).astype(np.float32) def heatmap_to_box( heatmap: np.ndarray, image_width: int, image_height: int, threshold_quantile: float, ) -> tuple[int, int, int, int]: threshold = float(np.quantile(heatmap, threshold_quantile)) ys, xs = np.where(heatmap >= threshold) if len(xs) == 0 or len(ys) == 0: best_y, best_x = np.unravel_index(np.argmax(heatmap), heatmap.shape) xs = np.array([best_x]) ys = np.array([best_y]) grid_h, grid_w = heatmap.shape x1 = int(xs.min() / grid_w * image_width) y1 = int(ys.min() / grid_h * image_height) x2 = int((xs.max() + 1) / grid_w * image_width) y2 = int((ys.max() + 1) / grid_h * image_height) return x1, y1, min(image_width - 1, x2), min(image_height - 1, y2) def heatmap_to_candidate_boxes( heatmap: np.ndarray, image_width: int, image_height: int, max_boxes: int = 8, min_distance: int = 2, ) -> list[tuple[int, int, int, int]]: grid_h, grid_w = heatmap.shape flat_indices = np.argsort(heatmap.reshape(-1))[::-1] peaks: list[tuple[int, int]] = [] for flat_index in flat_indices: y, x = np.unravel_index(flat_index, heatmap.shape) if any(abs(y - py) <= min_distance and abs(x - px) <= min_distance for py, px in peaks): continue peaks.append((int(y), int(x))) if len(peaks) >= max_boxes: break cell_w = image_width / grid_w cell_h = image_height / grid_h boxes = [] for y, x in peaks: x1 = int(max(0, (x - 1) * cell_w)) y1 = int(max(0, (y - 1) * cell_h)) x2 = int(min(image_width - 1, (x + 2) * cell_w)) y2 = int(min(image_height - 1, (y + 2) * cell_h)) boxes.append((x1, y1, x2, y2)) return boxes def heatmap_to_connected_boxes( heatmap: np.ndarray, image_width: int, image_height: int, max_boxes: int = 8, threshold_quantile: float = 0.82, min_cells: int = 2, ) -> list[tuple[int, int, int, int]]: threshold = max(float(np.quantile(heatmap, threshold_quantile)), 0.55) mask = heatmap > threshold components = connected_components(mask) grid_h, grid_w = heatmap.shape candidates = [] for component in components: if len(component) < min_cells: continue ys = np.array([cell[0] for cell in component]) xs = np.array([cell[1] for cell in component]) score = float(heatmap[ys, xs].mean()) x1 = int(xs.min() / grid_w * image_width) y1 = int(ys.min() / grid_h * image_height) x2 = int((xs.max() + 1) / grid_w * image_width) y2 = int((ys.max() + 1) / grid_h * image_height) box = (x1, y1, min(image_width - 1, x2), min(image_height - 1, y2)) if box_area(box) < image_width * image_height * 0.005: continue candidates.append((score, box)) candidates.sort(key=lambda item: item[0], reverse=True) boxes = non_max_suppression([box for _, box in candidates], iou_threshold=0.25) return boxes[:max_boxes] def connected_components(mask: np.ndarray) -> list[list[tuple[int, int]]]: visited = np.zeros(mask.shape, dtype=bool) components: list[list[tuple[int, int]]] = [] height, width = mask.shape for y in range(height): for x in range(width): if visited[y, x] or not mask[y, x]: continue stack = [(y, x)] visited[y, x] = True component = [] while stack: cy, cx = stack.pop() component.append((cy, cx)) for ny in range(max(0, cy - 1), min(height, cy + 2)): for nx in range(max(0, cx - 1), min(width, cx + 2)): if visited[ny, nx] or not mask[ny, nx]: continue visited[ny, nx] = True stack.append((ny, nx)) components.append(component) return components def non_max_suppression( boxes: list[tuple[int, int, int, int]], iou_threshold: float, ) -> list[tuple[int, int, int, int]]: kept: list[tuple[int, int, int, int]] = [] for box in boxes: if all(iou(box, kept_box) < iou_threshold for kept_box in kept): kept.append(box) return kept def box_area(box: tuple[int, int, int, int]) -> int: x1, y1, x2, y2 = box return max(0, x2 - x1) * max(0, y2 - y1) def iou(box_a: tuple[int, int, int, int], box_b: tuple[int, int, int, int]) -> float: ax1, ay1, ax2, ay2 = box_a bx1, by1, bx2, by2 = box_b inter_x1 = max(ax1, bx1) inter_y1 = max(ay1, by1) inter_x2 = min(ax2, bx2) inter_y2 = min(ay2, by2) inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) area_a = max(0, ax2 - ax1) * max(0, ay2 - ay1) area_b = max(0, bx2 - bx1) * max(0, by2 - by1) union = area_a + area_b - inter_area if union <= 0: return 0.0 return float(inter_area / union)