Spaces:
Runtime error
Runtime error
| """ | |
| ============================================================ | |
| K-Fold Stratified Cross-Validation | |
| ============================================================ | |
| For statistical rigor required in IEEE publications. | |
| Reports mean ± std across folds. | |
| Usage: | |
| python scripts/cross_validate.py --config configs/config.yaml --model resnet50 --folds 5 | |
| ============================================================ | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import yaml | |
| import argparse | |
| import numpy as np | |
| from datetime import datetime | |
| from copy import deepcopy | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, Subset | |
| from sklearn.model_selection import StratifiedKFold | |
| from tqdm import tqdm | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from dataset.rangoli_dataset import RangoliDataset, get_train_transforms, get_val_transforms | |
| from models.classifier import build_model, build_loss_function | |
| from scripts.train import train_one_epoch, validate, get_optimizer, get_scheduler, EarlyStopping | |
| from scripts.evaluate import compute_all_metrics | |
| def run_kfold_cv(model_name, config, n_folds=5, device=None): | |
| """Run stratified k-fold cross-validation.""" | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"\n{'#'*60}") | |
| print(f" {n_folds}-FOLD CROSS-VALIDATION: {model_name.upper()}") | |
| print(f"{'#'*60}") | |
| # Load full dataset (train + val combined) | |
| train_transform = get_train_transforms(config) | |
| val_transform = get_val_transforms(config) | |
| # Combine train and val for CV | |
| full_dataset = RangoliDataset( | |
| config["paths"]["train_dir"], split="train_cv", transform=None | |
| ) | |
| targets = full_dataset.targets | |
| skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42) | |
| fold_results = [] | |
| for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets)): | |
| print(f"\n --- Fold {fold+1}/{n_folds} ---") | |
| print(f" Train: {len(train_idx)} | Val: {len(val_idx)}") | |
| # Create fold-specific datasets with transforms | |
| train_subset = Subset( | |
| RangoliDataset(config["paths"]["train_dir"], split="train", | |
| transform=train_transform, class_to_idx=full_dataset.class_to_idx), | |
| train_idx | |
| ) | |
| val_subset = Subset( | |
| RangoliDataset(config["paths"]["train_dir"], split="val", | |
| transform=val_transform, class_to_idx=full_dataset.class_to_idx), | |
| val_idx | |
| ) | |
| train_loader = DataLoader(train_subset, batch_size=config["training"]["batch_size"], | |
| shuffle=True, num_workers=4, pin_memory=True, drop_last=True) | |
| val_loader = DataLoader(val_subset, batch_size=config["training"]["batch_size"], | |
| shuffle=False, num_workers=4, pin_memory=True) | |
| # Build fresh model | |
| model = build_model(model_name, config).to(device) | |
| criterion = build_loss_function(config, device=device) | |
| optimizer = get_optimizer(model, config) | |
| scheduler = get_scheduler(optimizer, config) | |
| scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda") | |
| early_stopping = EarlyStopping(patience=config["training"]["early_stopping_patience"]) | |
| best_val_acc = 0 | |
| best_state = None | |
| for epoch in range(config["training"]["num_epochs"]): | |
| train_loss, train_acc = train_one_epoch( | |
| model, train_loader, criterion, optimizer, scheduler, | |
| scaler, None, device, epoch, config | |
| ) | |
| val_loss, val_acc, val_preds, val_targets = validate( | |
| model, val_loader, criterion, device | |
| ) | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| best_state = deepcopy(model.state_dict()) | |
| if early_stopping(val_acc): | |
| print(f" Early stopping at epoch {epoch+1}") | |
| break | |
| if (epoch + 1) % 10 == 0: | |
| print(f" Epoch {epoch+1}: val_acc={val_acc:.4f} (best={best_val_acc:.4f})") | |
| # Evaluate best model on fold's val set | |
| model.load_state_dict(best_state) | |
| _, _, val_preds, val_targets = validate(model, val_loader, criterion, device) | |
| # Get probabilities for AUC | |
| model.eval() | |
| all_probs = [] | |
| all_true = [] | |
| with torch.no_grad(): | |
| for images, targets_batch in val_loader: | |
| images = images.to(device) | |
| outputs = model(images) | |
| probs = torch.softmax(outputs, dim=1) | |
| all_probs.append(probs.cpu().numpy()) | |
| all_true.append(targets_batch.numpy()) | |
| all_probs = np.concatenate(all_probs) | |
| all_true = np.concatenate(all_true) | |
| all_preds_np = np.argmax(all_probs, axis=1) | |
| class_names = [full_dataset.idx_to_class[i] for i in range(full_dataset.num_classes)] | |
| fold_metrics = compute_all_metrics( | |
| all_true, all_preds_np, all_probs, class_names, full_dataset.num_classes | |
| ) | |
| fold_results.append(fold_metrics) | |
| print(f" Fold {fold+1}: Acc={fold_metrics['accuracy']:.4f}, " | |
| f"F1={fold_metrics['f1_macro']:.4f}, " | |
| f"Kappa={fold_metrics['cohen_kappa']:.4f}") | |
| # Aggregate results | |
| metric_keys = ["accuracy", "precision_macro", "recall_macro", "f1_macro", | |
| "cohen_kappa", "matthews_corrcoef", "top_3_accuracy"] | |
| print(f"\n{'='*60}") | |
| print(f" {n_folds}-FOLD CROSS-VALIDATION RESULTS: {model_name}") | |
| print(f"{'='*60}") | |
| cv_summary = {} | |
| for key in metric_keys: | |
| values = [r[key] for r in fold_results if key in r] | |
| if values: | |
| mean_val = np.mean(values) | |
| std_val = np.std(values) | |
| cv_summary[key] = {"mean": float(mean_val), "std": float(std_val)} | |
| print(f" {key:25s}: {mean_val:.4f} ± {std_val:.4f}") | |
| print(f"{'='*60}") | |
| # Save | |
| save_path = os.path.join(config["paths"]["reports"], f"{model_name}_cv_results.json") | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| with open(save_path, "w") as f: | |
| json.dump(cv_summary, f, indent=2) | |
| return cv_summary | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="configs/config.yaml") | |
| parser.add_argument("--model", type=str, default="resnet50") | |
| parser.add_argument("--folds", type=int, default=5) | |
| parser.add_argument("--gpu", type=int, default=0) | |
| args = parser.parse_args() | |
| with open(args.config) as f: | |
| config = yaml.safe_load(f) | |
| device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") | |
| if args.model == "all": | |
| all_cv = {} | |
| for model_name in config["models"].keys(): | |
| all_cv[model_name] = run_kfold_cv(model_name, config, args.folds, device) | |
| # Save comparative CV results | |
| save_path = os.path.join(config["paths"]["reports"], "all_cv_results.json") | |
| with open(save_path, "w") as f: | |
| json.dump(all_cv, f, indent=2) | |
| else: | |
| run_kfold_cv(args.model, config, args.folds, device) | |
| if __name__ == "__main__": | |
| main() | |