ELIAS-epiblepharon / train.py
cahsu's picture
Upload 5 files
fa50b6c verified
"""
ELIAS β€” Eyelid Lesion Intelligent Analysis System
train.py
Stratified 5-fold cross-validation training pipeline.
Extracted and refactored from gemini_crossval_masked.ipynb.
Usage:
python train.py --data_dir ./data/data --output_dir ./outputs
Data directory structure:
data/data/
β”œβ”€β”€ epiblepharon/ (positive class)
└── control/ (negative class)
"""
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import auc, confusion_matrix, f1_score, roc_curve
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, models, transforms
from model import build_elias_model
# ── Hyperparameters ────────────────────────────────────────────────────────────
BATCH_SIZE = 32
EPOCHS = 20
LR = 1e-3
N_FOLDS = 5
RANDOM_STATE = 42
IMAGE_SIZE = 224
# ── Dataset Utilities ──────────────────────────────────────────────────────────
class ApplyTransform(torch.utils.data.Dataset):
"""Wrapper to apply different transforms to train/val subsets."""
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
x, y = self.subset[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.subset)
def get_transforms():
"""
Returns train and validation transform pipelines.
Note: Grayscale(num_output_channels=3) is applied to normalize
illumination variation across clinical photographs while maintaining
3-channel input compatibility with ImageNet-pretrained ResNet-18.
"""
train_tf = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.Grayscale(num_output_channels=3),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
val_tf = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
return train_tf, val_tf
# ── Training & Evaluation ──────────────────────────────────────────────────────
def train_one_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
return running_loss / len(loader.dataset)
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
y_true, y_probs, y_pred = [], [], []
correct = 0
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
probs = torch.softmax(outputs, dim=1)[:, 1]
preds = torch.argmax(outputs, dim=1)
correct += (preds == labels).sum().item()
y_true.extend(labels.cpu().numpy())
y_probs.extend(probs.cpu().numpy())
y_pred.extend(preds.cpu().numpy())
acc = correct / len(loader.dataset)
return acc, np.array(y_true), np.array(y_probs), np.array(y_pred)
def compute_fold_metrics(y_true, y_probs, y_pred, class_names):
"""Compute sensitivity, specificity, F1, AUC from fold predictions."""
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
f1 = f1_score(y_true, y_pred)
fpr, tpr, _ = roc_curve(y_true, y_probs)
fold_auc = auc(fpr, tpr)
return {
"sensitivity": sensitivity,
"specificity": specificity,
"f1": f1,
"auc": fold_auc,
"fpr": fpr,
"tpr": tpr,
"cm": cm,
}
# ── Plotting ───────────────────────────────────────────────────────────────────
def save_confusion_matrix(cm, class_names, fold_idx, output_dir):
plt.figure(figsize=(6, 5))
sns.heatmap(
cm, annot=True, fmt="d", cmap="Blues",
xticklabels=class_names, yticklabels=class_names,
)
plt.title(f"Confusion Matrix β€” Fold {fold_idx + 1}")
plt.ylabel("Actual"); plt.xlabel("Predicted")
path = os.path.join(output_dir, f"confusion_matrix_fold_{fold_idx + 1}.png")
plt.savefig(path, dpi=120, bbox_inches="tight")
plt.close()
def save_roc_curves(roc_data, output_dir):
plt.figure(figsize=(8, 6))
for fold_idx, (fpr, tpr, fold_auc) in enumerate(roc_data):
plt.plot(fpr, tpr, label=f"Fold {fold_idx + 1} (AUC = {fold_auc:.3f})")
plt.plot([0, 1], [0, 1], "k--", linewidth=1)
plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
plt.title("ROC Curves β€” 5-Fold Cross-Validation")
plt.legend(loc="lower right")
path = os.path.join(output_dir, "roc_curves.png")
plt.savefig(path, dpi=120, bbox_inches="tight")
plt.close()
print(f"[ELIAS] ROC curve saved β†’ {path}")
def save_learning_curves(all_train_loss, all_val_acc, output_dir):
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(np.mean(all_train_loss, axis=0), linewidth=2)
axes[0].fill_between(
range(EPOCHS),
np.mean(all_train_loss, axis=0) - np.std(all_train_loss, axis=0),
np.mean(all_train_loss, axis=0) + np.std(all_train_loss, axis=0),
alpha=0.2,
)
axes[0].set_title("Mean Training Loss (Β±SD)"); axes[0].set_xlabel("Epoch")
axes[1].plot(np.mean(all_val_acc, axis=0), linewidth=2, color="tab:orange")
axes[1].fill_between(
range(EPOCHS),
np.mean(all_val_acc, axis=0) - np.std(all_val_acc, axis=0),
np.mean(all_val_acc, axis=0) + np.std(all_val_acc, axis=0),
alpha=0.2, color="tab:orange",
)
axes[1].set_title("Mean Validation Accuracy (Β±SD)"); axes[1].set_xlabel("Epoch")
plt.tight_layout()
path = os.path.join(output_dir, "learning_curves.png")
plt.savefig(path, dpi=120, bbox_inches="tight")
plt.close()
print(f"[ELIAS] Learning curves saved β†’ {path}")
# ── Main ───────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ELIAS 5-Fold Cross-Validation")
parser.add_argument("--data_dir", type=str, default="./data/data")
parser.add_argument("--output_dir", type=str, default="./outputs")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[ELIAS] Device: {device}")
# ── Load dataset ─────────────────────────────────────────────────────
full_dataset = datasets.ImageFolder(args.data_dir)
labels = np.array(full_dataset.targets)
class_names = full_dataset.classes
print(f"[ELIAS] Classes: {class_names}")
print(f"[ELIAS] Total samples: {len(full_dataset)}")
train_tf, val_tf = get_transforms()
# ── Cross-validation setup ────────────────────────────────────────────
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_STATE)
all_train_loss = np.zeros((N_FOLDS, EPOCHS))
all_val_acc = np.zeros((N_FOLDS, EPOCHS))
fold_results = []
roc_data = []
# ── Fold loop ─────────────────────────────────────────────────────────
for fold, (train_ids, val_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
print(f"\n{'='*20} FOLD {fold + 1}/{N_FOLDS} {'='*20}")
print(f" Train: {len(train_ids)} | Val: {len(val_ids)}")
train_data = ApplyTransform(Subset(full_dataset, train_ids), transform=train_tf)
val_data = ApplyTransform(Subset(full_dataset, val_ids), transform=val_tf)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
model = build_elias_model(num_classes=2, freeze_backbone=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=LR)
# Epoch loop
for epoch in range(EPOCHS):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_acc, _, _, _ = evaluate(model, val_loader, device)
all_train_loss[fold, epoch] = train_loss
all_val_acc[fold, epoch] = val_acc
print(
f" Epoch {epoch + 1:02d}/{EPOCHS} "
f"loss={train_loss:.4f} val_acc={val_acc:.4f}"
)
# Final fold evaluation
val_acc, y_true, y_probs, y_pred = evaluate(model, val_loader, device)
metrics = compute_fold_metrics(y_true, y_probs, y_pred, class_names)
print(
f"\n βœ… Fold {fold + 1} | "
f"AUC={metrics['auc']:.4f} "
f"Sen={metrics['sensitivity']:.3f} "
f"Spe={metrics['specificity']:.3f} "
f"F1={metrics['f1']:.3f}"
)
fold_results.append({
"Fold": fold + 1,
"Accuracy": val_acc,
"Sensitivity": metrics["sensitivity"],
"Specificity": metrics["specificity"],
"F1 Score": metrics["f1"],
"AUC": metrics["auc"],
})
roc_data.append((metrics["fpr"], metrics["tpr"], metrics["auc"]))
# Save confusion matrix per fold
save_confusion_matrix(metrics["cm"], class_names, fold, args.output_dir)
# Save best model checkpoint (fold-specific)
ckpt_path = os.path.join(args.output_dir, f"pytorch_model_fold{fold + 1}.pt")
torch.save(model.state_dict(), ckpt_path)
# ── Aggregate results ─────────────────────────────────────────────────
results_df = pd.DataFrame(fold_results)
avg_row = results_df.mean(numeric_only=True).to_dict()
avg_row["Fold"] = "Average"
results_df = pd.concat([results_df, pd.DataFrame([avg_row])], ignore_index=True)
excel_path = os.path.join(args.output_dir, "model_performance_results.xlsx")
results_df.to_excel(excel_path, index=False)
print(f"\n{'='*60}")
print(" CROSS-VALIDATION SUMMARY")
print(f"{'='*60}")
print(results_df.to_string(index=False))
# ── Save plots ────────────────────────────────────────────────────────
save_roc_curves(roc_data, args.output_dir)
save_learning_curves(all_train_loss, all_val_acc, args.output_dir)
print(f"\n[ELIAS] All outputs saved to: {args.output_dir}")
if __name__ == "__main__":
main()