from __future__ import annotations from dataclasses import dataclass import numpy as np from PIL import Image from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score from .ijepa_localization import IJepaPatchLocalizer from .obstacle_dataset import YoloBox, load_obstacle_image, parse_yolo_boxes from .prototypes import safe_crop @dataclass class HeadGuess: object_index: int yolo_label: str head_guess: str confidence: float agreement: bool @dataclass class TrainedSmallHead: classifier: LogisticRegression classes: list[str] train_accuracy: float train_objects: int parameter_count: int def train_small_head( dataset_name: str, split: str, rows, localizer: IJepaPatchLocalizer, ) -> TrainedSmallHead | None: embeddings = [] labels = [] for row in rows: image = load_obstacle_image(dataset_name, row, split) for box in parse_yolo_boxes(row): crop = safe_crop(image, box) if crop is None: continue embeddings.append(localizer.embed_image(crop)) labels.append(box.class_name) if len(set(labels)) < 2 or len(labels) < 4: return None x = np.vstack(embeddings) y = np.asarray(labels) classifier = LogisticRegression(max_iter=1000, class_weight="balanced") classifier.fit(x, y) predictions = classifier.predict(x) return TrainedSmallHead( classifier=classifier, classes=list(classifier.classes_), train_accuracy=float(accuracy_score(y, predictions)), train_objects=int(len(labels)), parameter_count=int(classifier.coef_.size + classifier.intercept_.size), ) def guess_objects_with_head( image: Image.Image, yolo_boxes: list[YoloBox], localizer: IJepaPatchLocalizer, trained_head: TrainedSmallHead | None, ) -> list[HeadGuess]: if trained_head is None: return [] guesses = [] for index, box in enumerate(yolo_boxes, start=1): crop = safe_crop(image, box) if crop is None: continue embedding = localizer.embed_image(crop) probabilities = trained_head.classifier.predict_proba([embedding])[0] best_index = int(np.argmax(probabilities)) guess = trained_head.classes[best_index] guesses.append( HeadGuess( object_index=index, yolo_label=box.class_name, head_guess=guess, confidence=float(probabilities[best_index]), agreement=guess == box.class_name, ) ) return guesses