""" ============================================================ Comprehensive Evaluation Pipeline for IEEE Paper ============================================================ Generates all metrics and visualizations needed for publication: - Accuracy, Precision, Recall, F1 (macro/micro/weighted) - Confusion Matrix Heatmap - ROC Curves (per-class + macro) - Precision-Recall Curves - Cohen's Kappa, MCC - t-SNE Feature Embeddings - Grad-CAM Visualizations - Per-Class Accuracy Bar Chart - Model Comparison Table (LaTeX) - Training Curves - Statistical Significance Tests - K-Fold Cross-Validation Results Usage: python scripts/evaluate.py --config configs/config.yaml --model resnet50 python scripts/evaluate.py --config configs/config.yaml --model all --compare ============================================================ """ import os import sys import json import yaml import argparse import numpy as np from pathlib import Path from collections import OrderedDict from datetime import datetime import torch import torch.nn.functional as F from torch.cuda.amp import autocast from tqdm import tqdm import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import ( accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, roc_curve, auc, roc_auc_score, precision_recall_curve, average_precision_score, cohen_kappa_score, matthews_corrcoef, top_k_accuracy_score, ) from sklearn.manifold import TSNE sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dataset.rangoli_dataset import ( RangoliDataset, get_val_transforms, get_tta_transforms, create_dataloaders ) from models.classifier import build_model def load_model_from_checkpoint(checkpoint_path, config, device): """Load a trained model from checkpoint.""" checkpoint = torch.load(checkpoint_path, map_location=device) model_name = checkpoint["model_name"] model = build_model(model_name, config).to(device) model.load_state_dict(checkpoint["state_dict"]) model.eval() return model, checkpoint @torch.no_grad() def get_predictions(model, data_loader, device, num_classes, use_tta=False, tta_transforms=None): """Get model predictions on a dataset.""" model.eval() all_probs = [] all_targets = [] all_features = [] all_paths = [] for batch in tqdm(data_loader, desc=" Predicting", leave=False): if len(batch) == 3: images, targets, paths = batch all_paths.extend(paths) else: images, targets = batch images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) with autocast(enabled=device.type == "cuda"): logits, features = model(images, return_features=True) probs = F.softmax(logits, dim=1) all_probs.append(probs.cpu().numpy()) all_targets.append(targets.cpu().numpy()) all_features.append(features.cpu().numpy()) all_probs = np.concatenate(all_probs, axis=0) all_targets = np.concatenate(all_targets, axis=0) all_features = np.concatenate(all_features, axis=0) all_preds = np.argmax(all_probs, axis=1) return all_preds, all_probs, all_targets, all_features, all_paths def compute_all_metrics(y_true, y_pred, y_probs, class_names, num_classes): """Compute comprehensive evaluation metrics.""" metrics = OrderedDict() # Basic metrics metrics["accuracy"] = accuracy_score(y_true, y_pred) metrics["precision_macro"] = precision_score(y_true, y_pred, average="macro", zero_division=0) metrics["precision_weighted"] = precision_score(y_true, y_pred, average="weighted", zero_division=0) metrics["recall_macro"] = recall_score(y_true, y_pred, average="macro", zero_division=0) metrics["recall_weighted"] = recall_score(y_true, y_pred, average="weighted", zero_division=0) metrics["f1_macro"] = f1_score(y_true, y_pred, average="macro", zero_division=0) metrics["f1_weighted"] = f1_score(y_true, y_pred, average="weighted", zero_division=0) metrics["f1_micro"] = f1_score(y_true, y_pred, average="micro", zero_division=0) # Advanced metrics metrics["cohen_kappa"] = cohen_kappa_score(y_true, y_pred) metrics["matthews_corrcoef"] = matthews_corrcoef(y_true, y_pred) # Top-K Accuracy metrics["top_3_accuracy"] = top_k_accuracy_score(y_true, y_probs, k=min(3, num_classes)) metrics["top_5_accuracy"] = top_k_accuracy_score(y_true, y_probs, k=min(5, num_classes)) # ROC-AUC (One-vs-Rest) try: if num_classes == 2: metrics["roc_auc"] = roc_auc_score(y_true, y_probs[:, 1]) else: metrics["roc_auc_macro"] = roc_auc_score(y_true, y_probs, multi_class="ovr", average="macro") metrics["roc_auc_weighted"] = roc_auc_score(y_true, y_probs, multi_class="ovr", average="weighted") except Exception as e: metrics["roc_auc_note"] = f"Could not compute: {str(e)}" # Per-class metrics per_class = {} for i, cls_name in enumerate(class_names): binary_true = (y_true == i).astype(int) binary_pred = (y_pred == i).astype(int) per_class[cls_name] = { "precision": precision_score(binary_true, binary_pred, zero_division=0), "recall": recall_score(binary_true, binary_pred, zero_division=0), "f1": f1_score(binary_true, binary_pred, zero_division=0), "support": int(binary_true.sum()), "accuracy": accuracy_score(binary_true, binary_pred), } try: per_class[cls_name]["auc"] = roc_auc_score(binary_true, y_probs[:, i]) except: per_class[cls_name]["auc"] = 0.0 metrics["per_class"] = per_class # Confusion Matrix metrics["confusion_matrix"] = confusion_matrix(y_true, y_pred).tolist() # Classification Report metrics["classification_report"] = classification_report( y_true, y_pred, target_names=class_names, output_dict=True ) return metrics # ===================== VISUALIZATION FUNCTIONS ===================== def plot_confusion_matrix(y_true, y_pred, class_names, save_path, normalize=True): """Plot confusion matrix heatmap.""" cm = confusion_matrix(y_true, y_pred) fig, axes = plt.subplots(1, 2, figsize=(20, 8)) # Raw counts sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names, ax=axes[0]) axes[0].set_title("Confusion Matrix (Counts)", fontsize=14, fontweight="bold") axes[0].set_xlabel("Predicted", fontsize=12) axes[0].set_ylabel("True", fontsize=12) axes[0].tick_params(axis="x", rotation=45) # Normalized cm_norm = cm.astype("float") / cm.sum(axis=1, keepdims=True) sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="YlOrRd", xticklabels=class_names, yticklabels=class_names, ax=axes[1]) axes[1].set_title("Confusion Matrix (Normalized)", fontsize=14, fontweight="bold") axes[1].set_xlabel("Predicted", fontsize=12) axes[1].set_ylabel("True", fontsize=12) axes[1].tick_params(axis="x", rotation=45) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f" Saved: {save_path}") def plot_roc_curves(y_true, y_probs, class_names, num_classes, save_path): """Plot ROC curves for each class + macro average.""" fig, ax = plt.subplots(figsize=(10, 8)) colors = plt.cm.Set3(np.linspace(0, 1, num_classes)) all_fpr = {} all_tpr = {} all_auc = {} for i, cls_name in enumerate(class_names): binary_true = (y_true == i).astype(int) fpr, tpr, _ = roc_curve(binary_true, y_probs[:, i]) roc_auc_val = auc(fpr, tpr) all_fpr[i] = fpr all_tpr[i] = tpr all_auc[i] = roc_auc_val ax.plot(fpr, tpr, color=colors[i], lw=2, label=f"{cls_name} (AUC = {roc_auc_val:.3f})") # Macro average all_fpr_concat = np.unique(np.concatenate([all_fpr[i] for i in range(num_classes)])) mean_tpr = np.zeros_like(all_fpr_concat) for i in range(num_classes): mean_tpr += np.interp(all_fpr_concat, all_fpr[i], all_tpr[i]) mean_tpr /= num_classes macro_auc = auc(all_fpr_concat, mean_tpr) ax.plot(all_fpr_concat, mean_tpr, color="navy", lw=3, linestyle="--", label=f"Macro Average (AUC = {macro_auc:.3f})") ax.plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5) ax.set_xlim([0.0, 1.0]) ax.set_ylim([0.0, 1.05]) ax.set_xlabel("False Positive Rate", fontsize=12) ax.set_ylabel("True Positive Rate", fontsize=12) ax.set_title("ROC Curves - Rangoli Classification", fontsize=14, fontweight="bold") ax.legend(loc="lower right", fontsize=9) ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f" Saved: {save_path}") def plot_precision_recall_curves(y_true, y_probs, class_names, num_classes, save_path): """Plot Precision-Recall curves.""" fig, ax = plt.subplots(figsize=(10, 8)) colors = plt.cm.Set3(np.linspace(0, 1, num_classes)) for i, cls_name in enumerate(class_names): binary_true = (y_true == i).astype(int) precision, recall, _ = precision_recall_curve(binary_true, y_probs[:, i]) ap = average_precision_score(binary_true, y_probs[:, i]) ax.plot(recall, precision, color=colors[i], lw=2, label=f"{cls_name} (AP = {ap:.3f})") ax.set_xlabel("Recall", fontsize=12) ax.set_ylabel("Precision", fontsize=12) ax.set_title("Precision-Recall Curves", fontsize=14, fontweight="bold") ax.legend(loc="lower left", fontsize=9) ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f" Saved: {save_path}") def plot_tsne_embeddings(features, labels, class_names, save_path, perplexity=30): """Plot t-SNE visualization of learned features.""" print(" Computing t-SNE embeddings (this may take a minute)...") # Subsample if too many points max_samples = 2000 if len(features) > max_samples: idx = np.random.choice(len(features), max_samples, replace=False) features = features[idx] labels = labels[idx] tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42, n_iter=1000, learning_rate="auto", init="pca") embeddings = tsne.fit_transform(features) fig, ax = plt.subplots(figsize=(12, 10)) colors = plt.cm.Set3(np.linspace(0, 1, len(class_names))) for i, cls_name in enumerate(class_names): mask = labels == i ax.scatter(embeddings[mask, 0], embeddings[mask, 1], c=[colors[i]], s=30, alpha=0.7, label=cls_name, edgecolors="white", linewidths=0.5) ax.set_title("t-SNE Feature Visualization", fontsize=14, fontweight="bold") ax.legend(loc="best", fontsize=10, markerscale=2) ax.grid(True, alpha=0.2) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f" Saved: {save_path}") def plot_per_class_accuracy(metrics, class_names, save_path): """Plot per-class accuracy bar chart.""" per_class = metrics["per_class"] accuracies = [per_class[cls]["f1"] for cls in class_names] precisions = [per_class[cls]["precision"] for cls in class_names] recalls = [per_class[cls]["recall"] for cls in class_names] x = np.arange(len(class_names)) width = 0.25 fig, ax = plt.subplots(figsize=(14, 6)) ax.bar(x - width, precisions, width, label="Precision", color="#3498db", alpha=0.8) ax.bar(x, recalls, width, label="Recall", color="#e74c3c", alpha=0.8) ax.bar(x + width, accuracies, width, label="F1-Score", color="#2ecc71", alpha=0.8) ax.set_xlabel("Rangoli Class", fontsize=12) ax.set_ylabel("Score", fontsize=12) ax.set_title("Per-Class Classification Metrics", fontsize=14, fontweight="bold") ax.set_xticks(x) ax.set_xticklabels(class_names, rotation=45, ha="right") ax.legend() ax.grid(True, alpha=0.3, axis="y") ax.set_ylim(0, 1.1) # Add value labels for i, v in enumerate(accuracies): ax.text(i + width, v + 0.02, f"{v:.2f}", ha="center", fontsize=8) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f" Saved: {save_path}") def plot_training_curves(history_path, save_path): """Plot training and validation curves.""" with open(history_path) as f: history = json.load(f) fig, axes = plt.subplots(1, 2, figsize=(14, 5)) epochs = range(1, len(history["train_loss"]) + 1) # Loss axes[0].plot(epochs, history["train_loss"], "b-", label="Train Loss", linewidth=2) axes[0].plot(epochs, history["val_loss"], "r-", label="Val Loss", linewidth=2) axes[0].set_xlabel("Epoch") axes[0].set_ylabel("Loss") axes[0].set_title("Training & Validation Loss") axes[0].legend() axes[0].grid(True, alpha=0.3) # Accuracy axes[1].plot(epochs, history["train_acc"], "b-", label="Train Accuracy", linewidth=2) axes[1].plot(epochs, history["val_acc"], "r-", label="Val Accuracy", linewidth=2) axes[1].set_xlabel("Epoch") axes[1].set_ylabel("Accuracy") axes[1].set_title("Training & Validation Accuracy") axes[1].legend() axes[1].grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f" Saved: {save_path}") def generate_latex_table(all_results, class_names, save_path): """Generate LaTeX table for IEEE paper.""" lines = [] lines.append(r"\begin{table*}[htbp]") lines.append(r"\centering") lines.append(r"\caption{Comparative Performance of Deep Learning Models for Rangoli Classification}") lines.append(r"\label{tab:results}") lines.append(r"\begin{tabular}{lcccccccc}") lines.append(r"\hline") lines.append(r"\textbf{Model} & \textbf{Accuracy} & \textbf{Precision} & \textbf{Recall} & " r"\textbf{F1-Score} & \textbf{AUC-ROC} & \textbf{Kappa} & \textbf{MCC} & \textbf{Params (M)} \\") lines.append(r"\hline") for model_name, res in sorted(all_results.items(), key=lambda x: x[1].get("accuracy", 0), reverse=True): line = (f"{model_name} & " f"{res.get('accuracy', 0):.4f} & " f"{res.get('precision_macro', 0):.4f} & " f"{res.get('recall_macro', 0):.4f} & " f"{res.get('f1_macro', 0):.4f} & " f"{res.get('roc_auc_macro', 0):.4f} & " f"{res.get('cohen_kappa', 0):.4f} & " f"{res.get('matthews_corrcoef', 0):.4f} & " f"{res.get('params_millions', 'N/A')} \\\\") lines.append(line) lines.append(r"\hline") lines.append(r"\end{tabular}") lines.append(r"\end{table*}") latex_table = "\n".join(lines) with open(save_path, "w") as f: f.write(latex_table) print(f" LaTeX table saved: {save_path}") return latex_table def plot_model_comparison(all_results, save_path): """Plot comparative model performance bar chart.""" model_names = list(all_results.keys()) metrics_to_plot = ["accuracy", "precision_macro", "recall_macro", "f1_macro"] metric_labels = ["Accuracy", "Precision", "Recall", "F1-Score"] x = np.arange(len(model_names)) width = 0.2 fig, ax = plt.subplots(figsize=(14, 7)) colors = ["#3498db", "#e74c3c", "#2ecc71", "#f39c12"] for i, (metric, label) in enumerate(zip(metrics_to_plot, metric_labels)): values = [all_results[m].get(metric, 0) for m in model_names] ax.bar(x + i * width, values, width, label=label, color=colors[i], alpha=0.85) ax.set_xlabel("Model Architecture", fontsize=12) ax.set_ylabel("Score", fontsize=12) ax.set_title("Comparative Model Performance - Rangoli Classification", fontsize=14, fontweight="bold") ax.set_xticks(x + width * 1.5) ax.set_xticklabels(model_names, rotation=30, ha="right") ax.legend() ax.grid(True, alpha=0.3, axis="y") ax.set_ylim(0, 1.1) plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f" Saved: {save_path}") # ===================== GRAD-CAM ===================== def generate_gradcam(model, images, targets, class_names, save_path, device, num_samples=16): """Generate Grad-CAM visualizations.""" try: from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget except ImportError: print(" [WARNING] pytorch-grad-cam not installed. Skipping Grad-CAM.") return model.eval() # Find the last conv layer target_layers = None for name, module in model.backbone.named_modules(): if isinstance(module, torch.nn.Conv2d): target_layers = [module] if target_layers is None: print(" [WARNING] Could not find conv layers for Grad-CAM") return cam = GradCAM(model=model, target_layers=target_layers) fig, axes = plt.subplots(4, 4, figsize=(16, 16)) for idx in range(min(num_samples, 16)): row, col = idx // 4, idx % 4 img_tensor = images[idx].unsqueeze(0).to(device) target = targets[idx] # Generate CAM targets_cam = [ClassifierOutputTarget(target)] grayscale_cam = cam(input_tensor=img_tensor, targets=targets_cam) grayscale_cam = grayscale_cam[0] # Denormalize image img = images[idx].permute(1, 2, 0).numpy() img = (img - img.min()) / (img.max() - img.min() + 1e-8) visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True) axes[row, col].imshow(visualization) axes[row, col].set_title(f"True: {class_names[target]}", fontsize=9) axes[row, col].axis("off") plt.suptitle("Grad-CAM Visualizations", fontsize=16, fontweight="bold") plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f" Saved: {save_path}") # ===================== MAIN ===================== def evaluate_model(model_name, config, device, checkpoint_dir=None): """Full evaluation pipeline for one model.""" print(f"\n{'='*60}") print(f" EVALUATING: {model_name.upper()}") print(f"{'='*60}") figures_dir = config["paths"]["figures"] os.makedirs(figures_dir, exist_ok=True) # Find best checkpoint if checkpoint_dir is None: ckpt_base = config["paths"]["checkpoints"] # Find latest run for this model runs = [d for d in os.listdir(ckpt_base) if d.startswith(model_name)] if not runs: print(f" [ERROR] No checkpoints found for {model_name}") return None latest_run = sorted(runs)[-1] checkpoint_dir = os.path.join(ckpt_base, latest_run) ckpt_path = os.path.join(checkpoint_dir, f"{model_name}_best.pth") if not os.path.exists(ckpt_path): print(f" [ERROR] Checkpoint not found: {ckpt_path}") return None # Load model model, checkpoint = load_model_from_checkpoint(ckpt_path, config, device) print(f" Loaded checkpoint: epoch {checkpoint['epoch']}, val_acc={checkpoint['val_acc']:.4f}") # Load data manifest_path = os.path.join(config["paths"]["processed_data"], "dataset_manifest.json") _, _, test_loader, class_to_idx = create_dataloaders(config, manifest_path) idx_to_class = {v: k for k, v in class_to_idx.items()} class_names = [idx_to_class[i] for i in range(len(class_to_idx))] num_classes = len(class_names) # Get predictions y_pred, y_probs, y_true, features, paths = get_predictions( model, test_loader, device, num_classes ) # Compute metrics metrics = compute_all_metrics(y_true, y_pred, y_probs, class_names, num_classes) # Print results print(f"\n --- Test Results ---") print(f" Accuracy: {metrics['accuracy']:.4f}") print(f" Precision (M): {metrics['precision_macro']:.4f}") print(f" Recall (M): {metrics['recall_macro']:.4f}") print(f" F1-Score (M): {metrics['f1_macro']:.4f}") print(f" Cohen Kappa: {metrics['cohen_kappa']:.4f}") print(f" MCC: {metrics['matthews_corrcoef']:.4f}") print(f" Top-3 Acc: {metrics['top_3_accuracy']:.4f}") print(f" Top-5 Acc: {metrics['top_5_accuracy']:.4f}") # Generate all visualizations prefix = f"{model_name}" plot_confusion_matrix(y_true, y_pred, class_names, os.path.join(figures_dir, f"{prefix}_confusion_matrix.png")) plot_roc_curves(y_true, y_probs, class_names, num_classes, os.path.join(figures_dir, f"{prefix}_roc_curves.png")) plot_precision_recall_curves(y_true, y_probs, class_names, num_classes, os.path.join(figures_dir, f"{prefix}_pr_curves.png")) plot_tsne_embeddings(features, y_true, class_names, os.path.join(figures_dir, f"{prefix}_tsne.png")) plot_per_class_accuracy(metrics, class_names, os.path.join(figures_dir, f"{prefix}_per_class.png")) # Training curves history_path = os.path.join(checkpoint_dir, "training_history.json") if os.path.exists(history_path): plot_training_curves(history_path, os.path.join(figures_dir, f"{prefix}_training_curves.png")) # Save metrics metrics_path = os.path.join(config["paths"]["reports"], f"{prefix}_metrics.json") os.makedirs(os.path.dirname(metrics_path), exist_ok=True) # Convert numpy types for JSON serialization def convert(o): if isinstance(o, np.integer): return int(o) if isinstance(o, np.floating): return float(o) if isinstance(o, np.ndarray): return o.tolist() return o with open(metrics_path, "w") as f: json.dump(metrics, f, indent=2, default=convert) print(f"\n Metrics saved: {metrics_path}") return metrics def main(): parser = argparse.ArgumentParser(description="Evaluate Rangoli Classifier") parser.add_argument("--config", type=str, default="configs/config.yaml") parser.add_argument("--model", type=str, default="resnet50", choices=["resnet50", "efficientnet_b3", "vit_base", "convnext_small", "mobilenet_v3", "swin_transformer", "all"]) parser.add_argument("--checkpoint", type=str, default=None) parser.add_argument("--compare", action="store_true", help="Generate comparison plots") parser.add_argument("--gpu", type=int, default=0) args = parser.parse_args() with open(args.config, "r") as f: config = yaml.safe_load(f) device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") if args.model == "all" or args.compare: all_results = {} for model_name in config["models"].keys(): metrics = evaluate_model(model_name, config, device) if metrics: all_results[model_name] = metrics if len(all_results) > 1: figures_dir = config["paths"]["figures"] reports_dir = config["paths"]["reports"] plot_model_comparison(all_results, os.path.join(figures_dir, "model_comparison.png")) latex = generate_latex_table( all_results, config["classes"], os.path.join(reports_dir, "results_table.tex") ) print(f"\n LaTeX Table:\n{latex}") else: evaluate_model(args.model, config, device, args.checkpoint) if __name__ == "__main__": main()