EyeQ / train.py
farrell236's picture
add src
d0344ce
#!/usr/bin/env python3
"""
Train a CFP image-quality-control model on EyeQ / EyePACS-style data.
Expected dataset format
-----------------------
EyePACS/
train/
10009_left.jpeg
10009_right.jpeg
...
test/
...
data/
Label_EyeQ_train.csv
Label_EyeQ_test.csv
Label CSV format:
,image,quality,DR_grade
0,10009_left.jpeg,0,0
1,10009_right.jpeg,0,0
2,10014_left.jpeg,2,0
For EyeQ, this script assumes:
quality = 0 -> Good
quality = 1 -> Usable
quality = 2 -> Reject
DR_grade is ignored because this script trains only the image-quality model.
Example
-------
python EyeQ_train.py \
--images_dir /path/to/EyePACS \
--csv_dir /path/to/data \
--output_dir ./runs/eyeq_vit_base \
--epochs 30 \
--batch_size 32 \
--lr 3e-5
"""
import argparse
import random
from pathlib import Path
from typing import Dict, Tuple
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm
ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"}
LABEL_TO_ID: Dict[str, int] = {
"good": 0,
"usable": 1,
"reject": 2,
"0": 0,
"1": 1,
"2": 2,
}
class EyeQDataset(Dataset):
def __init__(self, df: pd.DataFrame, images_dir: str, transform=None):
self.df = df.reset_index(drop=True)
self.images_dir = Path(images_dir)
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
image_path = self.images_dir / str(row["image"])
image = Image.open(image_path).convert("RGB")
label = int(row["quality"])
if self.transform is not None:
image = self.transform(image)
return image, label
def seed_everything(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
def normalize_quality_label(x) -> int:
key = str(x).strip().lower()
if key in LABEL_TO_ID:
return LABEL_TO_ID[key]
try:
value = int(float(key))
if value in [0, 1, 2]:
return value
except ValueError:
pass
raise ValueError(f"Unknown quality label: {x}. Expected 0/1/2 or Good/Usable/Reject.")
def load_eyeq_csv(csv_path: str, images_dir: str) -> pd.DataFrame:
df = pd.read_csv(csv_path)
if "image" not in df.columns:
raise ValueError(f"CSV must contain an 'image' column. Found columns: {list(df.columns)}")
if "quality" not in df.columns:
raise ValueError(f"CSV must contain a 'quality' column. Found columns: {list(df.columns)}")
df = df[["image", "quality"]].copy()
df["image"] = df["image"].astype(str)
df["quality"] = df["quality"].apply(normalize_quality_label)
images_dir = Path(images_dir)
exists = df["image"].apply(lambda x: (images_dir / x).exists())
missing = int((~exists).sum())
if missing > 0:
print(f"Warning: dropping {missing} rows with missing image files from {csv_path}")
print(f" searched in: {images_dir}")
df = df.loc[exists].reset_index(drop=True)
if len(df) == 0:
raise RuntimeError(f"No valid images found for {csv_path}. Searched in: {images_dir}")
return df
def build_transforms(img_size: int) -> Tuple[transforms.Compose, transforms.Compose]:
train_tfms = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([
transforms.ColorJitter(
brightness=0.15,
contrast=0.15,
saturation=0.10,
hue=0.02,
)
], p=0.8),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))
], p=0.15),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
test_tfms = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
return train_tfms, test_tfms
def build_model(model_name: str, num_classes: int, pretrained: bool):
return timm.create_model(
model_name,
pretrained=pretrained,
num_classes=num_classes,
)
def train_one_epoch(model, loader, criterion, optimizer, scaler, device, epoch):
model.train()
running_loss = 0.0
all_preds = []
all_labels = []
pbar = tqdm(loader, desc=f"Train {epoch}", leave=False)
for images, labels in pbar:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(enabled=scaler is not None):
logits = model(images)
loss = criterion(logits, labels)
if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
preds = logits.argmax(dim=1)
all_preds.extend(preds.detach().cpu().numpy().tolist())
all_labels.extend(labels.detach().cpu().numpy().tolist())
pbar.set_postfix(loss=f"{loss.item():.4f}")
epoch_loss = running_loss / len(loader.dataset)
acc = accuracy_score(all_labels, all_preds)
bal_acc = balanced_accuracy_score(all_labels, all_preds)
return epoch_loss, acc, bal_acc
@torch.no_grad()
def evaluate(model, loader, criterion, device, split_name="Test"):
model.eval()
running_loss = 0.0
all_preds = []
all_labels = []
pbar = tqdm(loader, desc=split_name, leave=False)
for images, labels in pbar:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
logits = model(images)
loss = criterion(logits, labels)
running_loss += loss.item() * images.size(0)
preds = logits.argmax(dim=1)
all_preds.extend(preds.detach().cpu().numpy().tolist())
all_labels.extend(labels.detach().cpu().numpy().tolist())
val_loss = running_loss / len(loader.dataset)
acc = accuracy_score(all_labels, all_preds)
bal_acc = balanced_accuracy_score(all_labels, all_preds)
return val_loss, acc, bal_acc, np.array(all_labels), np.array(all_preds)
def save_checkpoint(path, model, optimizer, scheduler, epoch, best_metric, args):
torch.save({
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict() if scheduler is not None else None,
"best_metric": best_metric,
"args": vars(args),
"id_to_label": ID_TO_LABEL,
}, path)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--images_dir", type=str, required=True, help="EyePACS root containing train/ and test/ folders.")
parser.add_argument("--csv_dir", type=str, required=True, help="Directory containing Label_EyeQ_train.csv and Label_EyeQ_test.csv.")
parser.add_argument("--output_dir", type=str, default="./runs/eyeq_vit_base")
parser.add_argument("--model", type=str, default="vit_base_patch16_224")
parser.add_argument("--img_size", type=int, default=224)
parser.add_argument("--pretrained", action="store_true", default=True)
parser.add_argument("--no_pretrained", dest="pretrained", action="store_false")
parser.add_argument("--epochs", type=int, default=30)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--amp", action="store_true", default=True)
parser.add_argument("--no_amp", dest="amp", action="store_false")
parser.add_argument("--class_weights", action="store_true", help="Use inverse-frequency class weights.")
return parser.parse_args()
def print_label_counts(name: str, df: pd.DataFrame):
print(f"{name}: {len(df)}")
for label_id in [0, 1, 2]:
count = int((df["quality"] == label_id).sum())
print(f" {ID_TO_LABEL[label_id]} ({label_id}): {count}")
def main():
args = parse_args()
seed_everything(args.seed)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
images_root = Path(args.images_dir)
csv_root = Path(args.csv_dir)
train_images_dir = images_root / "train"
test_images_dir = images_root / "test"
train_csv = csv_root / "Label_EyeQ_train.csv"
test_csv = csv_root / "Label_EyeQ_test.csv"
train_df = load_eyeq_csv(str(train_csv), str(train_images_dir))
test_df = load_eyeq_csv(str(test_csv), str(test_images_dir))
train_tfms, test_tfms = build_transforms(args.img_size)
train_ds = EyeQDataset(train_df, str(train_images_dir), train_tfms)
test_ds = EyeQDataset(test_df, str(test_images_dir), test_tfms)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
)
test_loader = DataLoader(
test_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model(args.model, num_classes=3, pretrained=args.pretrained).to(device)
if args.class_weights:
counts = train_df["quality"].value_counts().sort_index().reindex([0, 1, 2], fill_value=1).values
weights = counts.sum() / (len(counts) * counts)
weights = torch.tensor(weights, dtype=torch.float32, device=device)
criterion = nn.CrossEntropyLoss(weight=weights)
print(f"Using class weights: {weights.detach().cpu().numpy().round(3).tolist()}")
else:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
scaler = torch.cuda.amp.GradScaler() if args.amp and device.type == "cuda" else None
print("Dataset summary")
print(f"Train CSV: {train_csv}")
print(f"Test CSV: {test_csv}")
print(f"Train images: {train_images_dir}")
print(f"Test images: {test_images_dir}")
print_label_counts("Train", train_df)
print_label_counts("Test", test_df)
print(f"Model: {args.model}")
print(f"Device: {device}")
best_bal_acc = -1.0
for epoch in range(1, args.epochs + 1):
train_loss, train_acc, train_bal_acc = train_one_epoch(
model, train_loader, criterion, optimizer, scaler, device, epoch
)
test_loss, test_acc, test_bal_acc, y_true, y_pred = evaluate(
model, test_loader, criterion, device, split_name="Test"
)
scheduler.step()
print(
f"Epoch {epoch:03d}/{args.epochs} | "
f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} train_bal_acc={train_bal_acc:.4f} | "
f"test_loss={test_loss:.4f} test_acc={test_acc:.4f} test_bal_acc={test_bal_acc:.4f}"
)
save_checkpoint(output_dir / "last.pt", model, optimizer, scheduler, epoch, best_bal_acc, args)
if test_bal_acc > best_bal_acc:
best_bal_acc = test_bal_acc
best_path = output_dir / "best.pt"
save_checkpoint(best_path, model, optimizer, scheduler, epoch, best_bal_acc, args)
report = classification_report(
y_true,
y_pred,
labels=[0, 1, 2],
target_names=[ID_TO_LABEL[i] for i in [0, 1, 2]],
digits=4,
)
cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
with open(output_dir / "best_report.txt", "w") as f:
f.write(f"Best epoch: {epoch}\n")
f.write(f"Best test balanced accuracy: {best_bal_acc:.4f}\n\n")
f.write(report)
f.write("\nConfusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]\n")
f.write(str(cm))
f.write("\n")
print(f" Saved new best checkpoint: {best_path}")
print(f"Training complete. Best test balanced accuracy: {best_bal_acc:.4f}")
print(f"Outputs saved to: {output_dir}")
if __name__ == "__main__":
main()