from __future__ import annotations import os from abc import ABC, abstractmethod import numpy as np class EyeClassifier(ABC): @property @abstractmethod def name(self) -> str: pass @abstractmethod def predict_score(self, crops_bgr: list[np.ndarray]) -> float: pass class GeometricOnlyClassifier(EyeClassifier): @property def name(self) -> str: return "geometric" def predict_score(self, crops_bgr: list[np.ndarray]) -> float: return 1.0 class YOLOv11Classifier(EyeClassifier): def __init__(self, checkpoint_path: str, device: str = "cpu"): from ultralytics import YOLO self._model = YOLO(checkpoint_path) self._device = device names = self._model.names self._attentive_idx = None for idx, cls_name in names.items(): if cls_name in ("open", "attentive"): self._attentive_idx = idx break if self._attentive_idx is None: self._attentive_idx = max(names.keys()) print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}") @property def name(self) -> str: return "yolo" def predict_score(self, crops_bgr: list[np.ndarray]) -> float: if not crops_bgr: return 1.0 results = self._model.predict(crops_bgr, device=self._device, verbose=False) scores = [float(r.probs.data[self._attentive_idx]) for r in results] return sum(scores) / len(scores) if scores else 1.0 class EyeCNNClassifier(EyeClassifier): """Loader for the custom PyTorch EyeCNN (trained on Kaggle eye crops).""" def __init__(self, checkpoint_path: str, device: str = "cpu"): import torch import torch.nn as nn class EyeCNN(nn.Module): def __init__(self, num_classes=2, dropout_rate=0.3): super().__init__() self.conv_layers = nn.Sequential( nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2), ) self.fc_layers = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(256, 512), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(512, num_classes), ) def forward(self, x): return self.fc_layers(self.conv_layers(x)) self._device = torch.device(device) checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False) dropout_rate = checkpoint.get("config", {}).get("dropout_rate", 0.35) self._model = EyeCNN(num_classes=2, dropout_rate=dropout_rate) self._model.load_state_dict(checkpoint["model_state_dict"]) self._model.to(self._device) self._model.eval() self._transform = None # built lazily def _get_transform(self): if self._transform is None: from torchvision import transforms self._transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((96, 96)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ]) return self._transform @property def name(self) -> str: return "eye_cnn" def predict_score(self, crops_bgr: list[np.ndarray]) -> float: if not crops_bgr: return 1.0 import torch import cv2 transform = self._get_transform() scores = [] for crop in crops_bgr: if crop is None or crop.size == 0: scores.append(1.0) continue rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) tensor = transform(rgb).unsqueeze(0).to(self._device) with torch.no_grad(): output = self._model(tensor) prob = torch.softmax(output, dim=1)[0, 1].item() # prob of "open" scores.append(prob) return sum(scores) / len(scores) _EXT_TO_BACKEND = {".pth": "cnn", ".pt": "yolo"} def load_eye_classifier( path: str | None = None, backend: str = "yolo", device: str = "cpu", ) -> EyeClassifier: if backend == "geometric": return GeometricOnlyClassifier() if path is None: print(f"[CLASSIFIER] No model path for backend {backend!r}, falling back to geometric") return GeometricOnlyClassifier() ext = os.path.splitext(path)[1].lower() inferred = _EXT_TO_BACKEND.get(ext) if inferred and inferred != backend: print(f"[CLASSIFIER] File extension {ext!r} implies backend {inferred!r}, " f"overriding requested {backend!r}") backend = inferred print(f"[CLASSIFIER] backend={backend!r}, path={path!r}") if backend == "cnn": return EyeCNNClassifier(path, device=device) if backend == "yolo": try: return YOLOv11Classifier(path, device=device) except ImportError: print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics") raise raise ValueError( f"Unknown eye backend {backend!r}. Choose from: yolo, cnn, geometric" )