JEPA-demo / src /ijepa_localization.py
ddebree's picture
Upload folder using huggingface_hub
2bc3168 verified
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import torch
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoImageProcessor, AutoModel
@dataclass
class LocalizationResult:
box_xyxy: tuple[int, int, int, int]
candidate_boxes_xyxy: list[tuple[int, int, int, int]]
heatmap: np.ndarray
score: float
image_embedding: np.ndarray
class IJepaPatchLocalizer:
"""Patch-similarity localizer for I-JEPA-style encoders.
I-JEPA is not an object detector. This class uses its patch embeddings as a
representation probe: patches most similar to the image-level embedding are
treated as the likely salient object region.
"""
def __init__(self, model_name: str = "facebook/ijepa_vith14_1k", device: str | None = None):
self.model_name = model_name
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
def localize(
self,
image: Image.Image,
threshold_quantile: float = 0.85,
max_boxes: int = 8,
) -> LocalizationResult:
width, height = image.size
inputs = self.processor(images=image.convert("RGB"), return_tensors="pt")
inputs = {key: value.to(self.device) for key, value in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
hidden = outputs.last_hidden_state[0].detach().cpu().float().numpy()
patch_embeddings = self._patch_tokens(hidden)
grid_size = int(np.sqrt(len(patch_embeddings)))
if grid_size * grid_size != len(patch_embeddings):
raise ValueError(f"Cannot reshape {len(patch_embeddings)} patch tokens into a square grid.")
image_embedding = patch_embeddings.mean(axis=0, keepdims=True)
scores = cosine_similarity(patch_embeddings, image_embedding).reshape(grid_size, grid_size)
heatmap = normalize(scores)
box = heatmap_to_box(heatmap, width, height, threshold_quantile)
candidate_boxes = heatmap_to_connected_boxes(
heatmap,
width,
height,
max_boxes=max_boxes,
threshold_quantile=threshold_quantile,
)
if not candidate_boxes:
candidate_boxes = heatmap_to_candidate_boxes(heatmap, width, height, max_boxes=max_boxes)
return LocalizationResult(
box_xyxy=box,
candidate_boxes_xyxy=candidate_boxes,
heatmap=heatmap,
score=float(heatmap.max()),
image_embedding=image_embedding[0].astype(np.float32),
)
def embed_image(self, image: Image.Image) -> np.ndarray:
inputs = self.processor(images=image.convert("RGB"), return_tensors="pt")
inputs = {key: value.to(self.device) for key, value in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
hidden = outputs.last_hidden_state[0].detach().cpu().float().numpy()
patch_embeddings = self._patch_tokens(hidden)
return patch_embeddings.mean(axis=0).astype(np.float32)
@staticmethod
def _patch_tokens(hidden: np.ndarray) -> np.ndarray:
token_count = hidden.shape[0]
grid_with_cls = int(np.sqrt(token_count - 1))
if grid_with_cls * grid_with_cls == token_count - 1:
return hidden[1:]
grid_without_cls = int(np.sqrt(token_count))
if grid_without_cls * grid_without_cls == token_count:
return hidden
return hidden[1:]
def normalize(values: np.ndarray) -> np.ndarray:
min_value = float(values.min())
max_value = float(values.max())
if max_value == min_value:
return np.zeros_like(values, dtype=np.float32)
return ((values - min_value) / (max_value - min_value)).astype(np.float32)
def heatmap_to_box(
heatmap: np.ndarray,
image_width: int,
image_height: int,
threshold_quantile: float,
) -> tuple[int, int, int, int]:
threshold = float(np.quantile(heatmap, threshold_quantile))
ys, xs = np.where(heatmap >= threshold)
if len(xs) == 0 or len(ys) == 0:
best_y, best_x = np.unravel_index(np.argmax(heatmap), heatmap.shape)
xs = np.array([best_x])
ys = np.array([best_y])
grid_h, grid_w = heatmap.shape
x1 = int(xs.min() / grid_w * image_width)
y1 = int(ys.min() / grid_h * image_height)
x2 = int((xs.max() + 1) / grid_w * image_width)
y2 = int((ys.max() + 1) / grid_h * image_height)
return x1, y1, min(image_width - 1, x2), min(image_height - 1, y2)
def heatmap_to_candidate_boxes(
heatmap: np.ndarray,
image_width: int,
image_height: int,
max_boxes: int = 8,
min_distance: int = 2,
) -> list[tuple[int, int, int, int]]:
grid_h, grid_w = heatmap.shape
flat_indices = np.argsort(heatmap.reshape(-1))[::-1]
peaks: list[tuple[int, int]] = []
for flat_index in flat_indices:
y, x = np.unravel_index(flat_index, heatmap.shape)
if any(abs(y - py) <= min_distance and abs(x - px) <= min_distance for py, px in peaks):
continue
peaks.append((int(y), int(x)))
if len(peaks) >= max_boxes:
break
cell_w = image_width / grid_w
cell_h = image_height / grid_h
boxes = []
for y, x in peaks:
x1 = int(max(0, (x - 1) * cell_w))
y1 = int(max(0, (y - 1) * cell_h))
x2 = int(min(image_width - 1, (x + 2) * cell_w))
y2 = int(min(image_height - 1, (y + 2) * cell_h))
boxes.append((x1, y1, x2, y2))
return boxes
def heatmap_to_connected_boxes(
heatmap: np.ndarray,
image_width: int,
image_height: int,
max_boxes: int = 8,
threshold_quantile: float = 0.82,
min_cells: int = 2,
) -> list[tuple[int, int, int, int]]:
threshold = max(float(np.quantile(heatmap, threshold_quantile)), 0.55)
mask = heatmap > threshold
components = connected_components(mask)
grid_h, grid_w = heatmap.shape
candidates = []
for component in components:
if len(component) < min_cells:
continue
ys = np.array([cell[0] for cell in component])
xs = np.array([cell[1] for cell in component])
score = float(heatmap[ys, xs].mean())
x1 = int(xs.min() / grid_w * image_width)
y1 = int(ys.min() / grid_h * image_height)
x2 = int((xs.max() + 1) / grid_w * image_width)
y2 = int((ys.max() + 1) / grid_h * image_height)
box = (x1, y1, min(image_width - 1, x2), min(image_height - 1, y2))
if box_area(box) < image_width * image_height * 0.005:
continue
candidates.append((score, box))
candidates.sort(key=lambda item: item[0], reverse=True)
boxes = non_max_suppression([box for _, box in candidates], iou_threshold=0.25)
return boxes[:max_boxes]
def connected_components(mask: np.ndarray) -> list[list[tuple[int, int]]]:
visited = np.zeros(mask.shape, dtype=bool)
components: list[list[tuple[int, int]]] = []
height, width = mask.shape
for y in range(height):
for x in range(width):
if visited[y, x] or not mask[y, x]:
continue
stack = [(y, x)]
visited[y, x] = True
component = []
while stack:
cy, cx = stack.pop()
component.append((cy, cx))
for ny in range(max(0, cy - 1), min(height, cy + 2)):
for nx in range(max(0, cx - 1), min(width, cx + 2)):
if visited[ny, nx] or not mask[ny, nx]:
continue
visited[ny, nx] = True
stack.append((ny, nx))
components.append(component)
return components
def non_max_suppression(
boxes: list[tuple[int, int, int, int]],
iou_threshold: float,
) -> list[tuple[int, int, int, int]]:
kept: list[tuple[int, int, int, int]] = []
for box in boxes:
if all(iou(box, kept_box) < iou_threshold for kept_box in kept):
kept.append(box)
return kept
def box_area(box: tuple[int, int, int, int]) -> int:
x1, y1, x2, y2 = box
return max(0, x2 - x1) * max(0, y2 - y1)
def iou(box_a: tuple[int, int, int, int], box_b: tuple[int, int, int, int]) -> float:
ax1, ay1, ax2, ay2 = box_a
bx1, by1, bx2, by2 = box_b
inter_x1 = max(ax1, bx1)
inter_y1 = max(ay1, by1)
inter_x2 = min(ax2, bx2)
inter_y2 = min(ay2, by2)
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
area_a = max(0, ax2 - ax1) * max(0, ay2 - ay1)
area_b = max(0, bx2 - bx1) * max(0, by2 - by1)
union = area_a + area_b - inter_area
if union <= 0:
return 0.0
return float(inter_area / union)