rangoli-classifier / scripts /evaluate.py
shashidharak99's picture
Upload 16 files
0b3dd07 verified
"""
============================================================
Comprehensive Evaluation Pipeline for IEEE Paper
============================================================
Generates all metrics and visualizations needed for publication:
- Accuracy, Precision, Recall, F1 (macro/micro/weighted)
- Confusion Matrix Heatmap
- ROC Curves (per-class + macro)
- Precision-Recall Curves
- Cohen's Kappa, MCC
- t-SNE Feature Embeddings
- Grad-CAM Visualizations
- Per-Class Accuracy Bar Chart
- Model Comparison Table (LaTeX)
- Training Curves
- Statistical Significance Tests
- K-Fold Cross-Validation Results
Usage:
python scripts/evaluate.py --config configs/config.yaml --model resnet50
python scripts/evaluate.py --config configs/config.yaml --model all --compare
============================================================
"""
import os
import sys
import json
import yaml
import argparse
import numpy as np
from pathlib import Path
from collections import OrderedDict
from datetime import datetime
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, classification_report,
roc_curve, auc, roc_auc_score,
precision_recall_curve, average_precision_score,
cohen_kappa_score, matthews_corrcoef,
top_k_accuracy_score,
)
from sklearn.manifold import TSNE
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dataset.rangoli_dataset import (
RangoliDataset, get_val_transforms, get_tta_transforms, create_dataloaders
)
from models.classifier import build_model
def load_model_from_checkpoint(checkpoint_path, config, device):
"""Load a trained model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device)
model_name = checkpoint["model_name"]
model = build_model(model_name, config).to(device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
return model, checkpoint
@torch.no_grad()
def get_predictions(model, data_loader, device, num_classes, use_tta=False,
tta_transforms=None):
"""Get model predictions on a dataset."""
model.eval()
all_probs = []
all_targets = []
all_features = []
all_paths = []
for batch in tqdm(data_loader, desc=" Predicting", leave=False):
if len(batch) == 3:
images, targets, paths = batch
all_paths.extend(paths)
else:
images, targets = batch
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
with autocast(enabled=device.type == "cuda"):
logits, features = model(images, return_features=True)
probs = F.softmax(logits, dim=1)
all_probs.append(probs.cpu().numpy())
all_targets.append(targets.cpu().numpy())
all_features.append(features.cpu().numpy())
all_probs = np.concatenate(all_probs, axis=0)
all_targets = np.concatenate(all_targets, axis=0)
all_features = np.concatenate(all_features, axis=0)
all_preds = np.argmax(all_probs, axis=1)
return all_preds, all_probs, all_targets, all_features, all_paths
def compute_all_metrics(y_true, y_pred, y_probs, class_names, num_classes):
"""Compute comprehensive evaluation metrics."""
metrics = OrderedDict()
# Basic metrics
metrics["accuracy"] = accuracy_score(y_true, y_pred)
metrics["precision_macro"] = precision_score(y_true, y_pred, average="macro", zero_division=0)
metrics["precision_weighted"] = precision_score(y_true, y_pred, average="weighted", zero_division=0)
metrics["recall_macro"] = recall_score(y_true, y_pred, average="macro", zero_division=0)
metrics["recall_weighted"] = recall_score(y_true, y_pred, average="weighted", zero_division=0)
metrics["f1_macro"] = f1_score(y_true, y_pred, average="macro", zero_division=0)
metrics["f1_weighted"] = f1_score(y_true, y_pred, average="weighted", zero_division=0)
metrics["f1_micro"] = f1_score(y_true, y_pred, average="micro", zero_division=0)
# Advanced metrics
metrics["cohen_kappa"] = cohen_kappa_score(y_true, y_pred)
metrics["matthews_corrcoef"] = matthews_corrcoef(y_true, y_pred)
# Top-K Accuracy
metrics["top_3_accuracy"] = top_k_accuracy_score(y_true, y_probs, k=min(3, num_classes))
metrics["top_5_accuracy"] = top_k_accuracy_score(y_true, y_probs, k=min(5, num_classes))
# ROC-AUC (One-vs-Rest)
try:
if num_classes == 2:
metrics["roc_auc"] = roc_auc_score(y_true, y_probs[:, 1])
else:
metrics["roc_auc_macro"] = roc_auc_score(y_true, y_probs, multi_class="ovr", average="macro")
metrics["roc_auc_weighted"] = roc_auc_score(y_true, y_probs, multi_class="ovr", average="weighted")
except Exception as e:
metrics["roc_auc_note"] = f"Could not compute: {str(e)}"
# Per-class metrics
per_class = {}
for i, cls_name in enumerate(class_names):
binary_true = (y_true == i).astype(int)
binary_pred = (y_pred == i).astype(int)
per_class[cls_name] = {
"precision": precision_score(binary_true, binary_pred, zero_division=0),
"recall": recall_score(binary_true, binary_pred, zero_division=0),
"f1": f1_score(binary_true, binary_pred, zero_division=0),
"support": int(binary_true.sum()),
"accuracy": accuracy_score(binary_true, binary_pred),
}
try:
per_class[cls_name]["auc"] = roc_auc_score(binary_true, y_probs[:, i])
except:
per_class[cls_name]["auc"] = 0.0
metrics["per_class"] = per_class
# Confusion Matrix
metrics["confusion_matrix"] = confusion_matrix(y_true, y_pred).tolist()
# Classification Report
metrics["classification_report"] = classification_report(
y_true, y_pred, target_names=class_names, output_dict=True
)
return metrics
# ===================== VISUALIZATION FUNCTIONS =====================
def plot_confusion_matrix(y_true, y_pred, class_names, save_path, normalize=True):
"""Plot confusion matrix heatmap."""
cm = confusion_matrix(y_true, y_pred)
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
# Raw counts
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=class_names, yticklabels=class_names, ax=axes[0])
axes[0].set_title("Confusion Matrix (Counts)", fontsize=14, fontweight="bold")
axes[0].set_xlabel("Predicted", fontsize=12)
axes[0].set_ylabel("True", fontsize=12)
axes[0].tick_params(axis="x", rotation=45)
# Normalized
cm_norm = cm.astype("float") / cm.sum(axis=1, keepdims=True)
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="YlOrRd",
xticklabels=class_names, yticklabels=class_names, ax=axes[1])
axes[1].set_title("Confusion Matrix (Normalized)", fontsize=14, fontweight="bold")
axes[1].set_xlabel("Predicted", fontsize=12)
axes[1].set_ylabel("True", fontsize=12)
axes[1].tick_params(axis="x", rotation=45)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f" Saved: {save_path}")
def plot_roc_curves(y_true, y_probs, class_names, num_classes, save_path):
"""Plot ROC curves for each class + macro average."""
fig, ax = plt.subplots(figsize=(10, 8))
colors = plt.cm.Set3(np.linspace(0, 1, num_classes))
all_fpr = {}
all_tpr = {}
all_auc = {}
for i, cls_name in enumerate(class_names):
binary_true = (y_true == i).astype(int)
fpr, tpr, _ = roc_curve(binary_true, y_probs[:, i])
roc_auc_val = auc(fpr, tpr)
all_fpr[i] = fpr
all_tpr[i] = tpr
all_auc[i] = roc_auc_val
ax.plot(fpr, tpr, color=colors[i], lw=2,
label=f"{cls_name} (AUC = {roc_auc_val:.3f})")
# Macro average
all_fpr_concat = np.unique(np.concatenate([all_fpr[i] for i in range(num_classes)]))
mean_tpr = np.zeros_like(all_fpr_concat)
for i in range(num_classes):
mean_tpr += np.interp(all_fpr_concat, all_fpr[i], all_tpr[i])
mean_tpr /= num_classes
macro_auc = auc(all_fpr_concat, mean_tpr)
ax.plot(all_fpr_concat, mean_tpr, color="navy", lw=3, linestyle="--",
label=f"Macro Average (AUC = {macro_auc:.3f})")
ax.plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5)
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("False Positive Rate", fontsize=12)
ax.set_ylabel("True Positive Rate", fontsize=12)
ax.set_title("ROC Curves - Rangoli Classification", fontsize=14, fontweight="bold")
ax.legend(loc="lower right", fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f" Saved: {save_path}")
def plot_precision_recall_curves(y_true, y_probs, class_names, num_classes, save_path):
"""Plot Precision-Recall curves."""
fig, ax = plt.subplots(figsize=(10, 8))
colors = plt.cm.Set3(np.linspace(0, 1, num_classes))
for i, cls_name in enumerate(class_names):
binary_true = (y_true == i).astype(int)
precision, recall, _ = precision_recall_curve(binary_true, y_probs[:, i])
ap = average_precision_score(binary_true, y_probs[:, i])
ax.plot(recall, precision, color=colors[i], lw=2,
label=f"{cls_name} (AP = {ap:.3f})")
ax.set_xlabel("Recall", fontsize=12)
ax.set_ylabel("Precision", fontsize=12)
ax.set_title("Precision-Recall Curves", fontsize=14, fontweight="bold")
ax.legend(loc="lower left", fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f" Saved: {save_path}")
def plot_tsne_embeddings(features, labels, class_names, save_path, perplexity=30):
"""Plot t-SNE visualization of learned features."""
print(" Computing t-SNE embeddings (this may take a minute)...")
# Subsample if too many points
max_samples = 2000
if len(features) > max_samples:
idx = np.random.choice(len(features), max_samples, replace=False)
features = features[idx]
labels = labels[idx]
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42,
n_iter=1000, learning_rate="auto", init="pca")
embeddings = tsne.fit_transform(features)
fig, ax = plt.subplots(figsize=(12, 10))
colors = plt.cm.Set3(np.linspace(0, 1, len(class_names)))
for i, cls_name in enumerate(class_names):
mask = labels == i
ax.scatter(embeddings[mask, 0], embeddings[mask, 1],
c=[colors[i]], s=30, alpha=0.7, label=cls_name, edgecolors="white", linewidths=0.5)
ax.set_title("t-SNE Feature Visualization", fontsize=14, fontweight="bold")
ax.legend(loc="best", fontsize=10, markerscale=2)
ax.grid(True, alpha=0.2)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f" Saved: {save_path}")
def plot_per_class_accuracy(metrics, class_names, save_path):
"""Plot per-class accuracy bar chart."""
per_class = metrics["per_class"]
accuracies = [per_class[cls]["f1"] for cls in class_names]
precisions = [per_class[cls]["precision"] for cls in class_names]
recalls = [per_class[cls]["recall"] for cls in class_names]
x = np.arange(len(class_names))
width = 0.25
fig, ax = plt.subplots(figsize=(14, 6))
ax.bar(x - width, precisions, width, label="Precision", color="#3498db", alpha=0.8)
ax.bar(x, recalls, width, label="Recall", color="#e74c3c", alpha=0.8)
ax.bar(x + width, accuracies, width, label="F1-Score", color="#2ecc71", alpha=0.8)
ax.set_xlabel("Rangoli Class", fontsize=12)
ax.set_ylabel("Score", fontsize=12)
ax.set_title("Per-Class Classification Metrics", fontsize=14, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45, ha="right")
ax.legend()
ax.grid(True, alpha=0.3, axis="y")
ax.set_ylim(0, 1.1)
# Add value labels
for i, v in enumerate(accuracies):
ax.text(i + width, v + 0.02, f"{v:.2f}", ha="center", fontsize=8)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f" Saved: {save_path}")
def plot_training_curves(history_path, save_path):
"""Plot training and validation curves."""
with open(history_path) as f:
history = json.load(f)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
epochs = range(1, len(history["train_loss"]) + 1)
# Loss
axes[0].plot(epochs, history["train_loss"], "b-", label="Train Loss", linewidth=2)
axes[0].plot(epochs, history["val_loss"], "r-", label="Val Loss", linewidth=2)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training & Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Accuracy
axes[1].plot(epochs, history["train_acc"], "b-", label="Train Accuracy", linewidth=2)
axes[1].plot(epochs, history["val_acc"], "r-", label="Val Accuracy", linewidth=2)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].set_title("Training & Validation Accuracy")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f" Saved: {save_path}")
def generate_latex_table(all_results, class_names, save_path):
"""Generate LaTeX table for IEEE paper."""
lines = []
lines.append(r"\begin{table*}[htbp]")
lines.append(r"\centering")
lines.append(r"\caption{Comparative Performance of Deep Learning Models for Rangoli Classification}")
lines.append(r"\label{tab:results}")
lines.append(r"\begin{tabular}{lcccccccc}")
lines.append(r"\hline")
lines.append(r"\textbf{Model} & \textbf{Accuracy} & \textbf{Precision} & \textbf{Recall} & "
r"\textbf{F1-Score} & \textbf{AUC-ROC} & \textbf{Kappa} & \textbf{MCC} & \textbf{Params (M)} \\")
lines.append(r"\hline")
for model_name, res in sorted(all_results.items(), key=lambda x: x[1].get("accuracy", 0), reverse=True):
line = (f"{model_name} & "
f"{res.get('accuracy', 0):.4f} & "
f"{res.get('precision_macro', 0):.4f} & "
f"{res.get('recall_macro', 0):.4f} & "
f"{res.get('f1_macro', 0):.4f} & "
f"{res.get('roc_auc_macro', 0):.4f} & "
f"{res.get('cohen_kappa', 0):.4f} & "
f"{res.get('matthews_corrcoef', 0):.4f} & "
f"{res.get('params_millions', 'N/A')} \\\\")
lines.append(line)
lines.append(r"\hline")
lines.append(r"\end{tabular}")
lines.append(r"\end{table*}")
latex_table = "\n".join(lines)
with open(save_path, "w") as f:
f.write(latex_table)
print(f" LaTeX table saved: {save_path}")
return latex_table
def plot_model_comparison(all_results, save_path):
"""Plot comparative model performance bar chart."""
model_names = list(all_results.keys())
metrics_to_plot = ["accuracy", "precision_macro", "recall_macro", "f1_macro"]
metric_labels = ["Accuracy", "Precision", "Recall", "F1-Score"]
x = np.arange(len(model_names))
width = 0.2
fig, ax = plt.subplots(figsize=(14, 7))
colors = ["#3498db", "#e74c3c", "#2ecc71", "#f39c12"]
for i, (metric, label) in enumerate(zip(metrics_to_plot, metric_labels)):
values = [all_results[m].get(metric, 0) for m in model_names]
ax.bar(x + i * width, values, width, label=label, color=colors[i], alpha=0.85)
ax.set_xlabel("Model Architecture", fontsize=12)
ax.set_ylabel("Score", fontsize=12)
ax.set_title("Comparative Model Performance - Rangoli Classification",
fontsize=14, fontweight="bold")
ax.set_xticks(x + width * 1.5)
ax.set_xticklabels(model_names, rotation=30, ha="right")
ax.legend()
ax.grid(True, alpha=0.3, axis="y")
ax.set_ylim(0, 1.1)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f" Saved: {save_path}")
# ===================== GRAD-CAM =====================
def generate_gradcam(model, images, targets, class_names, save_path, device, num_samples=16):
"""Generate Grad-CAM visualizations."""
try:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
except ImportError:
print(" [WARNING] pytorch-grad-cam not installed. Skipping Grad-CAM.")
return
model.eval()
# Find the last conv layer
target_layers = None
for name, module in model.backbone.named_modules():
if isinstance(module, torch.nn.Conv2d):
target_layers = [module]
if target_layers is None:
print(" [WARNING] Could not find conv layers for Grad-CAM")
return
cam = GradCAM(model=model, target_layers=target_layers)
fig, axes = plt.subplots(4, 4, figsize=(16, 16))
for idx in range(min(num_samples, 16)):
row, col = idx // 4, idx % 4
img_tensor = images[idx].unsqueeze(0).to(device)
target = targets[idx]
# Generate CAM
targets_cam = [ClassifierOutputTarget(target)]
grayscale_cam = cam(input_tensor=img_tensor, targets=targets_cam)
grayscale_cam = grayscale_cam[0]
# Denormalize image
img = images[idx].permute(1, 2, 0).numpy()
img = (img - img.min()) / (img.max() - img.min() + 1e-8)
visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
axes[row, col].imshow(visualization)
axes[row, col].set_title(f"True: {class_names[target]}", fontsize=9)
axes[row, col].axis("off")
plt.suptitle("Grad-CAM Visualizations", fontsize=16, fontweight="bold")
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
print(f" Saved: {save_path}")
# ===================== MAIN =====================
def evaluate_model(model_name, config, device, checkpoint_dir=None):
"""Full evaluation pipeline for one model."""
print(f"\n{'='*60}")
print(f" EVALUATING: {model_name.upper()}")
print(f"{'='*60}")
figures_dir = config["paths"]["figures"]
os.makedirs(figures_dir, exist_ok=True)
# Find best checkpoint
if checkpoint_dir is None:
ckpt_base = config["paths"]["checkpoints"]
# Find latest run for this model
runs = [d for d in os.listdir(ckpt_base) if d.startswith(model_name)]
if not runs:
print(f" [ERROR] No checkpoints found for {model_name}")
return None
latest_run = sorted(runs)[-1]
checkpoint_dir = os.path.join(ckpt_base, latest_run)
ckpt_path = os.path.join(checkpoint_dir, f"{model_name}_best.pth")
if not os.path.exists(ckpt_path):
print(f" [ERROR] Checkpoint not found: {ckpt_path}")
return None
# Load model
model, checkpoint = load_model_from_checkpoint(ckpt_path, config, device)
print(f" Loaded checkpoint: epoch {checkpoint['epoch']}, val_acc={checkpoint['val_acc']:.4f}")
# Load data
manifest_path = os.path.join(config["paths"]["processed_data"], "dataset_manifest.json")
_, _, test_loader, class_to_idx = create_dataloaders(config, manifest_path)
idx_to_class = {v: k for k, v in class_to_idx.items()}
class_names = [idx_to_class[i] for i in range(len(class_to_idx))]
num_classes = len(class_names)
# Get predictions
y_pred, y_probs, y_true, features, paths = get_predictions(
model, test_loader, device, num_classes
)
# Compute metrics
metrics = compute_all_metrics(y_true, y_pred, y_probs, class_names, num_classes)
# Print results
print(f"\n --- Test Results ---")
print(f" Accuracy: {metrics['accuracy']:.4f}")
print(f" Precision (M): {metrics['precision_macro']:.4f}")
print(f" Recall (M): {metrics['recall_macro']:.4f}")
print(f" F1-Score (M): {metrics['f1_macro']:.4f}")
print(f" Cohen Kappa: {metrics['cohen_kappa']:.4f}")
print(f" MCC: {metrics['matthews_corrcoef']:.4f}")
print(f" Top-3 Acc: {metrics['top_3_accuracy']:.4f}")
print(f" Top-5 Acc: {metrics['top_5_accuracy']:.4f}")
# Generate all visualizations
prefix = f"{model_name}"
plot_confusion_matrix(y_true, y_pred, class_names,
os.path.join(figures_dir, f"{prefix}_confusion_matrix.png"))
plot_roc_curves(y_true, y_probs, class_names, num_classes,
os.path.join(figures_dir, f"{prefix}_roc_curves.png"))
plot_precision_recall_curves(y_true, y_probs, class_names, num_classes,
os.path.join(figures_dir, f"{prefix}_pr_curves.png"))
plot_tsne_embeddings(features, y_true, class_names,
os.path.join(figures_dir, f"{prefix}_tsne.png"))
plot_per_class_accuracy(metrics, class_names,
os.path.join(figures_dir, f"{prefix}_per_class.png"))
# Training curves
history_path = os.path.join(checkpoint_dir, "training_history.json")
if os.path.exists(history_path):
plot_training_curves(history_path,
os.path.join(figures_dir, f"{prefix}_training_curves.png"))
# Save metrics
metrics_path = os.path.join(config["paths"]["reports"], f"{prefix}_metrics.json")
os.makedirs(os.path.dirname(metrics_path), exist_ok=True)
# Convert numpy types for JSON serialization
def convert(o):
if isinstance(o, np.integer): return int(o)
if isinstance(o, np.floating): return float(o)
if isinstance(o, np.ndarray): return o.tolist()
return o
with open(metrics_path, "w") as f:
json.dump(metrics, f, indent=2, default=convert)
print(f"\n Metrics saved: {metrics_path}")
return metrics
def main():
parser = argparse.ArgumentParser(description="Evaluate Rangoli Classifier")
parser.add_argument("--config", type=str, default="configs/config.yaml")
parser.add_argument("--model", type=str, default="resnet50",
choices=["resnet50", "efficientnet_b3", "vit_base",
"convnext_small", "mobilenet_v3", "swin_transformer", "all"])
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--compare", action="store_true", help="Generate comparison plots")
parser.add_argument("--gpu", type=int, default=0)
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.safe_load(f)
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
if args.model == "all" or args.compare:
all_results = {}
for model_name in config["models"].keys():
metrics = evaluate_model(model_name, config, device)
if metrics:
all_results[model_name] = metrics
if len(all_results) > 1:
figures_dir = config["paths"]["figures"]
reports_dir = config["paths"]["reports"]
plot_model_comparison(all_results,
os.path.join(figures_dir, "model_comparison.png"))
latex = generate_latex_table(
all_results, config["classes"],
os.path.join(reports_dir, "results_table.tex")
)
print(f"\n LaTeX Table:\n{latex}")
else:
evaluate_model(args.model, config, device, args.checkpoint)
if __name__ == "__main__":
main()