Yingtao-Zheng's picture
Upload partially updated files
8bbb872
from __future__ import annotations
import os
from abc import ABC, abstractmethod
import numpy as np
class EyeClassifier(ABC):
@property
@abstractmethod
def name(self) -> str:
pass
@abstractmethod
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
pass
class GeometricOnlyClassifier(EyeClassifier):
@property
def name(self) -> str:
return "geometric"
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
return 1.0
class YOLOv11Classifier(EyeClassifier):
def __init__(self, checkpoint_path: str, device: str = "cpu"):
from ultralytics import YOLO
self._model = YOLO(checkpoint_path)
self._device = device
names = self._model.names
self._attentive_idx = None
for idx, cls_name in names.items():
if cls_name in ("open", "attentive"):
self._attentive_idx = idx
break
if self._attentive_idx is None:
self._attentive_idx = max(names.keys())
print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
@property
def name(self) -> str:
return "yolo"
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
if not crops_bgr:
return 1.0
results = self._model.predict(crops_bgr, device=self._device, verbose=False)
scores = [float(r.probs.data[self._attentive_idx]) for r in results]
return sum(scores) / len(scores) if scores else 1.0
class EyeCNNClassifier(EyeClassifier):
"""Loader for the custom PyTorch EyeCNN (trained on Kaggle eye crops)."""
def __init__(self, checkpoint_path: str, device: str = "cpu"):
import torch
import torch.nn as nn
class EyeCNN(nn.Module):
def __init__(self, num_classes=2, dropout_rate=0.3):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2),
)
self.fc_layers = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
nn.Linear(256, 512), nn.ReLU(), nn.Dropout(dropout_rate),
nn.Linear(512, num_classes),
)
def forward(self, x):
return self.fc_layers(self.conv_layers(x))
self._device = torch.device(device)
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False)
dropout_rate = checkpoint.get("config", {}).get("dropout_rate", 0.35)
self._model = EyeCNN(num_classes=2, dropout_rate=dropout_rate)
self._model.load_state_dict(checkpoint["model_state_dict"])
self._model.to(self._device)
self._model.eval()
self._transform = None # built lazily
def _get_transform(self):
if self._transform is None:
from torchvision import transforms
self._transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
return self._transform
@property
def name(self) -> str:
return "eye_cnn"
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
if not crops_bgr:
return 1.0
import torch
import cv2
transform = self._get_transform()
scores = []
for crop in crops_bgr:
if crop is None or crop.size == 0:
scores.append(1.0)
continue
rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
tensor = transform(rgb).unsqueeze(0).to(self._device)
with torch.no_grad():
output = self._model(tensor)
prob = torch.softmax(output, dim=1)[0, 1].item() # prob of "open"
scores.append(prob)
return sum(scores) / len(scores)
_EXT_TO_BACKEND = {".pth": "cnn", ".pt": "yolo"}
def load_eye_classifier(
path: str | None = None,
backend: str = "yolo",
device: str = "cpu",
) -> EyeClassifier:
if backend == "geometric":
return GeometricOnlyClassifier()
if path is None:
print(f"[CLASSIFIER] No model path for backend {backend!r}, falling back to geometric")
return GeometricOnlyClassifier()
ext = os.path.splitext(path)[1].lower()
inferred = _EXT_TO_BACKEND.get(ext)
if inferred and inferred != backend:
print(f"[CLASSIFIER] File extension {ext!r} implies backend {inferred!r}, "
f"overriding requested {backend!r}")
backend = inferred
print(f"[CLASSIFIER] backend={backend!r}, path={path!r}")
if backend == "cnn":
return EyeCNNClassifier(path, device=device)
if backend == "yolo":
try:
return YOLOv11Classifier(path, device=device)
except ImportError:
print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
raise
raise ValueError(
f"Unknown eye backend {backend!r}. Choose from: yolo, cnn, geometric"
)