| from __future__ import annotations |
|
|
| from collections import defaultdict |
| from dataclasses import dataclass |
|
|
| import numpy as np |
| from PIL import Image |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| from .ijepa_localization import IJepaPatchLocalizer |
| from .obstacle_dataset import YoloBox, load_obstacle_image, parse_yolo_boxes |
|
|
|
|
| @dataclass |
| class PrototypeGuess: |
| object_index: int |
| yolo_label: str |
| ijepa_guess: str |
| similarity: float |
| agreement: bool |
|
|
|
|
| def build_class_prototypes( |
| dataset_name: str, |
| split: str, |
| rows, |
| localizer: IJepaPatchLocalizer, |
| ) -> dict[str, np.ndarray]: |
| embeddings_by_class: dict[str, list[np.ndarray]] = defaultdict(list) |
| for row in rows: |
| image = load_obstacle_image(dataset_name, row, split) |
| for box in parse_yolo_boxes(row): |
| crop = safe_crop(image, box) |
| if crop is None: |
| continue |
| embeddings_by_class[box.class_name].append(localizer.embed_image(crop)) |
|
|
| return { |
| class_name: np.vstack(embeddings).mean(axis=0) |
| for class_name, embeddings in embeddings_by_class.items() |
| if embeddings |
| } |
|
|
|
|
| def guess_objects_with_prototypes( |
| image: Image.Image, |
| yolo_boxes: list[YoloBox], |
| localizer: IJepaPatchLocalizer, |
| prototypes: dict[str, np.ndarray], |
| ) -> list[PrototypeGuess]: |
| if not prototypes: |
| return [] |
|
|
| class_names = list(prototypes) |
| prototype_matrix = np.vstack([prototypes[class_name] for class_name in class_names]) |
| guesses = [] |
| for index, box in enumerate(yolo_boxes, start=1): |
| crop = safe_crop(image, box) |
| if crop is None: |
| continue |
| embedding = localizer.embed_image(crop) |
| similarities = cosine_similarity([embedding], prototype_matrix)[0] |
| best_index = int(np.argmax(similarities)) |
| guess = class_names[best_index] |
| guesses.append( |
| PrototypeGuess( |
| object_index=index, |
| yolo_label=box.class_name, |
| ijepa_guess=guess, |
| similarity=float(similarities[best_index]), |
| agreement=guess == box.class_name, |
| ) |
| ) |
| return guesses |
|
|
|
|
| def safe_crop(image: Image.Image, box: YoloBox) -> Image.Image | None: |
| x1, y1, x2, y2 = box.to_xyxy(*image.size) |
| image_width, image_height = image.size |
| x1 = max(0, min(image_width - 1, x1)) |
| x2 = max(0, min(image_width, x2)) |
| y1 = max(0, min(image_height - 1, y1)) |
| y2 = max(0, min(image_height, y2)) |
| if x2 - x1 < 2 or y2 - y1 < 2: |
| return None |
| return image.crop((x1, y1, x2, y2)) |
|
|