| """ |
| ELIAS β Eyelid Lesion Intelligent Analysis System |
| train.py |
| |
| Stratified 5-fold cross-validation training pipeline. |
| Extracted and refactored from gemini_crossval_masked.ipynb. |
| |
| Usage: |
| python train.py --data_dir ./data/data --output_dir ./outputs |
| |
| Data directory structure: |
| data/data/ |
| βββ epiblepharon/ (positive class) |
| βββ control/ (negative class) |
| """ |
|
|
| import argparse |
| import os |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import seaborn as sns |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from sklearn.metrics import auc, confusion_matrix, f1_score, roc_curve |
| from sklearn.model_selection import StratifiedKFold |
| from torch.utils.data import DataLoader, Subset |
| from torchvision import datasets, models, transforms |
|
|
| from model import build_elias_model |
|
|
|
|
| |
| BATCH_SIZE = 32 |
| EPOCHS = 20 |
| LR = 1e-3 |
| N_FOLDS = 5 |
| RANDOM_STATE = 42 |
| IMAGE_SIZE = 224 |
|
|
|
|
| |
|
|
| class ApplyTransform(torch.utils.data.Dataset): |
| """Wrapper to apply different transforms to train/val subsets.""" |
| def __init__(self, subset, transform=None): |
| self.subset = subset |
| self.transform = transform |
|
|
| def __getitem__(self, index): |
| x, y = self.subset[index] |
| if self.transform: |
| x = self.transform(x) |
| return x, y |
|
|
| def __len__(self): |
| return len(self.subset) |
|
|
|
|
| def get_transforms(): |
| """ |
| Returns train and validation transform pipelines. |
| |
| Note: Grayscale(num_output_channels=3) is applied to normalize |
| illumination variation across clinical photographs while maintaining |
| 3-channel input compatibility with ImageNet-pretrained ResNet-18. |
| """ |
| train_tf = transforms.Compose([ |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
| transforms.Grayscale(num_output_channels=3), |
| transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
| val_tf = transforms.Compose([ |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
| transforms.Grayscale(num_output_channels=3), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
| return train_tf, val_tf |
|
|
|
|
| |
|
|
| def train_one_epoch(model, loader, criterion, optimizer, device): |
| model.train() |
| running_loss = 0.0 |
| for inputs, labels in loader: |
| inputs, labels = inputs.to(device), labels.to(device) |
| optimizer.zero_grad() |
| outputs = model(inputs) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
| running_loss += loss.item() * inputs.size(0) |
| return running_loss / len(loader.dataset) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device): |
| model.eval() |
| y_true, y_probs, y_pred = [], [], [] |
| correct = 0 |
| for inputs, labels in loader: |
| inputs, labels = inputs.to(device), labels.to(device) |
| outputs = model(inputs) |
| probs = torch.softmax(outputs, dim=1)[:, 1] |
| preds = torch.argmax(outputs, dim=1) |
| correct += (preds == labels).sum().item() |
| y_true.extend(labels.cpu().numpy()) |
| y_probs.extend(probs.cpu().numpy()) |
| y_pred.extend(preds.cpu().numpy()) |
| acc = correct / len(loader.dataset) |
| return acc, np.array(y_true), np.array(y_probs), np.array(y_pred) |
|
|
|
|
| def compute_fold_metrics(y_true, y_probs, y_pred, class_names): |
| """Compute sensitivity, specificity, F1, AUC from fold predictions.""" |
| cm = confusion_matrix(y_true, y_pred) |
| tn, fp, fn, tp = cm.ravel() |
| sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0 |
| f1 = f1_score(y_true, y_pred) |
| fpr, tpr, _ = roc_curve(y_true, y_probs) |
| fold_auc = auc(fpr, tpr) |
| return { |
| "sensitivity": sensitivity, |
| "specificity": specificity, |
| "f1": f1, |
| "auc": fold_auc, |
| "fpr": fpr, |
| "tpr": tpr, |
| "cm": cm, |
| } |
|
|
|
|
| |
|
|
| def save_confusion_matrix(cm, class_names, fold_idx, output_dir): |
| plt.figure(figsize=(6, 5)) |
| sns.heatmap( |
| cm, annot=True, fmt="d", cmap="Blues", |
| xticklabels=class_names, yticklabels=class_names, |
| ) |
| plt.title(f"Confusion Matrix β Fold {fold_idx + 1}") |
| plt.ylabel("Actual"); plt.xlabel("Predicted") |
| path = os.path.join(output_dir, f"confusion_matrix_fold_{fold_idx + 1}.png") |
| plt.savefig(path, dpi=120, bbox_inches="tight") |
| plt.close() |
|
|
|
|
| def save_roc_curves(roc_data, output_dir): |
| plt.figure(figsize=(8, 6)) |
| for fold_idx, (fpr, tpr, fold_auc) in enumerate(roc_data): |
| plt.plot(fpr, tpr, label=f"Fold {fold_idx + 1} (AUC = {fold_auc:.3f})") |
| plt.plot([0, 1], [0, 1], "k--", linewidth=1) |
| plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate") |
| plt.title("ROC Curves β 5-Fold Cross-Validation") |
| plt.legend(loc="lower right") |
| path = os.path.join(output_dir, "roc_curves.png") |
| plt.savefig(path, dpi=120, bbox_inches="tight") |
| plt.close() |
| print(f"[ELIAS] ROC curve saved β {path}") |
|
|
|
|
| def save_learning_curves(all_train_loss, all_val_acc, output_dir): |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) |
| axes[0].plot(np.mean(all_train_loss, axis=0), linewidth=2) |
| axes[0].fill_between( |
| range(EPOCHS), |
| np.mean(all_train_loss, axis=0) - np.std(all_train_loss, axis=0), |
| np.mean(all_train_loss, axis=0) + np.std(all_train_loss, axis=0), |
| alpha=0.2, |
| ) |
| axes[0].set_title("Mean Training Loss (Β±SD)"); axes[0].set_xlabel("Epoch") |
|
|
| axes[1].plot(np.mean(all_val_acc, axis=0), linewidth=2, color="tab:orange") |
| axes[1].fill_between( |
| range(EPOCHS), |
| np.mean(all_val_acc, axis=0) - np.std(all_val_acc, axis=0), |
| np.mean(all_val_acc, axis=0) + np.std(all_val_acc, axis=0), |
| alpha=0.2, color="tab:orange", |
| ) |
| axes[1].set_title("Mean Validation Accuracy (Β±SD)"); axes[1].set_xlabel("Epoch") |
|
|
| plt.tight_layout() |
| path = os.path.join(output_dir, "learning_curves.png") |
| plt.savefig(path, dpi=120, bbox_inches="tight") |
| plt.close() |
| print(f"[ELIAS] Learning curves saved β {path}") |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="ELIAS 5-Fold Cross-Validation") |
| parser.add_argument("--data_dir", type=str, default="./data/data") |
| parser.add_argument("--output_dir", type=str, default="./outputs") |
| args = parser.parse_args() |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"[ELIAS] Device: {device}") |
|
|
| |
| full_dataset = datasets.ImageFolder(args.data_dir) |
| labels = np.array(full_dataset.targets) |
| class_names = full_dataset.classes |
| print(f"[ELIAS] Classes: {class_names}") |
| print(f"[ELIAS] Total samples: {len(full_dataset)}") |
|
|
| train_tf, val_tf = get_transforms() |
|
|
| |
| skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_STATE) |
|
|
| all_train_loss = np.zeros((N_FOLDS, EPOCHS)) |
| all_val_acc = np.zeros((N_FOLDS, EPOCHS)) |
| fold_results = [] |
| roc_data = [] |
|
|
| |
| for fold, (train_ids, val_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)): |
| print(f"\n{'='*20} FOLD {fold + 1}/{N_FOLDS} {'='*20}") |
| print(f" Train: {len(train_ids)} | Val: {len(val_ids)}") |
|
|
| train_data = ApplyTransform(Subset(full_dataset, train_ids), transform=train_tf) |
| val_data = ApplyTransform(Subset(full_dataset, val_ids), transform=val_tf) |
| train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) |
| val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) |
|
|
| model = build_elias_model(num_classes=2, freeze_backbone=True).to(device) |
| criterion = nn.CrossEntropyLoss() |
| optimizer = optim.Adam(model.fc.parameters(), lr=LR) |
|
|
| |
| for epoch in range(EPOCHS): |
| train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) |
| val_acc, _, _, _ = evaluate(model, val_loader, device) |
| all_train_loss[fold, epoch] = train_loss |
| all_val_acc[fold, epoch] = val_acc |
| print( |
| f" Epoch {epoch + 1:02d}/{EPOCHS} " |
| f"loss={train_loss:.4f} val_acc={val_acc:.4f}" |
| ) |
|
|
| |
| val_acc, y_true, y_probs, y_pred = evaluate(model, val_loader, device) |
| metrics = compute_fold_metrics(y_true, y_probs, y_pred, class_names) |
|
|
| print( |
| f"\n β
Fold {fold + 1} | " |
| f"AUC={metrics['auc']:.4f} " |
| f"Sen={metrics['sensitivity']:.3f} " |
| f"Spe={metrics['specificity']:.3f} " |
| f"F1={metrics['f1']:.3f}" |
| ) |
|
|
| fold_results.append({ |
| "Fold": fold + 1, |
| "Accuracy": val_acc, |
| "Sensitivity": metrics["sensitivity"], |
| "Specificity": metrics["specificity"], |
| "F1 Score": metrics["f1"], |
| "AUC": metrics["auc"], |
| }) |
| roc_data.append((metrics["fpr"], metrics["tpr"], metrics["auc"])) |
|
|
| |
| save_confusion_matrix(metrics["cm"], class_names, fold, args.output_dir) |
|
|
| |
| ckpt_path = os.path.join(args.output_dir, f"pytorch_model_fold{fold + 1}.pt") |
| torch.save(model.state_dict(), ckpt_path) |
|
|
| |
| results_df = pd.DataFrame(fold_results) |
| avg_row = results_df.mean(numeric_only=True).to_dict() |
| avg_row["Fold"] = "Average" |
| results_df = pd.concat([results_df, pd.DataFrame([avg_row])], ignore_index=True) |
|
|
| excel_path = os.path.join(args.output_dir, "model_performance_results.xlsx") |
| results_df.to_excel(excel_path, index=False) |
|
|
| print(f"\n{'='*60}") |
| print(" CROSS-VALIDATION SUMMARY") |
| print(f"{'='*60}") |
| print(results_df.to_string(index=False)) |
|
|
| |
| save_roc_curves(roc_data, args.output_dir) |
| save_learning_curves(all_train_loss, all_val_acc, args.output_dir) |
|
|
| print(f"\n[ELIAS] All outputs saved to: {args.output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|