Spaces:
Runtime error
Runtime error
File size: 7,517 Bytes
0b3dd07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """
============================================================
K-Fold Stratified Cross-Validation
============================================================
For statistical rigor required in IEEE publications.
Reports mean ± std across folds.
Usage:
python scripts/cross_validate.py --config configs/config.yaml --model resnet50 --folds 5
============================================================
"""
import os
import sys
import json
import yaml
import argparse
import numpy as np
from datetime import datetime
from copy import deepcopy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dataset.rangoli_dataset import RangoliDataset, get_train_transforms, get_val_transforms
from models.classifier import build_model, build_loss_function
from scripts.train import train_one_epoch, validate, get_optimizer, get_scheduler, EarlyStopping
from scripts.evaluate import compute_all_metrics
def run_kfold_cv(model_name, config, n_folds=5, device=None):
"""Run stratified k-fold cross-validation."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n{'#'*60}")
print(f" {n_folds}-FOLD CROSS-VALIDATION: {model_name.upper()}")
print(f"{'#'*60}")
# Load full dataset (train + val combined)
train_transform = get_train_transforms(config)
val_transform = get_val_transforms(config)
# Combine train and val for CV
full_dataset = RangoliDataset(
config["paths"]["train_dir"], split="train_cv", transform=None
)
targets = full_dataset.targets
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
fold_results = []
for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets)):
print(f"\n --- Fold {fold+1}/{n_folds} ---")
print(f" Train: {len(train_idx)} | Val: {len(val_idx)}")
# Create fold-specific datasets with transforms
train_subset = Subset(
RangoliDataset(config["paths"]["train_dir"], split="train",
transform=train_transform, class_to_idx=full_dataset.class_to_idx),
train_idx
)
val_subset = Subset(
RangoliDataset(config["paths"]["train_dir"], split="val",
transform=val_transform, class_to_idx=full_dataset.class_to_idx),
val_idx
)
train_loader = DataLoader(train_subset, batch_size=config["training"]["batch_size"],
shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_subset, batch_size=config["training"]["batch_size"],
shuffle=False, num_workers=4, pin_memory=True)
# Build fresh model
model = build_model(model_name, config).to(device)
criterion = build_loss_function(config, device=device)
optimizer = get_optimizer(model, config)
scheduler = get_scheduler(optimizer, config)
scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")
early_stopping = EarlyStopping(patience=config["training"]["early_stopping_patience"])
best_val_acc = 0
best_state = None
for epoch in range(config["training"]["num_epochs"]):
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, scheduler,
scaler, None, device, epoch, config
)
val_loss, val_acc, val_preds, val_targets = validate(
model, val_loader, criterion, device
)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_state = deepcopy(model.state_dict())
if early_stopping(val_acc):
print(f" Early stopping at epoch {epoch+1}")
break
if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}: val_acc={val_acc:.4f} (best={best_val_acc:.4f})")
# Evaluate best model on fold's val set
model.load_state_dict(best_state)
_, _, val_preds, val_targets = validate(model, val_loader, criterion, device)
# Get probabilities for AUC
model.eval()
all_probs = []
all_true = []
with torch.no_grad():
for images, targets_batch in val_loader:
images = images.to(device)
outputs = model(images)
probs = torch.softmax(outputs, dim=1)
all_probs.append(probs.cpu().numpy())
all_true.append(targets_batch.numpy())
all_probs = np.concatenate(all_probs)
all_true = np.concatenate(all_true)
all_preds_np = np.argmax(all_probs, axis=1)
class_names = [full_dataset.idx_to_class[i] for i in range(full_dataset.num_classes)]
fold_metrics = compute_all_metrics(
all_true, all_preds_np, all_probs, class_names, full_dataset.num_classes
)
fold_results.append(fold_metrics)
print(f" Fold {fold+1}: Acc={fold_metrics['accuracy']:.4f}, "
f"F1={fold_metrics['f1_macro']:.4f}, "
f"Kappa={fold_metrics['cohen_kappa']:.4f}")
# Aggregate results
metric_keys = ["accuracy", "precision_macro", "recall_macro", "f1_macro",
"cohen_kappa", "matthews_corrcoef", "top_3_accuracy"]
print(f"\n{'='*60}")
print(f" {n_folds}-FOLD CROSS-VALIDATION RESULTS: {model_name}")
print(f"{'='*60}")
cv_summary = {}
for key in metric_keys:
values = [r[key] for r in fold_results if key in r]
if values:
mean_val = np.mean(values)
std_val = np.std(values)
cv_summary[key] = {"mean": float(mean_val), "std": float(std_val)}
print(f" {key:25s}: {mean_val:.4f} ± {std_val:.4f}")
print(f"{'='*60}")
# Save
save_path = os.path.join(config["paths"]["reports"], f"{model_name}_cv_results.json")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w") as f:
json.dump(cv_summary, f, indent=2)
return cv_summary
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/config.yaml")
parser.add_argument("--model", type=str, default="resnet50")
parser.add_argument("--folds", type=int, default=5)
parser.add_argument("--gpu", type=int, default=0)
args = parser.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
if args.model == "all":
all_cv = {}
for model_name in config["models"].keys():
all_cv[model_name] = run_kfold_cv(model_name, config, args.folds, device)
# Save comparative CV results
save_path = os.path.join(config["paths"]["reports"], "all_cv_results.json")
with open(save_path, "w") as f:
json.dump(all_cv, f, indent=2)
else:
run_kfold_cv(args.model, config, args.folds, device)
if __name__ == "__main__":
main()
|