#!/usr/bin/env python3 """ Evaluate an EyeQ CFP image-quality-control model on Label_EyeQ_test.csv. Example ------- python EyeQ_test.py \ --images_dir /data/MIDS/datasets/retina/EyePACS \ --csv_dir /data/MIDS/datasets/retina/EyeQ/data \ --checkpoint ./checkpoints/eyeq_vit_base/best.pt \ --output_dir ./checkpoints/eyeq_vit_base/test_eval \ --batch_size 32 \ --num_workers 24 """ import argparse from pathlib import Path from typing import Dict, Tuple import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms import timm from sklearn.metrics import ( accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix, ) from tqdm import tqdm ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"} LABEL_TO_ID: Dict[str, int] = { "good": 0, "usable": 1, "reject": 2, "0": 0, "1": 1, "2": 2, } class EyeQDataset(Dataset): def __init__(self, df: pd.DataFrame, images_dir: str, transform=None): self.df = df.reset_index(drop=True) self.images_dir = Path(images_dir) self.transform = transform def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] image_name = str(row["image"]) image_path = self.images_dir / image_name image = Image.open(image_path).convert("RGB") label = int(row["quality"]) if self.transform is not None: image = self.transform(image) return image, label, image_name def normalize_quality_label(x) -> int: key = str(x).strip().lower() if key in LABEL_TO_ID: return LABEL_TO_ID[key] try: value = int(float(key)) if value in [0, 1, 2]: return value except ValueError: pass raise ValueError(f"Unknown quality label: {x}. Expected 0/1/2 or Good/Usable/Reject.") def load_eyeq_csv(csv_path: str, images_dir: str) -> pd.DataFrame: df = pd.read_csv(csv_path) if "image" not in df.columns: raise ValueError(f"CSV must contain an 'image' column. Found columns: {list(df.columns)}") if "quality" not in df.columns: raise ValueError(f"CSV must contain a 'quality' column. Found columns: {list(df.columns)}") # Keep DR_grade if present for optional downstream inspection. keep_cols = ["image", "quality"] if "DR_grade" in df.columns: keep_cols.append("DR_grade") df = df[keep_cols].copy() df["image"] = df["image"].astype(str) df["quality"] = df["quality"].apply(normalize_quality_label) images_dir = Path(images_dir) exists = df["image"].apply(lambda x: (images_dir / x).exists()) missing = int((~exists).sum()) if missing > 0: print(f"Warning: dropping {missing} rows with missing image files from {csv_path}") print(f" searched in: {images_dir}") df = df.loc[exists].reset_index(drop=True) if len(df) == 0: raise RuntimeError(f"No valid images found for {csv_path}. Searched in: {images_dir}") return df def build_transform(img_size: int): return transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), ), ]) def load_model(checkpoint_path: str, device: torch.device): ckpt = torch.load(checkpoint_path, map_location="cpu") ckpt_args = ckpt.get("args", {}) model_name = ckpt_args.get("model", "vit_base_patch16_224") img_size = int(ckpt_args.get("img_size", 224)) id_to_label = ckpt.get("id_to_label", ID_TO_LABEL) id_to_label = {int(k): str(v) for k, v in id_to_label.items()} model = timm.create_model( model_name, pretrained=False, num_classes=len(id_to_label), ) model.load_state_dict(ckpt["model"], strict=True) model.to(device) model.eval() return model, id_to_label, model_name, img_size, ckpt @torch.no_grad() def evaluate(model, loader, criterion, device, amp=False): model.eval() running_loss = 0.0 all_labels = [] all_preds = [] all_probs = [] all_images = [] for images, labels, image_names in tqdm(loader, desc="Test"): images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) with torch.cuda.amp.autocast(enabled=amp and device.type == "cuda"): logits = model(images) loss = criterion(logits, labels) probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) running_loss += loss.item() * images.size(0) all_labels.extend(labels.detach().cpu().numpy().tolist()) all_preds.extend(preds.detach().cpu().numpy().tolist()) all_probs.extend(probs.detach().cpu().numpy().tolist()) all_images.extend(list(image_names)) test_loss = running_loss / len(loader.dataset) y_true = np.array(all_labels) y_pred = np.array(all_preds) probs = np.array(all_probs) acc = accuracy_score(y_true, y_pred) bal_acc = balanced_accuracy_score(y_true, y_pred) return test_loss, acc, bal_acc, y_true, y_pred, probs, all_images def print_label_counts(name: str, df: pd.DataFrame): print(f"{name}: {len(df)}") for label_id in [0, 1, 2]: count = int((df["quality"] == label_id).sum()) print(f" {ID_TO_LABEL[label_id]} ({label_id}): {count}") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--images_dir", type=str, required=True, help="EyePACS root containing train/ and test/ folders.") parser.add_argument("--csv_dir", type=str, required=True, help="Directory containing Label_EyeQ_test.csv.") parser.add_argument("--checkpoint", type=str, default="./checkpoints/eyeq_vit_base/best.pt") parser.add_argument("--output_dir", type=str, default=None) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--amp", action="store_true", default=True) parser.add_argument("--no_amp", dest="amp", action="store_false") parser.add_argument("--cpu", action="store_true") return parser.parse_args() def main(): args = parse_args() images_root = Path(args.images_dir) csv_root = Path(args.csv_dir) checkpoint_path = Path(args.checkpoint) test_images_dir = images_root / "test" test_csv = csv_root / "Label_EyeQ_test.csv" if args.output_dir is None: output_dir = checkpoint_path.parent / "test_eval" else: output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") if not test_images_dir.exists(): raise FileNotFoundError(f"Test image directory not found: {test_images_dir}") if not test_csv.exists(): raise FileNotFoundError(f"Test CSV not found: {test_csv}") device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") model, id_to_label, model_name, img_size, ckpt = load_model(str(checkpoint_path), device) transform = build_transform(img_size) test_df = load_eyeq_csv(str(test_csv), str(test_images_dir)) test_ds = EyeQDataset(test_df, str(test_images_dir), transform) test_loader = DataLoader( test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=(device.type == "cuda"), persistent_workers=(args.num_workers > 0), ) criterion = nn.CrossEntropyLoss() print("Evaluation summary") print(f"Checkpoint: {checkpoint_path}") print(f"Test CSV: {test_csv}") print(f"Test images: {test_images_dir}") print(f"Output dir: {output_dir}") print(f"Model: {model_name}") print(f"Image size: {img_size}") print(f"Device: {device}") print(f"Labels: {id_to_label}") print_label_counts("Test", test_df) test_loss, acc, bal_acc, y_true, y_pred, probs, image_names = evaluate( model=model, loader=test_loader, criterion=criterion, device=device, amp=args.amp, ) target_names = [id_to_label[i] for i in [0, 1, 2]] report = classification_report( y_true, y_pred, labels=[0, 1, 2], target_names=target_names, digits=4, ) cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2]) print() print(f"test_loss={test_loss:.4f}") print(f"test_acc={acc:.4f}") print(f"test_bal_acc={bal_acc:.4f}") print() print(report) print("Confusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]") print(cm) # Save text report with open(output_dir / "test_report.txt", "w") as f: f.write(f"Checkpoint: {checkpoint_path}\n") f.write(f"Test CSV: {test_csv}\n") f.write(f"Test images: {test_images_dir}\n") f.write(f"Model: {model_name}\n") f.write(f"Image size: {img_size}\n") f.write(f"Device: {device}\n\n") f.write(f"test_loss={test_loss:.6f}\n") f.write(f"test_acc={acc:.6f}\n") f.write(f"test_bal_acc={bal_acc:.6f}\n\n") f.write(report) f.write("\nConfusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]\n") f.write(str(cm)) f.write("\n") # Save confusion matrix CSV cm_df = pd.DataFrame( cm, index=[f"true_{name}" for name in target_names], columns=[f"pred_{name}" for name in target_names], ) cm_df.to_csv(output_dir / "test_confusion_matrix.csv") # Save per-image predictions pred_df = test_df.copy() pred_df["pred_quality"] = y_pred pred_df["true_label"] = [id_to_label[int(x)] for x in y_true] pred_df["pred_label"] = [id_to_label[int(x)] for x in y_pred] pred_df["prob_good"] = probs[:, 0] pred_df["prob_usable"] = probs[:, 1] pred_df["prob_reject"] = probs[:, 2] pred_df["correct"] = pred_df["quality"].values == pred_df["pred_quality"].values pred_df.to_csv(output_dir / "test_predictions.csv", index=False) print() print(f"Saved report: {output_dir / 'test_report.txt'}") print(f"Saved confusion: {output_dir / 'test_confusion_matrix.csv'}") print(f"Saved predictions: {output_dir / 'test_predictions.csv'}") if __name__ == "__main__": main()