RuthvikBandari commited on
Commit
88a3f32
·
verified ·
1 Parent(s): c5b096b

Upload scripts/run_cross_val.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/run_cross_val.py +208 -0
scripts/run_cross_val.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DiaFoot.AI v2 — 5-Fold Cross Validation.
2
+
3
+ Trains U-Net++ segmentation on 5 folds for robust performance estimation.
4
+ Reports mean +/- std across folds.
5
+
6
+ Usage:
7
+ python scripts/run_cross_val.py --fold 0 --device cuda --epochs 50
8
+ (run with --fold 0,1,2,3,4 as SLURM array job)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import csv
15
+ import json
16
+ import logging
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
24
+
25
+ from src.data.augmentation import get_train_transforms, get_val_transforms
26
+ from src.data.torch_dataset import DFUDataset
27
+ from src.evaluation.metrics import (
28
+ aggregate_metrics,
29
+ compute_segmentation_metrics,
30
+ )
31
+ from src.models.unetpp import build_unetpp
32
+ from src.training.losses import DiceCELoss
33
+ from src.training.schedulers import CosineAnnealingWithWarmup
34
+ from src.training.trainer import TrainConfig, Trainer
35
+
36
+
37
+ def create_fold_splits(
38
+ train_csv: str | Path,
39
+ val_csv: str | Path,
40
+ fold: int,
41
+ n_folds: int = 5,
42
+ output_dir: str | Path = "data/splits/cv",
43
+ filter_classes: list[str] | None = None,
44
+ ) -> tuple[Path, Path]:
45
+ """Create train/val split for a specific fold.
46
+
47
+ Combines train+val, then splits into n_folds.
48
+ """
49
+ output_dir = Path(output_dir)
50
+ output_dir.mkdir(parents=True, exist_ok=True)
51
+
52
+ # Load all data
53
+ all_rows = []
54
+ fieldnames = None
55
+ for csv_path in [train_csv, val_csv]:
56
+ with open(csv_path) as f:
57
+ reader = csv.DictReader(f)
58
+ if fieldnames is None:
59
+ fieldnames = reader.fieldnames
60
+ for row in reader:
61
+ if filter_classes and row.get("class", "") not in filter_classes:
62
+ continue
63
+ all_rows.append(row)
64
+
65
+ # Shuffle deterministically
66
+ rng = np.random.RandomState(42)
67
+ indices = list(range(len(all_rows)))
68
+ rng.shuffle(indices)
69
+
70
+ # Split into folds
71
+ fold_size = len(indices) // n_folds
72
+ val_start = fold * fold_size
73
+ val_end = val_start + fold_size if fold < n_folds - 1 else len(indices)
74
+
75
+ val_indices = set(indices[val_start:val_end])
76
+ train_indices = [i for i in indices if i not in val_indices]
77
+
78
+ # Write fold CSVs
79
+ fold_train = output_dir / f"train_fold{fold}.csv"
80
+ fold_val = output_dir / f"val_fold{fold}.csv"
81
+
82
+ for out_path, idx_list in [(fold_train, train_indices), (fold_val, list(val_indices))]:
83
+ with open(out_path, "w", newline="") as f:
84
+ writer = csv.DictWriter(f, fieldnames=fieldnames or [])
85
+ writer.writeheader()
86
+ for i in idx_list:
87
+ writer.writerow(all_rows[i])
88
+
89
+ return fold_train, fold_val
90
+
91
+
92
+ def train_fold(fold: int, args: argparse.Namespace) -> dict:
93
+ """Train and evaluate one fold."""
94
+ logger = logging.getLogger(f"fold_{fold}")
95
+ logger.info("Starting fold %d/%d", fold + 1, 5)
96
+
97
+ # Create fold splits
98
+ fold_train, fold_val = create_fold_splits(
99
+ Path(args.splits_dir) / "train.csv",
100
+ Path(args.splits_dir) / "val.csv",
101
+ fold=fold,
102
+ filter_classes=["dfu", "non_dfu"],
103
+ )
104
+
105
+ train_ds = DFUDataset(str(fold_train), transform=get_train_transforms())
106
+ val_ds = DFUDataset(str(fold_val), transform=get_val_transforms())
107
+
108
+ train_loader = torch.utils.data.DataLoader(
109
+ train_ds,
110
+ batch_size=args.batch_size,
111
+ shuffle=True,
112
+ num_workers=args.num_workers,
113
+ pin_memory=True,
114
+ persistent_workers=args.num_workers > 0,
115
+ drop_last=True,
116
+ )
117
+ val_loader = torch.utils.data.DataLoader(
118
+ val_ds,
119
+ batch_size=args.batch_size,
120
+ shuffle=False,
121
+ num_workers=args.num_workers,
122
+ pin_memory=True,
123
+ persistent_workers=args.num_workers > 0,
124
+ )
125
+ logger.info("Fold %d: %d train, %d val samples", fold, len(train_ds), len(val_ds))
126
+
127
+ # Model
128
+ model = build_unetpp(
129
+ encoder_name="efficientnet-b4",
130
+ encoder_weights="imagenet",
131
+ classes=1,
132
+ decoder_attention_type="scse",
133
+ )
134
+
135
+ loss_fn = DiceCELoss()
136
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
137
+ scheduler = CosineAnnealingWithWarmup(
138
+ optimizer,
139
+ warmup_epochs=5,
140
+ max_epochs=args.epochs,
141
+ )
142
+ torch.manual_seed(42 + fold)
143
+
144
+ config = TrainConfig(
145
+ epochs=args.epochs,
146
+ precision="bf16-mixed",
147
+ compile_model=False,
148
+ gradient_clip=1.0,
149
+ checkpoint_dir=f"checkpoints/cv_fold{fold}",
150
+ monitor_metric="val/loss",
151
+ monitor_mode="min",
152
+ device=args.device,
153
+ early_stopping_patience=15,
154
+ )
155
+
156
+ trainer = Trainer(model=model, config=config)
157
+ trainer.fit(train_loader, val_loader, loss_fn, optimizer, scheduler)
158
+
159
+ # Evaluate on fold validation set
160
+ model.eval()
161
+ fold_metrics = []
162
+ with torch.no_grad():
163
+ for batch in val_loader:
164
+ images = batch["image"].to(args.device)
165
+ masks = batch["mask"].numpy()
166
+ logits = model(images)
167
+ preds = (torch.sigmoid(logits) > 0.5).squeeze(1).cpu().numpy().astype(np.uint8)
168
+ for i in range(len(images)):
169
+ m = compute_segmentation_metrics(preds[i], masks[i])
170
+ fold_metrics.append(m)
171
+
172
+ summary = aggregate_metrics(fold_metrics)
173
+ dice = summary.get("dice", {}).get("mean", 0)
174
+ iou = summary.get("iou", {}).get("mean", 0)
175
+ logger.info("Fold %d results: Dice=%.4f, IoU=%.4f", fold, dice, iou)
176
+
177
+ return {"fold": fold, "dice": dice, "iou": iou, "n_val": len(val_ds)}
178
+
179
+
180
+ def main() -> None:
181
+ """Run cross-validation."""
182
+ parser = argparse.ArgumentParser(description="5-Fold Cross Validation")
183
+ parser.add_argument("--fold", type=int, required=True, help="Fold index (0-4)")
184
+ parser.add_argument("--splits-dir", type=str, default="data/splits")
185
+ parser.add_argument("--device", type=str, default="cuda")
186
+ parser.add_argument("--epochs", type=int, default=50)
187
+ parser.add_argument("--batch-size", type=int, default=16)
188
+ parser.add_argument("--num-workers", type=int, default=8)
189
+ parser.add_argument("--verbose", action="store_true")
190
+ args = parser.parse_args()
191
+
192
+ logging.basicConfig(
193
+ level=logging.DEBUG if args.verbose else logging.INFO,
194
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
195
+ datefmt="%H:%M:%S",
196
+ )
197
+
198
+ result = train_fold(args.fold, args)
199
+
200
+ # Save fold result
201
+ output = Path(f"results/cv_fold{args.fold}.json")
202
+ output.parent.mkdir(parents=True, exist_ok=True)
203
+ with open(output, "w") as f:
204
+ json.dump(result, f, indent=2)
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()