File size: 2,622 Bytes
2bc3168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass

import numpy as np
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity

from .ijepa_localization import IJepaPatchLocalizer
from .obstacle_dataset import YoloBox, load_obstacle_image, parse_yolo_boxes


@dataclass
class PrototypeGuess:
    object_index: int
    yolo_label: str
    ijepa_guess: str
    similarity: float
    agreement: bool


def build_class_prototypes(
    dataset_name: str,
    split: str,
    rows,
    localizer: IJepaPatchLocalizer,
) -> dict[str, np.ndarray]:
    embeddings_by_class: dict[str, list[np.ndarray]] = defaultdict(list)
    for row in rows:
        image = load_obstacle_image(dataset_name, row, split)
        for box in parse_yolo_boxes(row):
            crop = safe_crop(image, box)
            if crop is None:
                continue
            embeddings_by_class[box.class_name].append(localizer.embed_image(crop))

    return {
        class_name: np.vstack(embeddings).mean(axis=0)
        for class_name, embeddings in embeddings_by_class.items()
        if embeddings
    }


def guess_objects_with_prototypes(
    image: Image.Image,
    yolo_boxes: list[YoloBox],
    localizer: IJepaPatchLocalizer,
    prototypes: dict[str, np.ndarray],
) -> list[PrototypeGuess]:
    if not prototypes:
        return []

    class_names = list(prototypes)
    prototype_matrix = np.vstack([prototypes[class_name] for class_name in class_names])
    guesses = []
    for index, box in enumerate(yolo_boxes, start=1):
        crop = safe_crop(image, box)
        if crop is None:
            continue
        embedding = localizer.embed_image(crop)
        similarities = cosine_similarity([embedding], prototype_matrix)[0]
        best_index = int(np.argmax(similarities))
        guess = class_names[best_index]
        guesses.append(
            PrototypeGuess(
                object_index=index,
                yolo_label=box.class_name,
                ijepa_guess=guess,
                similarity=float(similarities[best_index]),
                agreement=guess == box.class_name,
            )
        )
    return guesses


def safe_crop(image: Image.Image, box: YoloBox) -> Image.Image | None:
    x1, y1, x2, y2 = box.to_xyxy(*image.size)
    image_width, image_height = image.size
    x1 = max(0, min(image_width - 1, x1))
    x2 = max(0, min(image_width, x2))
    y1 = max(0, min(image_height - 1, y1))
    y2 = max(0, min(image_height, y2))
    if x2 - x1 < 2 or y2 - y1 < 2:
        return None
    return image.crop((x1, y1, x2, y2))