rangoli-classifier / scripts /cross_validate.py
shashidharak99's picture
Upload 16 files
0b3dd07 verified
"""
============================================================
K-Fold Stratified Cross-Validation
============================================================
For statistical rigor required in IEEE publications.
Reports mean ± std across folds.
Usage:
python scripts/cross_validate.py --config configs/config.yaml --model resnet50 --folds 5
============================================================
"""
import os
import sys
import json
import yaml
import argparse
import numpy as np
from datetime import datetime
from copy import deepcopy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dataset.rangoli_dataset import RangoliDataset, get_train_transforms, get_val_transforms
from models.classifier import build_model, build_loss_function
from scripts.train import train_one_epoch, validate, get_optimizer, get_scheduler, EarlyStopping
from scripts.evaluate import compute_all_metrics
def run_kfold_cv(model_name, config, n_folds=5, device=None):
"""Run stratified k-fold cross-validation."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n{'#'*60}")
print(f" {n_folds}-FOLD CROSS-VALIDATION: {model_name.upper()}")
print(f"{'#'*60}")
# Load full dataset (train + val combined)
train_transform = get_train_transforms(config)
val_transform = get_val_transforms(config)
# Combine train and val for CV
full_dataset = RangoliDataset(
config["paths"]["train_dir"], split="train_cv", transform=None
)
targets = full_dataset.targets
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
fold_results = []
for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets)):
print(f"\n --- Fold {fold+1}/{n_folds} ---")
print(f" Train: {len(train_idx)} | Val: {len(val_idx)}")
# Create fold-specific datasets with transforms
train_subset = Subset(
RangoliDataset(config["paths"]["train_dir"], split="train",
transform=train_transform, class_to_idx=full_dataset.class_to_idx),
train_idx
)
val_subset = Subset(
RangoliDataset(config["paths"]["train_dir"], split="val",
transform=val_transform, class_to_idx=full_dataset.class_to_idx),
val_idx
)
train_loader = DataLoader(train_subset, batch_size=config["training"]["batch_size"],
shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_subset, batch_size=config["training"]["batch_size"],
shuffle=False, num_workers=4, pin_memory=True)
# Build fresh model
model = build_model(model_name, config).to(device)
criterion = build_loss_function(config, device=device)
optimizer = get_optimizer(model, config)
scheduler = get_scheduler(optimizer, config)
scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")
early_stopping = EarlyStopping(patience=config["training"]["early_stopping_patience"])
best_val_acc = 0
best_state = None
for epoch in range(config["training"]["num_epochs"]):
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, scheduler,
scaler, None, device, epoch, config
)
val_loss, val_acc, val_preds, val_targets = validate(
model, val_loader, criterion, device
)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_state = deepcopy(model.state_dict())
if early_stopping(val_acc):
print(f" Early stopping at epoch {epoch+1}")
break
if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}: val_acc={val_acc:.4f} (best={best_val_acc:.4f})")
# Evaluate best model on fold's val set
model.load_state_dict(best_state)
_, _, val_preds, val_targets = validate(model, val_loader, criterion, device)
# Get probabilities for AUC
model.eval()
all_probs = []
all_true = []
with torch.no_grad():
for images, targets_batch in val_loader:
images = images.to(device)
outputs = model(images)
probs = torch.softmax(outputs, dim=1)
all_probs.append(probs.cpu().numpy())
all_true.append(targets_batch.numpy())
all_probs = np.concatenate(all_probs)
all_true = np.concatenate(all_true)
all_preds_np = np.argmax(all_probs, axis=1)
class_names = [full_dataset.idx_to_class[i] for i in range(full_dataset.num_classes)]
fold_metrics = compute_all_metrics(
all_true, all_preds_np, all_probs, class_names, full_dataset.num_classes
)
fold_results.append(fold_metrics)
print(f" Fold {fold+1}: Acc={fold_metrics['accuracy']:.4f}, "
f"F1={fold_metrics['f1_macro']:.4f}, "
f"Kappa={fold_metrics['cohen_kappa']:.4f}")
# Aggregate results
metric_keys = ["accuracy", "precision_macro", "recall_macro", "f1_macro",
"cohen_kappa", "matthews_corrcoef", "top_3_accuracy"]
print(f"\n{'='*60}")
print(f" {n_folds}-FOLD CROSS-VALIDATION RESULTS: {model_name}")
print(f"{'='*60}")
cv_summary = {}
for key in metric_keys:
values = [r[key] for r in fold_results if key in r]
if values:
mean_val = np.mean(values)
std_val = np.std(values)
cv_summary[key] = {"mean": float(mean_val), "std": float(std_val)}
print(f" {key:25s}: {mean_val:.4f} ± {std_val:.4f}")
print(f"{'='*60}")
# Save
save_path = os.path.join(config["paths"]["reports"], f"{model_name}_cv_results.json")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w") as f:
json.dump(cv_summary, f, indent=2)
return cv_summary
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/config.yaml")
parser.add_argument("--model", type=str, default="resnet50")
parser.add_argument("--folds", type=int, default=5)
parser.add_argument("--gpu", type=int, default=0)
args = parser.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
if args.model == "all":
all_cv = {}
for model_name in config["models"].keys():
all_cv[model_name] = run_kfold_cv(model_name, config, args.folds, device)
# Save comparative CV results
save_path = os.path.join(config["paths"]["reports"], "all_cv_results.json")
with open(save_path, "w") as f:
json.dump(all_cv, f, indent=2)
else:
run_kfold_cv(args.model, config, args.folds, device)
if __name__ == "__main__":
main()