File size: 2,660 Bytes
2bc3168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations

from dataclasses import dataclass

import numpy as np
from PIL import Image
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

from .ijepa_localization import IJepaPatchLocalizer
from .obstacle_dataset import YoloBox, load_obstacle_image, parse_yolo_boxes
from .prototypes import safe_crop


@dataclass
class HeadGuess:
    object_index: int
    yolo_label: str
    head_guess: str
    confidence: float
    agreement: bool


@dataclass
class TrainedSmallHead:
    classifier: LogisticRegression
    classes: list[str]
    train_accuracy: float
    train_objects: int
    parameter_count: int


def train_small_head(
    dataset_name: str,
    split: str,
    rows,
    localizer: IJepaPatchLocalizer,
) -> TrainedSmallHead | None:
    embeddings = []
    labels = []
    for row in rows:
        image = load_obstacle_image(dataset_name, row, split)
        for box in parse_yolo_boxes(row):
            crop = safe_crop(image, box)
            if crop is None:
                continue
            embeddings.append(localizer.embed_image(crop))
            labels.append(box.class_name)

    if len(set(labels)) < 2 or len(labels) < 4:
        return None

    x = np.vstack(embeddings)
    y = np.asarray(labels)
    classifier = LogisticRegression(max_iter=1000, class_weight="balanced")
    classifier.fit(x, y)
    predictions = classifier.predict(x)
    return TrainedSmallHead(
        classifier=classifier,
        classes=list(classifier.classes_),
        train_accuracy=float(accuracy_score(y, predictions)),
        train_objects=int(len(labels)),
        parameter_count=int(classifier.coef_.size + classifier.intercept_.size),
    )


def guess_objects_with_head(
    image: Image.Image,
    yolo_boxes: list[YoloBox],
    localizer: IJepaPatchLocalizer,
    trained_head: TrainedSmallHead | None,
) -> list[HeadGuess]:
    if trained_head is None:
        return []

    guesses = []
    for index, box in enumerate(yolo_boxes, start=1):
        crop = safe_crop(image, box)
        if crop is None:
            continue
        embedding = localizer.embed_image(crop)
        probabilities = trained_head.classifier.predict_proba([embedding])[0]
        best_index = int(np.argmax(probabilities))
        guess = trained_head.classes[best_index]
        guesses.append(
            HeadGuess(
                object_index=index,
                yolo_label=box.class_name,
                head_guess=guess,
                confidence=float(probabilities[best_index]),
                agreement=guess == box.class_name,
            )
        )
    return guesses