#!/usr/bin/env python3 """ Train a CFP image-quality-control model on EyeQ / EyePACS-style data. Expected dataset format ----------------------- EyePACS/ train/ 10009_left.jpeg 10009_right.jpeg ... test/ ... data/ Label_EyeQ_train.csv Label_EyeQ_test.csv Label CSV format: ,image,quality,DR_grade 0,10009_left.jpeg,0,0 1,10009_right.jpeg,0,0 2,10014_left.jpeg,2,0 For EyeQ, this script assumes: quality = 0 -> Good quality = 1 -> Usable quality = 2 -> Reject DR_grade is ignored because this script trains only the image-quality model. Example ------- python EyeQ_train.py \ --images_dir /path/to/EyePACS \ --csv_dir /path/to/data \ --output_dir ./runs/eyeq_vit_base \ --epochs 30 \ --batch_size 32 \ --lr 3e-5 """ import argparse import random from pathlib import Path from typing import Dict, Tuple import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms import timm from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix from tqdm import tqdm ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"} LABEL_TO_ID: Dict[str, int] = { "good": 0, "usable": 1, "reject": 2, "0": 0, "1": 1, "2": 2, } class EyeQDataset(Dataset): def __init__(self, df: pd.DataFrame, images_dir: str, transform=None): self.df = df.reset_index(drop=True) self.images_dir = Path(images_dir) self.transform = transform def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] image_path = self.images_dir / str(row["image"]) image = Image.open(image_path).convert("RGB") label = int(row["quality"]) if self.transform is not None: image = self.transform(image) return image, label def seed_everything(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = True def normalize_quality_label(x) -> int: key = str(x).strip().lower() if key in LABEL_TO_ID: return LABEL_TO_ID[key] try: value = int(float(key)) if value in [0, 1, 2]: return value except ValueError: pass raise ValueError(f"Unknown quality label: {x}. Expected 0/1/2 or Good/Usable/Reject.") def load_eyeq_csv(csv_path: str, images_dir: str) -> pd.DataFrame: df = pd.read_csv(csv_path) if "image" not in df.columns: raise ValueError(f"CSV must contain an 'image' column. Found columns: {list(df.columns)}") if "quality" not in df.columns: raise ValueError(f"CSV must contain a 'quality' column. Found columns: {list(df.columns)}") df = df[["image", "quality"]].copy() df["image"] = df["image"].astype(str) df["quality"] = df["quality"].apply(normalize_quality_label) images_dir = Path(images_dir) exists = df["image"].apply(lambda x: (images_dir / x).exists()) missing = int((~exists).sum()) if missing > 0: print(f"Warning: dropping {missing} rows with missing image files from {csv_path}") print(f" searched in: {images_dir}") df = df.loc[exists].reset_index(drop=True) if len(df) == 0: raise RuntimeError(f"No valid images found for {csv_path}. Searched in: {images_dir}") return df def build_transforms(img_size: int) -> Tuple[transforms.Compose, transforms.Compose]: train_tfms = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply([ transforms.ColorJitter( brightness=0.15, contrast=0.15, saturation=0.10, hue=0.02, ) ], p=0.8), transforms.RandomApply([ transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)) ], p=0.15), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) test_tfms = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) return train_tfms, test_tfms def build_model(model_name: str, num_classes: int, pretrained: bool): return timm.create_model( model_name, pretrained=pretrained, num_classes=num_classes, ) def train_one_epoch(model, loader, criterion, optimizer, scaler, device, epoch): model.train() running_loss = 0.0 all_preds = [] all_labels = [] pbar = tqdm(loader, desc=f"Train {epoch}", leave=False) for images, labels in pbar: images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(enabled=scaler is not None): logits = model(images) loss = criterion(logits, labels) if scaler is not None: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss.backward() optimizer.step() running_loss += loss.item() * images.size(0) preds = logits.argmax(dim=1) all_preds.extend(preds.detach().cpu().numpy().tolist()) all_labels.extend(labels.detach().cpu().numpy().tolist()) pbar.set_postfix(loss=f"{loss.item():.4f}") epoch_loss = running_loss / len(loader.dataset) acc = accuracy_score(all_labels, all_preds) bal_acc = balanced_accuracy_score(all_labels, all_preds) return epoch_loss, acc, bal_acc @torch.no_grad() def evaluate(model, loader, criterion, device, split_name="Test"): model.eval() running_loss = 0.0 all_preds = [] all_labels = [] pbar = tqdm(loader, desc=split_name, leave=False) for images, labels in pbar: images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) logits = model(images) loss = criterion(logits, labels) running_loss += loss.item() * images.size(0) preds = logits.argmax(dim=1) all_preds.extend(preds.detach().cpu().numpy().tolist()) all_labels.extend(labels.detach().cpu().numpy().tolist()) val_loss = running_loss / len(loader.dataset) acc = accuracy_score(all_labels, all_preds) bal_acc = balanced_accuracy_score(all_labels, all_preds) return val_loss, acc, bal_acc, np.array(all_labels), np.array(all_preds) def save_checkpoint(path, model, optimizer, scheduler, epoch, best_metric, args): torch.save({ "epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if scheduler is not None else None, "best_metric": best_metric, "args": vars(args), "id_to_label": ID_TO_LABEL, }, path) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--images_dir", type=str, required=True, help="EyePACS root containing train/ and test/ folders.") parser.add_argument("--csv_dir", type=str, required=True, help="Directory containing Label_EyeQ_train.csv and Label_EyeQ_test.csv.") parser.add_argument("--output_dir", type=str, default="./runs/eyeq_vit_base") parser.add_argument("--model", type=str, default="vit_base_patch16_224") parser.add_argument("--img_size", type=int, default=224) parser.add_argument("--pretrained", action="store_true", default=True) parser.add_argument("--no_pretrained", dest="pretrained", action="store_false") parser.add_argument("--epochs", type=int, default=30) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--lr", type=float, default=3e-5) parser.add_argument("--weight_decay", type=float, default=1e-4) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--amp", action="store_true", default=True) parser.add_argument("--no_amp", dest="amp", action="store_false") parser.add_argument("--class_weights", action="store_true", help="Use inverse-frequency class weights.") return parser.parse_args() def print_label_counts(name: str, df: pd.DataFrame): print(f"{name}: {len(df)}") for label_id in [0, 1, 2]: count = int((df["quality"] == label_id).sum()) print(f" {ID_TO_LABEL[label_id]} ({label_id}): {count}") def main(): args = parse_args() seed_everything(args.seed) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) images_root = Path(args.images_dir) csv_root = Path(args.csv_dir) train_images_dir = images_root / "train" test_images_dir = images_root / "test" train_csv = csv_root / "Label_EyeQ_train.csv" test_csv = csv_root / "Label_EyeQ_test.csv" train_df = load_eyeq_csv(str(train_csv), str(train_images_dir)) test_df = load_eyeq_csv(str(test_csv), str(test_images_dir)) train_tfms, test_tfms = build_transforms(args.img_size) train_ds = EyeQDataset(train_df, str(train_images_dir), train_tfms) test_ds = EyeQDataset(test_df, str(test_images_dir), test_tfms) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True, ) test_loader = DataLoader( test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = build_model(args.model, num_classes=3, pretrained=args.pretrained).to(device) if args.class_weights: counts = train_df["quality"].value_counts().sort_index().reindex([0, 1, 2], fill_value=1).values weights = counts.sum() / (len(counts) * counts) weights = torch.tensor(weights, dtype=torch.float32, device=device) criterion = nn.CrossEntropyLoss(weight=weights) print(f"Using class weights: {weights.detach().cpu().numpy().round(3).tolist()}") else: criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) scaler = torch.cuda.amp.GradScaler() if args.amp and device.type == "cuda" else None print("Dataset summary") print(f"Train CSV: {train_csv}") print(f"Test CSV: {test_csv}") print(f"Train images: {train_images_dir}") print(f"Test images: {test_images_dir}") print_label_counts("Train", train_df) print_label_counts("Test", test_df) print(f"Model: {args.model}") print(f"Device: {device}") best_bal_acc = -1.0 for epoch in range(1, args.epochs + 1): train_loss, train_acc, train_bal_acc = train_one_epoch( model, train_loader, criterion, optimizer, scaler, device, epoch ) test_loss, test_acc, test_bal_acc, y_true, y_pred = evaluate( model, test_loader, criterion, device, split_name="Test" ) scheduler.step() print( f"Epoch {epoch:03d}/{args.epochs} | " f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} train_bal_acc={train_bal_acc:.4f} | " f"test_loss={test_loss:.4f} test_acc={test_acc:.4f} test_bal_acc={test_bal_acc:.4f}" ) save_checkpoint(output_dir / "last.pt", model, optimizer, scheduler, epoch, best_bal_acc, args) if test_bal_acc > best_bal_acc: best_bal_acc = test_bal_acc best_path = output_dir / "best.pt" save_checkpoint(best_path, model, optimizer, scheduler, epoch, best_bal_acc, args) report = classification_report( y_true, y_pred, labels=[0, 1, 2], target_names=[ID_TO_LABEL[i] for i in [0, 1, 2]], digits=4, ) cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2]) with open(output_dir / "best_report.txt", "w") as f: f.write(f"Best epoch: {epoch}\n") f.write(f"Best test balanced accuracy: {best_bal_acc:.4f}\n\n") f.write(report) f.write("\nConfusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]\n") f.write(str(cm)) f.write("\n") print(f" Saved new best checkpoint: {best_path}") print(f"Training complete. Best test balanced accuracy: {best_bal_acc:.4f}") print(f"Outputs saved to: {output_dir}") if __name__ == "__main__": main()