""" ============================================================ K-Fold Stratified Cross-Validation ============================================================ For statistical rigor required in IEEE publications. Reports mean ± std across folds. Usage: python scripts/cross_validate.py --config configs/config.yaml --model resnet50 --folds 5 ============================================================ """ import os import sys import json import yaml import argparse import numpy as np from datetime import datetime from copy import deepcopy import torch import torch.nn as nn from torch.utils.data import DataLoader, Subset from sklearn.model_selection import StratifiedKFold from tqdm import tqdm sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dataset.rangoli_dataset import RangoliDataset, get_train_transforms, get_val_transforms from models.classifier import build_model, build_loss_function from scripts.train import train_one_epoch, validate, get_optimizer, get_scheduler, EarlyStopping from scripts.evaluate import compute_all_metrics def run_kfold_cv(model_name, config, n_folds=5, device=None): """Run stratified k-fold cross-validation.""" if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\n{'#'*60}") print(f" {n_folds}-FOLD CROSS-VALIDATION: {model_name.upper()}") print(f"{'#'*60}") # Load full dataset (train + val combined) train_transform = get_train_transforms(config) val_transform = get_val_transforms(config) # Combine train and val for CV full_dataset = RangoliDataset( config["paths"]["train_dir"], split="train_cv", transform=None ) targets = full_dataset.targets skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42) fold_results = [] for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets)): print(f"\n --- Fold {fold+1}/{n_folds} ---") print(f" Train: {len(train_idx)} | Val: {len(val_idx)}") # Create fold-specific datasets with transforms train_subset = Subset( RangoliDataset(config["paths"]["train_dir"], split="train", transform=train_transform, class_to_idx=full_dataset.class_to_idx), train_idx ) val_subset = Subset( RangoliDataset(config["paths"]["train_dir"], split="val", transform=val_transform, class_to_idx=full_dataset.class_to_idx), val_idx ) train_loader = DataLoader(train_subset, batch_size=config["training"]["batch_size"], shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = DataLoader(val_subset, batch_size=config["training"]["batch_size"], shuffle=False, num_workers=4, pin_memory=True) # Build fresh model model = build_model(model_name, config).to(device) criterion = build_loss_function(config, device=device) optimizer = get_optimizer(model, config) scheduler = get_scheduler(optimizer, config) scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda") early_stopping = EarlyStopping(patience=config["training"]["early_stopping_patience"]) best_val_acc = 0 best_state = None for epoch in range(config["training"]["num_epochs"]): train_loss, train_acc = train_one_epoch( model, train_loader, criterion, optimizer, scheduler, scaler, None, device, epoch, config ) val_loss, val_acc, val_preds, val_targets = validate( model, val_loader, criterion, device ) if val_acc > best_val_acc: best_val_acc = val_acc best_state = deepcopy(model.state_dict()) if early_stopping(val_acc): print(f" Early stopping at epoch {epoch+1}") break if (epoch + 1) % 10 == 0: print(f" Epoch {epoch+1}: val_acc={val_acc:.4f} (best={best_val_acc:.4f})") # Evaluate best model on fold's val set model.load_state_dict(best_state) _, _, val_preds, val_targets = validate(model, val_loader, criterion, device) # Get probabilities for AUC model.eval() all_probs = [] all_true = [] with torch.no_grad(): for images, targets_batch in val_loader: images = images.to(device) outputs = model(images) probs = torch.softmax(outputs, dim=1) all_probs.append(probs.cpu().numpy()) all_true.append(targets_batch.numpy()) all_probs = np.concatenate(all_probs) all_true = np.concatenate(all_true) all_preds_np = np.argmax(all_probs, axis=1) class_names = [full_dataset.idx_to_class[i] for i in range(full_dataset.num_classes)] fold_metrics = compute_all_metrics( all_true, all_preds_np, all_probs, class_names, full_dataset.num_classes ) fold_results.append(fold_metrics) print(f" Fold {fold+1}: Acc={fold_metrics['accuracy']:.4f}, " f"F1={fold_metrics['f1_macro']:.4f}, " f"Kappa={fold_metrics['cohen_kappa']:.4f}") # Aggregate results metric_keys = ["accuracy", "precision_macro", "recall_macro", "f1_macro", "cohen_kappa", "matthews_corrcoef", "top_3_accuracy"] print(f"\n{'='*60}") print(f" {n_folds}-FOLD CROSS-VALIDATION RESULTS: {model_name}") print(f"{'='*60}") cv_summary = {} for key in metric_keys: values = [r[key] for r in fold_results if key in r] if values: mean_val = np.mean(values) std_val = np.std(values) cv_summary[key] = {"mean": float(mean_val), "std": float(std_val)} print(f" {key:25s}: {mean_val:.4f} ± {std_val:.4f}") print(f"{'='*60}") # Save save_path = os.path.join(config["paths"]["reports"], f"{model_name}_cv_results.json") os.makedirs(os.path.dirname(save_path), exist_ok=True) with open(save_path, "w") as f: json.dump(cv_summary, f, indent=2) return cv_summary def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="configs/config.yaml") parser.add_argument("--model", type=str, default="resnet50") parser.add_argument("--folds", type=int, default=5) parser.add_argument("--gpu", type=int, default=0) args = parser.parse_args() with open(args.config) as f: config = yaml.safe_load(f) device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") if args.model == "all": all_cv = {} for model_name in config["models"].keys(): all_cv[model_name] = run_kfold_cv(model_name, config, args.folds, device) # Save comparative CV results save_path = os.path.join(config["paths"]["reports"], "all_cv_results.json") with open(save_path, "w") as f: json.dump(all_cv, f, indent=2) else: run_kfold_cv(args.model, config, args.folds, device) if __name__ == "__main__": main()