Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Train a CFP image-quality-control model on EyeQ / EyePACS-style data. | |
| Expected dataset format | |
| ----------------------- | |
| EyePACS/ | |
| train/ | |
| 10009_left.jpeg | |
| 10009_right.jpeg | |
| ... | |
| test/ | |
| ... | |
| data/ | |
| Label_EyeQ_train.csv | |
| Label_EyeQ_test.csv | |
| Label CSV format: | |
| ,image,quality,DR_grade | |
| 0,10009_left.jpeg,0,0 | |
| 1,10009_right.jpeg,0,0 | |
| 2,10014_left.jpeg,2,0 | |
| For EyeQ, this script assumes: | |
| quality = 0 -> Good | |
| quality = 1 -> Usable | |
| quality = 2 -> Reject | |
| DR_grade is ignored because this script trains only the image-quality model. | |
| Example | |
| ------- | |
| python EyeQ_train.py \ | |
| --images_dir /path/to/EyePACS \ | |
| --csv_dir /path/to/data \ | |
| --output_dir ./runs/eyeq_vit_base \ | |
| --epochs 30 \ | |
| --batch_size 32 \ | |
| --lr 3e-5 | |
| """ | |
| import argparse | |
| import random | |
| from pathlib import Path | |
| from typing import Dict, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| import timm | |
| from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix | |
| from tqdm import tqdm | |
| ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"} | |
| LABEL_TO_ID: Dict[str, int] = { | |
| "good": 0, | |
| "usable": 1, | |
| "reject": 2, | |
| "0": 0, | |
| "1": 1, | |
| "2": 2, | |
| } | |
| class EyeQDataset(Dataset): | |
| def __init__(self, df: pd.DataFrame, images_dir: str, transform=None): | |
| self.df = df.reset_index(drop=True) | |
| self.images_dir = Path(images_dir) | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| image_path = self.images_dir / str(row["image"]) | |
| image = Image.open(image_path).convert("RGB") | |
| label = int(row["quality"]) | |
| if self.transform is not None: | |
| image = self.transform(image) | |
| return image, label | |
| def seed_everything(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.benchmark = True | |
| def normalize_quality_label(x) -> int: | |
| key = str(x).strip().lower() | |
| if key in LABEL_TO_ID: | |
| return LABEL_TO_ID[key] | |
| try: | |
| value = int(float(key)) | |
| if value in [0, 1, 2]: | |
| return value | |
| except ValueError: | |
| pass | |
| raise ValueError(f"Unknown quality label: {x}. Expected 0/1/2 or Good/Usable/Reject.") | |
| def load_eyeq_csv(csv_path: str, images_dir: str) -> pd.DataFrame: | |
| df = pd.read_csv(csv_path) | |
| if "image" not in df.columns: | |
| raise ValueError(f"CSV must contain an 'image' column. Found columns: {list(df.columns)}") | |
| if "quality" not in df.columns: | |
| raise ValueError(f"CSV must contain a 'quality' column. Found columns: {list(df.columns)}") | |
| df = df[["image", "quality"]].copy() | |
| df["image"] = df["image"].astype(str) | |
| df["quality"] = df["quality"].apply(normalize_quality_label) | |
| images_dir = Path(images_dir) | |
| exists = df["image"].apply(lambda x: (images_dir / x).exists()) | |
| missing = int((~exists).sum()) | |
| if missing > 0: | |
| print(f"Warning: dropping {missing} rows with missing image files from {csv_path}") | |
| print(f" searched in: {images_dir}") | |
| df = df.loc[exists].reset_index(drop=True) | |
| if len(df) == 0: | |
| raise RuntimeError(f"No valid images found for {csv_path}. Searched in: {images_dir}") | |
| return df | |
| def build_transforms(img_size: int) -> Tuple[transforms.Compose, transforms.Compose]: | |
| train_tfms = transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomApply([ | |
| transforms.ColorJitter( | |
| brightness=0.15, | |
| contrast=0.15, | |
| saturation=0.10, | |
| hue=0.02, | |
| ) | |
| ], p=0.8), | |
| transforms.RandomApply([ | |
| transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)) | |
| ], p=0.15), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ]) | |
| test_tfms = transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ]) | |
| return train_tfms, test_tfms | |
| def build_model(model_name: str, num_classes: int, pretrained: bool): | |
| return timm.create_model( | |
| model_name, | |
| pretrained=pretrained, | |
| num_classes=num_classes, | |
| ) | |
| def train_one_epoch(model, loader, criterion, optimizer, scaler, device, epoch): | |
| model.train() | |
| running_loss = 0.0 | |
| all_preds = [] | |
| all_labels = [] | |
| pbar = tqdm(loader, desc=f"Train {epoch}", leave=False) | |
| for images, labels in pbar: | |
| images = images.to(device, non_blocking=True) | |
| labels = labels.to(device, non_blocking=True) | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.cuda.amp.autocast(enabled=scaler is not None): | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| if scaler is not None: | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * images.size(0) | |
| preds = logits.argmax(dim=1) | |
| all_preds.extend(preds.detach().cpu().numpy().tolist()) | |
| all_labels.extend(labels.detach().cpu().numpy().tolist()) | |
| pbar.set_postfix(loss=f"{loss.item():.4f}") | |
| epoch_loss = running_loss / len(loader.dataset) | |
| acc = accuracy_score(all_labels, all_preds) | |
| bal_acc = balanced_accuracy_score(all_labels, all_preds) | |
| return epoch_loss, acc, bal_acc | |
| def evaluate(model, loader, criterion, device, split_name="Test"): | |
| model.eval() | |
| running_loss = 0.0 | |
| all_preds = [] | |
| all_labels = [] | |
| pbar = tqdm(loader, desc=split_name, leave=False) | |
| for images, labels in pbar: | |
| images = images.to(device, non_blocking=True) | |
| labels = labels.to(device, non_blocking=True) | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| running_loss += loss.item() * images.size(0) | |
| preds = logits.argmax(dim=1) | |
| all_preds.extend(preds.detach().cpu().numpy().tolist()) | |
| all_labels.extend(labels.detach().cpu().numpy().tolist()) | |
| val_loss = running_loss / len(loader.dataset) | |
| acc = accuracy_score(all_labels, all_preds) | |
| bal_acc = balanced_accuracy_score(all_labels, all_preds) | |
| return val_loss, acc, bal_acc, np.array(all_labels), np.array(all_preds) | |
| def save_checkpoint(path, model, optimizer, scheduler, epoch, best_metric, args): | |
| torch.save({ | |
| "epoch": epoch, | |
| "model": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "scheduler": scheduler.state_dict() if scheduler is not None else None, | |
| "best_metric": best_metric, | |
| "args": vars(args), | |
| "id_to_label": ID_TO_LABEL, | |
| }, path) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--images_dir", type=str, required=True, help="EyePACS root containing train/ and test/ folders.") | |
| parser.add_argument("--csv_dir", type=str, required=True, help="Directory containing Label_EyeQ_train.csv and Label_EyeQ_test.csv.") | |
| parser.add_argument("--output_dir", type=str, default="./runs/eyeq_vit_base") | |
| parser.add_argument("--model", type=str, default="vit_base_patch16_224") | |
| parser.add_argument("--img_size", type=int, default=224) | |
| parser.add_argument("--pretrained", action="store_true", default=True) | |
| parser.add_argument("--no_pretrained", dest="pretrained", action="store_false") | |
| parser.add_argument("--epochs", type=int, default=30) | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("--num_workers", type=int, default=8) | |
| parser.add_argument("--lr", type=float, default=3e-5) | |
| parser.add_argument("--weight_decay", type=float, default=1e-4) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--amp", action="store_true", default=True) | |
| parser.add_argument("--no_amp", dest="amp", action="store_false") | |
| parser.add_argument("--class_weights", action="store_true", help="Use inverse-frequency class weights.") | |
| return parser.parse_args() | |
| def print_label_counts(name: str, df: pd.DataFrame): | |
| print(f"{name}: {len(df)}") | |
| for label_id in [0, 1, 2]: | |
| count = int((df["quality"] == label_id).sum()) | |
| print(f" {ID_TO_LABEL[label_id]} ({label_id}): {count}") | |
| def main(): | |
| args = parse_args() | |
| seed_everything(args.seed) | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| images_root = Path(args.images_dir) | |
| csv_root = Path(args.csv_dir) | |
| train_images_dir = images_root / "train" | |
| test_images_dir = images_root / "test" | |
| train_csv = csv_root / "Label_EyeQ_train.csv" | |
| test_csv = csv_root / "Label_EyeQ_test.csv" | |
| train_df = load_eyeq_csv(str(train_csv), str(train_images_dir)) | |
| test_df = load_eyeq_csv(str(test_csv), str(test_images_dir)) | |
| train_tfms, test_tfms = build_transforms(args.img_size) | |
| train_ds = EyeQDataset(train_df, str(train_images_dir), train_tfms) | |
| test_ds = EyeQDataset(test_df, str(test_images_dir), test_tfms) | |
| train_loader = DataLoader( | |
| train_ds, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| test_loader = DataLoader( | |
| test_ds, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = build_model(args.model, num_classes=3, pretrained=args.pretrained).to(device) | |
| if args.class_weights: | |
| counts = train_df["quality"].value_counts().sort_index().reindex([0, 1, 2], fill_value=1).values | |
| weights = counts.sum() / (len(counts) * counts) | |
| weights = torch.tensor(weights, dtype=torch.float32, device=device) | |
| criterion = nn.CrossEntropyLoss(weight=weights) | |
| print(f"Using class weights: {weights.detach().cpu().numpy().round(3).tolist()}") | |
| else: | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) | |
| scaler = torch.cuda.amp.GradScaler() if args.amp and device.type == "cuda" else None | |
| print("Dataset summary") | |
| print(f"Train CSV: {train_csv}") | |
| print(f"Test CSV: {test_csv}") | |
| print(f"Train images: {train_images_dir}") | |
| print(f"Test images: {test_images_dir}") | |
| print_label_counts("Train", train_df) | |
| print_label_counts("Test", test_df) | |
| print(f"Model: {args.model}") | |
| print(f"Device: {device}") | |
| best_bal_acc = -1.0 | |
| for epoch in range(1, args.epochs + 1): | |
| train_loss, train_acc, train_bal_acc = train_one_epoch( | |
| model, train_loader, criterion, optimizer, scaler, device, epoch | |
| ) | |
| test_loss, test_acc, test_bal_acc, y_true, y_pred = evaluate( | |
| model, test_loader, criterion, device, split_name="Test" | |
| ) | |
| scheduler.step() | |
| print( | |
| f"Epoch {epoch:03d}/{args.epochs} | " | |
| f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} train_bal_acc={train_bal_acc:.4f} | " | |
| f"test_loss={test_loss:.4f} test_acc={test_acc:.4f} test_bal_acc={test_bal_acc:.4f}" | |
| ) | |
| save_checkpoint(output_dir / "last.pt", model, optimizer, scheduler, epoch, best_bal_acc, args) | |
| if test_bal_acc > best_bal_acc: | |
| best_bal_acc = test_bal_acc | |
| best_path = output_dir / "best.pt" | |
| save_checkpoint(best_path, model, optimizer, scheduler, epoch, best_bal_acc, args) | |
| report = classification_report( | |
| y_true, | |
| y_pred, | |
| labels=[0, 1, 2], | |
| target_names=[ID_TO_LABEL[i] for i in [0, 1, 2]], | |
| digits=4, | |
| ) | |
| cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2]) | |
| with open(output_dir / "best_report.txt", "w") as f: | |
| f.write(f"Best epoch: {epoch}\n") | |
| f.write(f"Best test balanced accuracy: {best_bal_acc:.4f}\n\n") | |
| f.write(report) | |
| f.write("\nConfusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]\n") | |
| f.write(str(cm)) | |
| f.write("\n") | |
| print(f" Saved new best checkpoint: {best_path}") | |
| print(f"Training complete. Best test balanced accuracy: {best_bal_acc:.4f}") | |
| print(f"Outputs saved to: {output_dir}") | |
| if __name__ == "__main__": | |
| main() | |