File size: 2,660 Bytes
2bc3168 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from PIL import Image
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from .ijepa_localization import IJepaPatchLocalizer
from .obstacle_dataset import YoloBox, load_obstacle_image, parse_yolo_boxes
from .prototypes import safe_crop
@dataclass
class HeadGuess:
object_index: int
yolo_label: str
head_guess: str
confidence: float
agreement: bool
@dataclass
class TrainedSmallHead:
classifier: LogisticRegression
classes: list[str]
train_accuracy: float
train_objects: int
parameter_count: int
def train_small_head(
dataset_name: str,
split: str,
rows,
localizer: IJepaPatchLocalizer,
) -> TrainedSmallHead | None:
embeddings = []
labels = []
for row in rows:
image = load_obstacle_image(dataset_name, row, split)
for box in parse_yolo_boxes(row):
crop = safe_crop(image, box)
if crop is None:
continue
embeddings.append(localizer.embed_image(crop))
labels.append(box.class_name)
if len(set(labels)) < 2 or len(labels) < 4:
return None
x = np.vstack(embeddings)
y = np.asarray(labels)
classifier = LogisticRegression(max_iter=1000, class_weight="balanced")
classifier.fit(x, y)
predictions = classifier.predict(x)
return TrainedSmallHead(
classifier=classifier,
classes=list(classifier.classes_),
train_accuracy=float(accuracy_score(y, predictions)),
train_objects=int(len(labels)),
parameter_count=int(classifier.coef_.size + classifier.intercept_.size),
)
def guess_objects_with_head(
image: Image.Image,
yolo_boxes: list[YoloBox],
localizer: IJepaPatchLocalizer,
trained_head: TrainedSmallHead | None,
) -> list[HeadGuess]:
if trained_head is None:
return []
guesses = []
for index, box in enumerate(yolo_boxes, start=1):
crop = safe_crop(image, box)
if crop is None:
continue
embedding = localizer.embed_image(crop)
probabilities = trained_head.classifier.predict_proba([embedding])[0]
best_index = int(np.argmax(probabilities))
guess = trained_head.classes[best_index]
guesses.append(
HeadGuess(
object_index=index,
yolo_label=box.class_name,
head_guess=guess,
confidence=float(probabilities[best_index]),
agreement=guess == box.class_name,
)
)
return guesses
|