| | """DiaFoot.AI v2 — 5-Fold Cross Validation. |
| | |
| | Trains U-Net++ segmentation on 5 folds for robust performance estimation. |
| | Reports mean +/- std across folds. |
| | |
| | Usage: |
| | python scripts/run_cross_val.py --fold 0 --device cuda --epochs 50 |
| | (run with --fold 0,1,2,3,4 as SLURM array job) |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import csv |
| | import json |
| | import logging |
| | import sys |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| |
|
| | from src.data.augmentation import get_train_transforms, get_val_transforms |
| | from src.data.torch_dataset import DFUDataset |
| | from src.evaluation.metrics import ( |
| | aggregate_metrics, |
| | compute_segmentation_metrics, |
| | ) |
| | from src.models.unetpp import build_unetpp |
| | from src.training.losses import DiceCELoss |
| | from src.training.schedulers import CosineAnnealingWithWarmup |
| | from src.training.trainer import TrainConfig, Trainer |
| |
|
| |
|
| | def create_fold_splits( |
| | train_csv: str | Path, |
| | val_csv: str | Path, |
| | fold: int, |
| | n_folds: int = 5, |
| | output_dir: str | Path = "data/splits/cv", |
| | filter_classes: list[str] | None = None, |
| | ) -> tuple[Path, Path]: |
| | """Create train/val split for a specific fold. |
| | |
| | Combines train+val, then splits into n_folds. |
| | """ |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | all_rows = [] |
| | fieldnames = None |
| | for csv_path in [train_csv, val_csv]: |
| | with open(csv_path) as f: |
| | reader = csv.DictReader(f) |
| | if fieldnames is None: |
| | fieldnames = reader.fieldnames |
| | for row in reader: |
| | if filter_classes and row.get("class", "") not in filter_classes: |
| | continue |
| | all_rows.append(row) |
| |
|
| | |
| | rng = np.random.RandomState(42) |
| | indices = list(range(len(all_rows))) |
| | rng.shuffle(indices) |
| |
|
| | |
| | fold_size = len(indices) // n_folds |
| | val_start = fold * fold_size |
| | val_end = val_start + fold_size if fold < n_folds - 1 else len(indices) |
| |
|
| | val_indices = set(indices[val_start:val_end]) |
| | train_indices = [i for i in indices if i not in val_indices] |
| |
|
| | |
| | fold_train = output_dir / f"train_fold{fold}.csv" |
| | fold_val = output_dir / f"val_fold{fold}.csv" |
| |
|
| | for out_path, idx_list in [(fold_train, train_indices), (fold_val, list(val_indices))]: |
| | with open(out_path, "w", newline="") as f: |
| | writer = csv.DictWriter(f, fieldnames=fieldnames or []) |
| | writer.writeheader() |
| | for i in idx_list: |
| | writer.writerow(all_rows[i]) |
| |
|
| | return fold_train, fold_val |
| |
|
| |
|
| | def train_fold(fold: int, args: argparse.Namespace) -> dict: |
| | """Train and evaluate one fold.""" |
| | logger = logging.getLogger(f"fold_{fold}") |
| | logger.info("Starting fold %d/%d", fold + 1, 5) |
| |
|
| | |
| | fold_train, fold_val = create_fold_splits( |
| | Path(args.splits_dir) / "train.csv", |
| | Path(args.splits_dir) / "val.csv", |
| | fold=fold, |
| | filter_classes=["dfu", "non_dfu"], |
| | ) |
| |
|
| | train_ds = DFUDataset(str(fold_train), transform=get_train_transforms()) |
| | val_ds = DFUDataset(str(fold_val), transform=get_val_transforms()) |
| |
|
| | train_loader = torch.utils.data.DataLoader( |
| | train_ds, |
| | batch_size=args.batch_size, |
| | shuffle=True, |
| | num_workers=args.num_workers, |
| | pin_memory=True, |
| | persistent_workers=args.num_workers > 0, |
| | drop_last=True, |
| | ) |
| | val_loader = torch.utils.data.DataLoader( |
| | val_ds, |
| | batch_size=args.batch_size, |
| | shuffle=False, |
| | num_workers=args.num_workers, |
| | pin_memory=True, |
| | persistent_workers=args.num_workers > 0, |
| | ) |
| | logger.info("Fold %d: %d train, %d val samples", fold, len(train_ds), len(val_ds)) |
| |
|
| | |
| | model = build_unetpp( |
| | encoder_name="efficientnet-b4", |
| | encoder_weights="imagenet", |
| | classes=1, |
| | decoder_attention_type="scse", |
| | ) |
| |
|
| | loss_fn = DiceCELoss() |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) |
| | scheduler = CosineAnnealingWithWarmup( |
| | optimizer, |
| | warmup_epochs=5, |
| | max_epochs=args.epochs, |
| | ) |
| | torch.manual_seed(42 + fold) |
| |
|
| | config = TrainConfig( |
| | epochs=args.epochs, |
| | precision="bf16-mixed", |
| | compile_model=False, |
| | gradient_clip=1.0, |
| | checkpoint_dir=f"checkpoints/cv_fold{fold}", |
| | monitor_metric="val/loss", |
| | monitor_mode="min", |
| | device=args.device, |
| | early_stopping_patience=15, |
| | ) |
| |
|
| | trainer = Trainer(model=model, config=config) |
| | trainer.fit(train_loader, val_loader, loss_fn, optimizer, scheduler) |
| |
|
| | |
| | model.eval() |
| | fold_metrics = [] |
| | with torch.no_grad(): |
| | for batch in val_loader: |
| | images = batch["image"].to(args.device) |
| | masks = batch["mask"].numpy() |
| | logits = model(images) |
| | preds = (torch.sigmoid(logits) > 0.5).squeeze(1).cpu().numpy().astype(np.uint8) |
| | for i in range(len(images)): |
| | m = compute_segmentation_metrics(preds[i], masks[i]) |
| | fold_metrics.append(m) |
| |
|
| | summary = aggregate_metrics(fold_metrics) |
| | dice = summary.get("dice", {}).get("mean", 0) |
| | iou = summary.get("iou", {}).get("mean", 0) |
| | logger.info("Fold %d results: Dice=%.4f, IoU=%.4f", fold, dice, iou) |
| |
|
| | return {"fold": fold, "dice": dice, "iou": iou, "n_val": len(val_ds)} |
| |
|
| |
|
| | def main() -> None: |
| | """Run cross-validation.""" |
| | parser = argparse.ArgumentParser(description="5-Fold Cross Validation") |
| | parser.add_argument("--fold", type=int, required=True, help="Fold index (0-4)") |
| | parser.add_argument("--splits-dir", type=str, default="data/splits") |
| | parser.add_argument("--device", type=str, default="cuda") |
| | parser.add_argument("--epochs", type=int, default=50) |
| | parser.add_argument("--batch-size", type=int, default=16) |
| | parser.add_argument("--num-workers", type=int, default=8) |
| | parser.add_argument("--verbose", action="store_true") |
| | args = parser.parse_args() |
| |
|
| | logging.basicConfig( |
| | level=logging.DEBUG if args.verbose else logging.INFO, |
| | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| | datefmt="%H:%M:%S", |
| | ) |
| |
|
| | result = train_fold(args.fold, args) |
| |
|
| | |
| | output = Path(f"results/cv_fold{args.fold}.json") |
| | output.parent.mkdir(parents=True, exist_ok=True) |
| | with open(output, "w") as f: |
| | json.dump(result, f, indent=2) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|