""" ELIAS — Eyelid Lesion Intelligent Analysis System train.py Stratified 5-fold cross-validation training pipeline. Extracted and refactored from gemini_crossval_masked.ipynb. Usage: python train.py --data_dir ./data/data --output_dir ./outputs Data directory structure: data/data/ ├── epiblepharon/ (positive class) └── control/ (negative class) """ import argparse import os import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import torch import torch.nn as nn import torch.optim as optim from sklearn.metrics import auc, confusion_matrix, f1_score, roc_curve from sklearn.model_selection import StratifiedKFold from torch.utils.data import DataLoader, Subset from torchvision import datasets, models, transforms from model import build_elias_model # ── Hyperparameters ──────────────────────────────────────────────────────────── BATCH_SIZE = 32 EPOCHS = 20 LR = 1e-3 N_FOLDS = 5 RANDOM_STATE = 42 IMAGE_SIZE = 224 # ── Dataset Utilities ────────────────────────────────────────────────────────── class ApplyTransform(torch.utils.data.Dataset): """Wrapper to apply different transforms to train/val subsets.""" def __init__(self, subset, transform=None): self.subset = subset self.transform = transform def __getitem__(self, index): x, y = self.subset[index] if self.transform: x = self.transform(x) return x, y def __len__(self): return len(self.subset) def get_transforms(): """ Returns train and validation transform pipelines. Note: Grayscale(num_output_channels=3) is applied to normalize illumination variation across clinical photographs while maintaining 3-channel input compatibility with ImageNet-pretrained ResNet-18. """ train_tf = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.Grayscale(num_output_channels=3), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) val_tf = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) return train_tf, val_tf # ── Training & Evaluation ────────────────────────────────────────────────────── def train_one_epoch(model, loader, criterion, optimizer, device): model.train() running_loss = 0.0 for inputs, labels in loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) return running_loss / len(loader.dataset) @torch.no_grad() def evaluate(model, loader, device): model.eval() y_true, y_probs, y_pred = [], [], [] correct = 0 for inputs, labels in loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) probs = torch.softmax(outputs, dim=1)[:, 1] preds = torch.argmax(outputs, dim=1) correct += (preds == labels).sum().item() y_true.extend(labels.cpu().numpy()) y_probs.extend(probs.cpu().numpy()) y_pred.extend(preds.cpu().numpy()) acc = correct / len(loader.dataset) return acc, np.array(y_true), np.array(y_probs), np.array(y_pred) def compute_fold_metrics(y_true, y_probs, y_pred, class_names): """Compute sensitivity, specificity, F1, AUC from fold predictions.""" cm = confusion_matrix(y_true, y_pred) tn, fp, fn, tp = cm.ravel() sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0 specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0 f1 = f1_score(y_true, y_pred) fpr, tpr, _ = roc_curve(y_true, y_probs) fold_auc = auc(fpr, tpr) return { "sensitivity": sensitivity, "specificity": specificity, "f1": f1, "auc": fold_auc, "fpr": fpr, "tpr": tpr, "cm": cm, } # ── Plotting ─────────────────────────────────────────────────────────────────── def save_confusion_matrix(cm, class_names, fold_idx, output_dir): plt.figure(figsize=(6, 5)) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names, ) plt.title(f"Confusion Matrix — Fold {fold_idx + 1}") plt.ylabel("Actual"); plt.xlabel("Predicted") path = os.path.join(output_dir, f"confusion_matrix_fold_{fold_idx + 1}.png") plt.savefig(path, dpi=120, bbox_inches="tight") plt.close() def save_roc_curves(roc_data, output_dir): plt.figure(figsize=(8, 6)) for fold_idx, (fpr, tpr, fold_auc) in enumerate(roc_data): plt.plot(fpr, tpr, label=f"Fold {fold_idx + 1} (AUC = {fold_auc:.3f})") plt.plot([0, 1], [0, 1], "k--", linewidth=1) plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate") plt.title("ROC Curves — 5-Fold Cross-Validation") plt.legend(loc="lower right") path = os.path.join(output_dir, "roc_curves.png") plt.savefig(path, dpi=120, bbox_inches="tight") plt.close() print(f"[ELIAS] ROC curve saved → {path}") def save_learning_curves(all_train_loss, all_val_acc, output_dir): fig, axes = plt.subplots(1, 2, figsize=(12, 4)) axes[0].plot(np.mean(all_train_loss, axis=0), linewidth=2) axes[0].fill_between( range(EPOCHS), np.mean(all_train_loss, axis=0) - np.std(all_train_loss, axis=0), np.mean(all_train_loss, axis=0) + np.std(all_train_loss, axis=0), alpha=0.2, ) axes[0].set_title("Mean Training Loss (±SD)"); axes[0].set_xlabel("Epoch") axes[1].plot(np.mean(all_val_acc, axis=0), linewidth=2, color="tab:orange") axes[1].fill_between( range(EPOCHS), np.mean(all_val_acc, axis=0) - np.std(all_val_acc, axis=0), np.mean(all_val_acc, axis=0) + np.std(all_val_acc, axis=0), alpha=0.2, color="tab:orange", ) axes[1].set_title("Mean Validation Accuracy (±SD)"); axes[1].set_xlabel("Epoch") plt.tight_layout() path = os.path.join(output_dir, "learning_curves.png") plt.savefig(path, dpi=120, bbox_inches="tight") plt.close() print(f"[ELIAS] Learning curves saved → {path}") # ── Main ─────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="ELIAS 5-Fold Cross-Validation") parser.add_argument("--data_dir", type=str, default="./data/data") parser.add_argument("--output_dir", type=str, default="./outputs") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[ELIAS] Device: {device}") # ── Load dataset ───────────────────────────────────────────────────── full_dataset = datasets.ImageFolder(args.data_dir) labels = np.array(full_dataset.targets) class_names = full_dataset.classes print(f"[ELIAS] Classes: {class_names}") print(f"[ELIAS] Total samples: {len(full_dataset)}") train_tf, val_tf = get_transforms() # ── Cross-validation setup ──────────────────────────────────────────── skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_STATE) all_train_loss = np.zeros((N_FOLDS, EPOCHS)) all_val_acc = np.zeros((N_FOLDS, EPOCHS)) fold_results = [] roc_data = [] # ── Fold loop ───────────────────────────────────────────────────────── for fold, (train_ids, val_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)): print(f"\n{'='*20} FOLD {fold + 1}/{N_FOLDS} {'='*20}") print(f" Train: {len(train_ids)} | Val: {len(val_ids)}") train_data = ApplyTransform(Subset(full_dataset, train_ids), transform=train_tf) val_data = ApplyTransform(Subset(full_dataset, val_ids), transform=val_tf) train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) model = build_elias_model(num_classes=2, freeze_backbone=True).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.fc.parameters(), lr=LR) # Epoch loop for epoch in range(EPOCHS): train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) val_acc, _, _, _ = evaluate(model, val_loader, device) all_train_loss[fold, epoch] = train_loss all_val_acc[fold, epoch] = val_acc print( f" Epoch {epoch + 1:02d}/{EPOCHS} " f"loss={train_loss:.4f} val_acc={val_acc:.4f}" ) # Final fold evaluation val_acc, y_true, y_probs, y_pred = evaluate(model, val_loader, device) metrics = compute_fold_metrics(y_true, y_probs, y_pred, class_names) print( f"\n ✅ Fold {fold + 1} | " f"AUC={metrics['auc']:.4f} " f"Sen={metrics['sensitivity']:.3f} " f"Spe={metrics['specificity']:.3f} " f"F1={metrics['f1']:.3f}" ) fold_results.append({ "Fold": fold + 1, "Accuracy": val_acc, "Sensitivity": metrics["sensitivity"], "Specificity": metrics["specificity"], "F1 Score": metrics["f1"], "AUC": metrics["auc"], }) roc_data.append((metrics["fpr"], metrics["tpr"], metrics["auc"])) # Save confusion matrix per fold save_confusion_matrix(metrics["cm"], class_names, fold, args.output_dir) # Save best model checkpoint (fold-specific) ckpt_path = os.path.join(args.output_dir, f"pytorch_model_fold{fold + 1}.pt") torch.save(model.state_dict(), ckpt_path) # ── Aggregate results ───────────────────────────────────────────────── results_df = pd.DataFrame(fold_results) avg_row = results_df.mean(numeric_only=True).to_dict() avg_row["Fold"] = "Average" results_df = pd.concat([results_df, pd.DataFrame([avg_row])], ignore_index=True) excel_path = os.path.join(args.output_dir, "model_performance_results.xlsx") results_df.to_excel(excel_path, index=False) print(f"\n{'='*60}") print(" CROSS-VALIDATION SUMMARY") print(f"{'='*60}") print(results_df.to_string(index=False)) # ── Save plots ──────────────────────────────────────────────────────── save_roc_curves(roc_data, args.output_dir) save_learning_curves(all_train_loss, all_val_acc, args.output_dir) print(f"\n[ELIAS] All outputs saved to: {args.output_dir}") if __name__ == "__main__": main()