Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Evaluate an EyeQ CFP image-quality-control model on Label_EyeQ_test.csv. | |
| Example | |
| ------- | |
| python EyeQ_test.py \ | |
| --images_dir /data/MIDS/datasets/retina/EyePACS \ | |
| --csv_dir /data/MIDS/datasets/retina/EyeQ/data \ | |
| --checkpoint ./checkpoints/eyeq_vit_base/best.pt \ | |
| --output_dir ./checkpoints/eyeq_vit_base/test_eval \ | |
| --batch_size 32 \ | |
| --num_workers 24 | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| from typing import Dict, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| import timm | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| balanced_accuracy_score, | |
| classification_report, | |
| confusion_matrix, | |
| ) | |
| from tqdm import tqdm | |
| ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"} | |
| LABEL_TO_ID: Dict[str, int] = { | |
| "good": 0, | |
| "usable": 1, | |
| "reject": 2, | |
| "0": 0, | |
| "1": 1, | |
| "2": 2, | |
| } | |
| class EyeQDataset(Dataset): | |
| def __init__(self, df: pd.DataFrame, images_dir: str, transform=None): | |
| self.df = df.reset_index(drop=True) | |
| self.images_dir = Path(images_dir) | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| image_name = str(row["image"]) | |
| image_path = self.images_dir / image_name | |
| image = Image.open(image_path).convert("RGB") | |
| label = int(row["quality"]) | |
| if self.transform is not None: | |
| image = self.transform(image) | |
| return image, label, image_name | |
| def normalize_quality_label(x) -> int: | |
| key = str(x).strip().lower() | |
| if key in LABEL_TO_ID: | |
| return LABEL_TO_ID[key] | |
| try: | |
| value = int(float(key)) | |
| if value in [0, 1, 2]: | |
| return value | |
| except ValueError: | |
| pass | |
| raise ValueError(f"Unknown quality label: {x}. Expected 0/1/2 or Good/Usable/Reject.") | |
| def load_eyeq_csv(csv_path: str, images_dir: str) -> pd.DataFrame: | |
| df = pd.read_csv(csv_path) | |
| if "image" not in df.columns: | |
| raise ValueError(f"CSV must contain an 'image' column. Found columns: {list(df.columns)}") | |
| if "quality" not in df.columns: | |
| raise ValueError(f"CSV must contain a 'quality' column. Found columns: {list(df.columns)}") | |
| # Keep DR_grade if present for optional downstream inspection. | |
| keep_cols = ["image", "quality"] | |
| if "DR_grade" in df.columns: | |
| keep_cols.append("DR_grade") | |
| df = df[keep_cols].copy() | |
| df["image"] = df["image"].astype(str) | |
| df["quality"] = df["quality"].apply(normalize_quality_label) | |
| images_dir = Path(images_dir) | |
| exists = df["image"].apply(lambda x: (images_dir / x).exists()) | |
| missing = int((~exists).sum()) | |
| if missing > 0: | |
| print(f"Warning: dropping {missing} rows with missing image files from {csv_path}") | |
| print(f" searched in: {images_dir}") | |
| df = df.loc[exists].reset_index(drop=True) | |
| if len(df) == 0: | |
| raise RuntimeError(f"No valid images found for {csv_path}. Searched in: {images_dir}") | |
| return df | |
| def build_transform(img_size: int): | |
| return transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225), | |
| ), | |
| ]) | |
| def load_model(checkpoint_path: str, device: torch.device): | |
| ckpt = torch.load(checkpoint_path, map_location="cpu") | |
| ckpt_args = ckpt.get("args", {}) | |
| model_name = ckpt_args.get("model", "vit_base_patch16_224") | |
| img_size = int(ckpt_args.get("img_size", 224)) | |
| id_to_label = ckpt.get("id_to_label", ID_TO_LABEL) | |
| id_to_label = {int(k): str(v) for k, v in id_to_label.items()} | |
| model = timm.create_model( | |
| model_name, | |
| pretrained=False, | |
| num_classes=len(id_to_label), | |
| ) | |
| model.load_state_dict(ckpt["model"], strict=True) | |
| model.to(device) | |
| model.eval() | |
| return model, id_to_label, model_name, img_size, ckpt | |
| def evaluate(model, loader, criterion, device, amp=False): | |
| model.eval() | |
| running_loss = 0.0 | |
| all_labels = [] | |
| all_preds = [] | |
| all_probs = [] | |
| all_images = [] | |
| for images, labels, image_names in tqdm(loader, desc="Test"): | |
| images = images.to(device, non_blocking=True) | |
| labels = labels.to(device, non_blocking=True) | |
| with torch.cuda.amp.autocast(enabled=amp and device.type == "cuda"): | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| probs = torch.softmax(logits, dim=1) | |
| preds = probs.argmax(dim=1) | |
| running_loss += loss.item() * images.size(0) | |
| all_labels.extend(labels.detach().cpu().numpy().tolist()) | |
| all_preds.extend(preds.detach().cpu().numpy().tolist()) | |
| all_probs.extend(probs.detach().cpu().numpy().tolist()) | |
| all_images.extend(list(image_names)) | |
| test_loss = running_loss / len(loader.dataset) | |
| y_true = np.array(all_labels) | |
| y_pred = np.array(all_preds) | |
| probs = np.array(all_probs) | |
| acc = accuracy_score(y_true, y_pred) | |
| bal_acc = balanced_accuracy_score(y_true, y_pred) | |
| return test_loss, acc, bal_acc, y_true, y_pred, probs, all_images | |
| def print_label_counts(name: str, df: pd.DataFrame): | |
| print(f"{name}: {len(df)}") | |
| for label_id in [0, 1, 2]: | |
| count = int((df["quality"] == label_id).sum()) | |
| print(f" {ID_TO_LABEL[label_id]} ({label_id}): {count}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--images_dir", type=str, required=True, | |
| help="EyePACS root containing train/ and test/ folders.") | |
| parser.add_argument("--csv_dir", type=str, required=True, | |
| help="Directory containing Label_EyeQ_test.csv.") | |
| parser.add_argument("--checkpoint", type=str, default="./checkpoints/eyeq_vit_base/best.pt") | |
| parser.add_argument("--output_dir", type=str, default=None) | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("--num_workers", type=int, default=8) | |
| parser.add_argument("--amp", action="store_true", default=True) | |
| parser.add_argument("--no_amp", dest="amp", action="store_false") | |
| parser.add_argument("--cpu", action="store_true") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| images_root = Path(args.images_dir) | |
| csv_root = Path(args.csv_dir) | |
| checkpoint_path = Path(args.checkpoint) | |
| test_images_dir = images_root / "test" | |
| test_csv = csv_root / "Label_EyeQ_test.csv" | |
| if args.output_dir is None: | |
| output_dir = checkpoint_path.parent / "test_eval" | |
| else: | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| if not test_images_dir.exists(): | |
| raise FileNotFoundError(f"Test image directory not found: {test_images_dir}") | |
| if not test_csv.exists(): | |
| raise FileNotFoundError(f"Test CSV not found: {test_csv}") | |
| device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") | |
| model, id_to_label, model_name, img_size, ckpt = load_model(str(checkpoint_path), device) | |
| transform = build_transform(img_size) | |
| test_df = load_eyeq_csv(str(test_csv), str(test_images_dir)) | |
| test_ds = EyeQDataset(test_df, str(test_images_dir), transform) | |
| test_loader = DataLoader( | |
| test_ds, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=(device.type == "cuda"), | |
| persistent_workers=(args.num_workers > 0), | |
| ) | |
| criterion = nn.CrossEntropyLoss() | |
| print("Evaluation summary") | |
| print(f"Checkpoint: {checkpoint_path}") | |
| print(f"Test CSV: {test_csv}") | |
| print(f"Test images: {test_images_dir}") | |
| print(f"Output dir: {output_dir}") | |
| print(f"Model: {model_name}") | |
| print(f"Image size: {img_size}") | |
| print(f"Device: {device}") | |
| print(f"Labels: {id_to_label}") | |
| print_label_counts("Test", test_df) | |
| test_loss, acc, bal_acc, y_true, y_pred, probs, image_names = evaluate( | |
| model=model, | |
| loader=test_loader, | |
| criterion=criterion, | |
| device=device, | |
| amp=args.amp, | |
| ) | |
| target_names = [id_to_label[i] for i in [0, 1, 2]] | |
| report = classification_report( | |
| y_true, | |
| y_pred, | |
| labels=[0, 1, 2], | |
| target_names=target_names, | |
| digits=4, | |
| ) | |
| cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2]) | |
| print() | |
| print(f"test_loss={test_loss:.4f}") | |
| print(f"test_acc={acc:.4f}") | |
| print(f"test_bal_acc={bal_acc:.4f}") | |
| print() | |
| print(report) | |
| print("Confusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]") | |
| print(cm) | |
| # Save text report | |
| with open(output_dir / "test_report.txt", "w") as f: | |
| f.write(f"Checkpoint: {checkpoint_path}\n") | |
| f.write(f"Test CSV: {test_csv}\n") | |
| f.write(f"Test images: {test_images_dir}\n") | |
| f.write(f"Model: {model_name}\n") | |
| f.write(f"Image size: {img_size}\n") | |
| f.write(f"Device: {device}\n\n") | |
| f.write(f"test_loss={test_loss:.6f}\n") | |
| f.write(f"test_acc={acc:.6f}\n") | |
| f.write(f"test_bal_acc={bal_acc:.6f}\n\n") | |
| f.write(report) | |
| f.write("\nConfusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]\n") | |
| f.write(str(cm)) | |
| f.write("\n") | |
| # Save confusion matrix CSV | |
| cm_df = pd.DataFrame( | |
| cm, | |
| index=[f"true_{name}" for name in target_names], | |
| columns=[f"pred_{name}" for name in target_names], | |
| ) | |
| cm_df.to_csv(output_dir / "test_confusion_matrix.csv") | |
| # Save per-image predictions | |
| pred_df = test_df.copy() | |
| pred_df["pred_quality"] = y_pred | |
| pred_df["true_label"] = [id_to_label[int(x)] for x in y_true] | |
| pred_df["pred_label"] = [id_to_label[int(x)] for x in y_pred] | |
| pred_df["prob_good"] = probs[:, 0] | |
| pred_df["prob_usable"] = probs[:, 1] | |
| pred_df["prob_reject"] = probs[:, 2] | |
| pred_df["correct"] = pred_df["quality"].values == pred_df["pred_quality"].values | |
| pred_df.to_csv(output_dir / "test_predictions.csv", index=False) | |
| print() | |
| print(f"Saved report: {output_dir / 'test_report.txt'}") | |
| print(f"Saved confusion: {output_dir / 'test_confusion_matrix.csv'}") | |
| print(f"Saved predictions: {output_dir / 'test_predictions.csv'}") | |
| if __name__ == "__main__": | |
| main() |