Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from abc import ABC, abstractmethod | |
| import numpy as np | |
| class EyeClassifier(ABC): | |
| def name(self) -> str: | |
| pass | |
| def predict_score(self, crops_bgr: list[np.ndarray]) -> float: | |
| pass | |
| class GeometricOnlyClassifier(EyeClassifier): | |
| def name(self) -> str: | |
| return "geometric" | |
| def predict_score(self, crops_bgr: list[np.ndarray]) -> float: | |
| return 1.0 | |
| class YOLOv11Classifier(EyeClassifier): | |
| def __init__(self, checkpoint_path: str, device: str = "cpu"): | |
| from ultralytics import YOLO | |
| self._model = YOLO(checkpoint_path) | |
| self._device = device | |
| names = self._model.names | |
| self._attentive_idx = None | |
| for idx, cls_name in names.items(): | |
| if cls_name in ("open", "attentive"): | |
| self._attentive_idx = idx | |
| break | |
| if self._attentive_idx is None: | |
| self._attentive_idx = max(names.keys()) | |
| print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}") | |
| def name(self) -> str: | |
| return "yolo" | |
| def predict_score(self, crops_bgr: list[np.ndarray]) -> float: | |
| if not crops_bgr: | |
| return 1.0 | |
| results = self._model.predict(crops_bgr, device=self._device, verbose=False) | |
| scores = [float(r.probs.data[self._attentive_idx]) for r in results] | |
| return sum(scores) / len(scores) if scores else 1.0 | |
| class EyeCNNClassifier(EyeClassifier): | |
| """Loader for the custom PyTorch EyeCNN (trained on Kaggle eye crops).""" | |
| def __init__(self, checkpoint_path: str, device: str = "cpu"): | |
| import torch | |
| import torch.nn as nn | |
| class EyeCNN(nn.Module): | |
| def __init__(self, num_classes=2, dropout_rate=0.3): | |
| super().__init__() | |
| self.conv_layers = nn.Sequential( | |
| nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2), | |
| nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2), | |
| nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2), | |
| nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2), | |
| ) | |
| self.fc_layers = nn.Sequential( | |
| nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), | |
| nn.Linear(256, 512), nn.ReLU(), nn.Dropout(dropout_rate), | |
| nn.Linear(512, num_classes), | |
| ) | |
| def forward(self, x): | |
| return self.fc_layers(self.conv_layers(x)) | |
| self._device = torch.device(device) | |
| checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False) | |
| dropout_rate = checkpoint.get("config", {}).get("dropout_rate", 0.35) | |
| self._model = EyeCNN(num_classes=2, dropout_rate=dropout_rate) | |
| self._model.load_state_dict(checkpoint["model_state_dict"]) | |
| self._model.to(self._device) | |
| self._model.eval() | |
| self._transform = None # built lazily | |
| def _get_transform(self): | |
| if self._transform is None: | |
| from torchvision import transforms | |
| self._transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((96, 96)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ]) | |
| return self._transform | |
| def name(self) -> str: | |
| return "eye_cnn" | |
| def predict_score(self, crops_bgr: list[np.ndarray]) -> float: | |
| if not crops_bgr: | |
| return 1.0 | |
| import torch | |
| import cv2 | |
| transform = self._get_transform() | |
| scores = [] | |
| for crop in crops_bgr: | |
| if crop is None or crop.size == 0: | |
| scores.append(1.0) | |
| continue | |
| rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) | |
| tensor = transform(rgb).unsqueeze(0).to(self._device) | |
| with torch.no_grad(): | |
| output = self._model(tensor) | |
| prob = torch.softmax(output, dim=1)[0, 1].item() # prob of "open" | |
| scores.append(prob) | |
| return sum(scores) / len(scores) | |
| _EXT_TO_BACKEND = {".pth": "cnn", ".pt": "yolo"} | |
| def load_eye_classifier( | |
| path: str | None = None, | |
| backend: str = "yolo", | |
| device: str = "cpu", | |
| ) -> EyeClassifier: | |
| if backend == "geometric": | |
| return GeometricOnlyClassifier() | |
| if path is None: | |
| print(f"[CLASSIFIER] No model path for backend {backend!r}, falling back to geometric") | |
| return GeometricOnlyClassifier() | |
| ext = os.path.splitext(path)[1].lower() | |
| inferred = _EXT_TO_BACKEND.get(ext) | |
| if inferred and inferred != backend: | |
| print(f"[CLASSIFIER] File extension {ext!r} implies backend {inferred!r}, " | |
| f"overriding requested {backend!r}") | |
| backend = inferred | |
| print(f"[CLASSIFIER] backend={backend!r}, path={path!r}") | |
| if backend == "cnn": | |
| return EyeCNNClassifier(path, device=device) | |
| if backend == "yolo": | |
| try: | |
| return YOLOv11Classifier(path, device=device) | |
| except ImportError: | |
| print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics") | |
| raise | |
| raise ValueError( | |
| f"Unknown eye backend {backend!r}. Choose from: yolo, cnn, geometric" | |
| ) |