DiaFoot.AI-v2 / scripts /run_cross_val.py
RuthvikBandari's picture
Upload scripts/run_cross_val.py with huggingface_hub
88a3f32 verified
"""DiaFoot.AI v2 — 5-Fold Cross Validation.
Trains U-Net++ segmentation on 5 folds for robust performance estimation.
Reports mean +/- std across folds.
Usage:
python scripts/run_cross_val.py --fold 0 --device cuda --epochs 50
(run with --fold 0,1,2,3,4 as SLURM array job)
"""
from __future__ import annotations
import argparse
import csv
import json
import logging
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.data.augmentation import get_train_transforms, get_val_transforms
from src.data.torch_dataset import DFUDataset
from src.evaluation.metrics import (
aggregate_metrics,
compute_segmentation_metrics,
)
from src.models.unetpp import build_unetpp
from src.training.losses import DiceCELoss
from src.training.schedulers import CosineAnnealingWithWarmup
from src.training.trainer import TrainConfig, Trainer
def create_fold_splits(
train_csv: str | Path,
val_csv: str | Path,
fold: int,
n_folds: int = 5,
output_dir: str | Path = "data/splits/cv",
filter_classes: list[str] | None = None,
) -> tuple[Path, Path]:
"""Create train/val split for a specific fold.
Combines train+val, then splits into n_folds.
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load all data
all_rows = []
fieldnames = None
for csv_path in [train_csv, val_csv]:
with open(csv_path) as f:
reader = csv.DictReader(f)
if fieldnames is None:
fieldnames = reader.fieldnames
for row in reader:
if filter_classes and row.get("class", "") not in filter_classes:
continue
all_rows.append(row)
# Shuffle deterministically
rng = np.random.RandomState(42)
indices = list(range(len(all_rows)))
rng.shuffle(indices)
# Split into folds
fold_size = len(indices) // n_folds
val_start = fold * fold_size
val_end = val_start + fold_size if fold < n_folds - 1 else len(indices)
val_indices = set(indices[val_start:val_end])
train_indices = [i for i in indices if i not in val_indices]
# Write fold CSVs
fold_train = output_dir / f"train_fold{fold}.csv"
fold_val = output_dir / f"val_fold{fold}.csv"
for out_path, idx_list in [(fold_train, train_indices), (fold_val, list(val_indices))]:
with open(out_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames or [])
writer.writeheader()
for i in idx_list:
writer.writerow(all_rows[i])
return fold_train, fold_val
def train_fold(fold: int, args: argparse.Namespace) -> dict:
"""Train and evaluate one fold."""
logger = logging.getLogger(f"fold_{fold}")
logger.info("Starting fold %d/%d", fold + 1, 5)
# Create fold splits
fold_train, fold_val = create_fold_splits(
Path(args.splits_dir) / "train.csv",
Path(args.splits_dir) / "val.csv",
fold=fold,
filter_classes=["dfu", "non_dfu"],
)
train_ds = DFUDataset(str(fold_train), transform=get_train_transforms())
val_ds = DFUDataset(str(fold_val), transform=get_val_transforms())
train_loader = torch.utils.data.DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=args.num_workers > 0,
drop_last=True,
)
val_loader = torch.utils.data.DataLoader(
val_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=args.num_workers > 0,
)
logger.info("Fold %d: %d train, %d val samples", fold, len(train_ds), len(val_ds))
# Model
model = build_unetpp(
encoder_name="efficientnet-b4",
encoder_weights="imagenet",
classes=1,
decoder_attention_type="scse",
)
loss_fn = DiceCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = CosineAnnealingWithWarmup(
optimizer,
warmup_epochs=5,
max_epochs=args.epochs,
)
torch.manual_seed(42 + fold)
config = TrainConfig(
epochs=args.epochs,
precision="bf16-mixed",
compile_model=False,
gradient_clip=1.0,
checkpoint_dir=f"checkpoints/cv_fold{fold}",
monitor_metric="val/loss",
monitor_mode="min",
device=args.device,
early_stopping_patience=15,
)
trainer = Trainer(model=model, config=config)
trainer.fit(train_loader, val_loader, loss_fn, optimizer, scheduler)
# Evaluate on fold validation set
model.eval()
fold_metrics = []
with torch.no_grad():
for batch in val_loader:
images = batch["image"].to(args.device)
masks = batch["mask"].numpy()
logits = model(images)
preds = (torch.sigmoid(logits) > 0.5).squeeze(1).cpu().numpy().astype(np.uint8)
for i in range(len(images)):
m = compute_segmentation_metrics(preds[i], masks[i])
fold_metrics.append(m)
summary = aggregate_metrics(fold_metrics)
dice = summary.get("dice", {}).get("mean", 0)
iou = summary.get("iou", {}).get("mean", 0)
logger.info("Fold %d results: Dice=%.4f, IoU=%.4f", fold, dice, iou)
return {"fold": fold, "dice": dice, "iou": iou, "n_val": len(val_ds)}
def main() -> None:
"""Run cross-validation."""
parser = argparse.ArgumentParser(description="5-Fold Cross Validation")
parser.add_argument("--fold", type=int, required=True, help="Fold index (0-4)")
parser.add_argument("--splits-dir", type=str, default="data/splits")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%H:%M:%S",
)
result = train_fold(args.fold, args)
# Save fold result
output = Path(f"results/cv_fold{args.fold}.json")
output.parent.mkdir(parents=True, exist_ok=True)
with open(output, "w") as f:
json.dump(result, f, indent=2)
if __name__ == "__main__":
main()