JEPA-demo / src /prototypes.py
ddebree's picture
Upload folder using huggingface_hub
2bc3168 verified
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
import numpy as np
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
from .ijepa_localization import IJepaPatchLocalizer
from .obstacle_dataset import YoloBox, load_obstacle_image, parse_yolo_boxes
@dataclass
class PrototypeGuess:
object_index: int
yolo_label: str
ijepa_guess: str
similarity: float
agreement: bool
def build_class_prototypes(
dataset_name: str,
split: str,
rows,
localizer: IJepaPatchLocalizer,
) -> dict[str, np.ndarray]:
embeddings_by_class: dict[str, list[np.ndarray]] = defaultdict(list)
for row in rows:
image = load_obstacle_image(dataset_name, row, split)
for box in parse_yolo_boxes(row):
crop = safe_crop(image, box)
if crop is None:
continue
embeddings_by_class[box.class_name].append(localizer.embed_image(crop))
return {
class_name: np.vstack(embeddings).mean(axis=0)
for class_name, embeddings in embeddings_by_class.items()
if embeddings
}
def guess_objects_with_prototypes(
image: Image.Image,
yolo_boxes: list[YoloBox],
localizer: IJepaPatchLocalizer,
prototypes: dict[str, np.ndarray],
) -> list[PrototypeGuess]:
if not prototypes:
return []
class_names = list(prototypes)
prototype_matrix = np.vstack([prototypes[class_name] for class_name in class_names])
guesses = []
for index, box in enumerate(yolo_boxes, start=1):
crop = safe_crop(image, box)
if crop is None:
continue
embedding = localizer.embed_image(crop)
similarities = cosine_similarity([embedding], prototype_matrix)[0]
best_index = int(np.argmax(similarities))
guess = class_names[best_index]
guesses.append(
PrototypeGuess(
object_index=index,
yolo_label=box.class_name,
ijepa_guess=guess,
similarity=float(similarities[best_index]),
agreement=guess == box.class_name,
)
)
return guesses
def safe_crop(image: Image.Image, box: YoloBox) -> Image.Image | None:
x1, y1, x2, y2 = box.to_xyxy(*image.size)
image_width, image_height = image.size
x1 = max(0, min(image_width - 1, x1))
x2 = max(0, min(image_width, x2))
y1 = max(0, min(image_height - 1, y1))
y2 = max(0, min(image_height, y2))
if x2 - x1 < 2 or y2 - y1 < 2:
return None
return image.crop((x1, y1, x2, y2))