Spaces:
Runtime error
Runtime error
| """ | |
| ============================================================ | |
| Comprehensive Evaluation Pipeline for IEEE Paper | |
| ============================================================ | |
| Generates all metrics and visualizations needed for publication: | |
| - Accuracy, Precision, Recall, F1 (macro/micro/weighted) | |
| - Confusion Matrix Heatmap | |
| - ROC Curves (per-class + macro) | |
| - Precision-Recall Curves | |
| - Cohen's Kappa, MCC | |
| - t-SNE Feature Embeddings | |
| - Grad-CAM Visualizations | |
| - Per-Class Accuracy Bar Chart | |
| - Model Comparison Table (LaTeX) | |
| - Training Curves | |
| - Statistical Significance Tests | |
| - K-Fold Cross-Validation Results | |
| Usage: | |
| python scripts/evaluate.py --config configs/config.yaml --model resnet50 | |
| python scripts/evaluate.py --config configs/config.yaml --model all --compare | |
| ============================================================ | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import yaml | |
| import argparse | |
| import numpy as np | |
| from pathlib import Path | |
| from collections import OrderedDict | |
| from datetime import datetime | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.cuda.amp import autocast | |
| from tqdm import tqdm | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import ( | |
| accuracy_score, precision_score, recall_score, f1_score, | |
| confusion_matrix, classification_report, | |
| roc_curve, auc, roc_auc_score, | |
| precision_recall_curve, average_precision_score, | |
| cohen_kappa_score, matthews_corrcoef, | |
| top_k_accuracy_score, | |
| ) | |
| from sklearn.manifold import TSNE | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from dataset.rangoli_dataset import ( | |
| RangoliDataset, get_val_transforms, get_tta_transforms, create_dataloaders | |
| ) | |
| from models.classifier import build_model | |
| def load_model_from_checkpoint(checkpoint_path, config, device): | |
| """Load a trained model from checkpoint.""" | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model_name = checkpoint["model_name"] | |
| model = build_model(model_name, config).to(device) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model.eval() | |
| return model, checkpoint | |
| def get_predictions(model, data_loader, device, num_classes, use_tta=False, | |
| tta_transforms=None): | |
| """Get model predictions on a dataset.""" | |
| model.eval() | |
| all_probs = [] | |
| all_targets = [] | |
| all_features = [] | |
| all_paths = [] | |
| for batch in tqdm(data_loader, desc=" Predicting", leave=False): | |
| if len(batch) == 3: | |
| images, targets, paths = batch | |
| all_paths.extend(paths) | |
| else: | |
| images, targets = batch | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| with autocast(enabled=device.type == "cuda"): | |
| logits, features = model(images, return_features=True) | |
| probs = F.softmax(logits, dim=1) | |
| all_probs.append(probs.cpu().numpy()) | |
| all_targets.append(targets.cpu().numpy()) | |
| all_features.append(features.cpu().numpy()) | |
| all_probs = np.concatenate(all_probs, axis=0) | |
| all_targets = np.concatenate(all_targets, axis=0) | |
| all_features = np.concatenate(all_features, axis=0) | |
| all_preds = np.argmax(all_probs, axis=1) | |
| return all_preds, all_probs, all_targets, all_features, all_paths | |
| def compute_all_metrics(y_true, y_pred, y_probs, class_names, num_classes): | |
| """Compute comprehensive evaluation metrics.""" | |
| metrics = OrderedDict() | |
| # Basic metrics | |
| metrics["accuracy"] = accuracy_score(y_true, y_pred) | |
| metrics["precision_macro"] = precision_score(y_true, y_pred, average="macro", zero_division=0) | |
| metrics["precision_weighted"] = precision_score(y_true, y_pred, average="weighted", zero_division=0) | |
| metrics["recall_macro"] = recall_score(y_true, y_pred, average="macro", zero_division=0) | |
| metrics["recall_weighted"] = recall_score(y_true, y_pred, average="weighted", zero_division=0) | |
| metrics["f1_macro"] = f1_score(y_true, y_pred, average="macro", zero_division=0) | |
| metrics["f1_weighted"] = f1_score(y_true, y_pred, average="weighted", zero_division=0) | |
| metrics["f1_micro"] = f1_score(y_true, y_pred, average="micro", zero_division=0) | |
| # Advanced metrics | |
| metrics["cohen_kappa"] = cohen_kappa_score(y_true, y_pred) | |
| metrics["matthews_corrcoef"] = matthews_corrcoef(y_true, y_pred) | |
| # Top-K Accuracy | |
| metrics["top_3_accuracy"] = top_k_accuracy_score(y_true, y_probs, k=min(3, num_classes)) | |
| metrics["top_5_accuracy"] = top_k_accuracy_score(y_true, y_probs, k=min(5, num_classes)) | |
| # ROC-AUC (One-vs-Rest) | |
| try: | |
| if num_classes == 2: | |
| metrics["roc_auc"] = roc_auc_score(y_true, y_probs[:, 1]) | |
| else: | |
| metrics["roc_auc_macro"] = roc_auc_score(y_true, y_probs, multi_class="ovr", average="macro") | |
| metrics["roc_auc_weighted"] = roc_auc_score(y_true, y_probs, multi_class="ovr", average="weighted") | |
| except Exception as e: | |
| metrics["roc_auc_note"] = f"Could not compute: {str(e)}" | |
| # Per-class metrics | |
| per_class = {} | |
| for i, cls_name in enumerate(class_names): | |
| binary_true = (y_true == i).astype(int) | |
| binary_pred = (y_pred == i).astype(int) | |
| per_class[cls_name] = { | |
| "precision": precision_score(binary_true, binary_pred, zero_division=0), | |
| "recall": recall_score(binary_true, binary_pred, zero_division=0), | |
| "f1": f1_score(binary_true, binary_pred, zero_division=0), | |
| "support": int(binary_true.sum()), | |
| "accuracy": accuracy_score(binary_true, binary_pred), | |
| } | |
| try: | |
| per_class[cls_name]["auc"] = roc_auc_score(binary_true, y_probs[:, i]) | |
| except: | |
| per_class[cls_name]["auc"] = 0.0 | |
| metrics["per_class"] = per_class | |
| # Confusion Matrix | |
| metrics["confusion_matrix"] = confusion_matrix(y_true, y_pred).tolist() | |
| # Classification Report | |
| metrics["classification_report"] = classification_report( | |
| y_true, y_pred, target_names=class_names, output_dict=True | |
| ) | |
| return metrics | |
| # ===================== VISUALIZATION FUNCTIONS ===================== | |
| def plot_confusion_matrix(y_true, y_pred, class_names, save_path, normalize=True): | |
| """Plot confusion matrix heatmap.""" | |
| cm = confusion_matrix(y_true, y_pred) | |
| fig, axes = plt.subplots(1, 2, figsize=(20, 8)) | |
| # Raw counts | |
| sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", | |
| xticklabels=class_names, yticklabels=class_names, ax=axes[0]) | |
| axes[0].set_title("Confusion Matrix (Counts)", fontsize=14, fontweight="bold") | |
| axes[0].set_xlabel("Predicted", fontsize=12) | |
| axes[0].set_ylabel("True", fontsize=12) | |
| axes[0].tick_params(axis="x", rotation=45) | |
| # Normalized | |
| cm_norm = cm.astype("float") / cm.sum(axis=1, keepdims=True) | |
| sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="YlOrRd", | |
| xticklabels=class_names, yticklabels=class_names, ax=axes[1]) | |
| axes[1].set_title("Confusion Matrix (Normalized)", fontsize=14, fontweight="bold") | |
| axes[1].set_xlabel("Predicted", fontsize=12) | |
| axes[1].set_ylabel("True", fontsize=12) | |
| axes[1].tick_params(axis="x", rotation=45) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved: {save_path}") | |
| def plot_roc_curves(y_true, y_probs, class_names, num_classes, save_path): | |
| """Plot ROC curves for each class + macro average.""" | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| colors = plt.cm.Set3(np.linspace(0, 1, num_classes)) | |
| all_fpr = {} | |
| all_tpr = {} | |
| all_auc = {} | |
| for i, cls_name in enumerate(class_names): | |
| binary_true = (y_true == i).astype(int) | |
| fpr, tpr, _ = roc_curve(binary_true, y_probs[:, i]) | |
| roc_auc_val = auc(fpr, tpr) | |
| all_fpr[i] = fpr | |
| all_tpr[i] = tpr | |
| all_auc[i] = roc_auc_val | |
| ax.plot(fpr, tpr, color=colors[i], lw=2, | |
| label=f"{cls_name} (AUC = {roc_auc_val:.3f})") | |
| # Macro average | |
| all_fpr_concat = np.unique(np.concatenate([all_fpr[i] for i in range(num_classes)])) | |
| mean_tpr = np.zeros_like(all_fpr_concat) | |
| for i in range(num_classes): | |
| mean_tpr += np.interp(all_fpr_concat, all_fpr[i], all_tpr[i]) | |
| mean_tpr /= num_classes | |
| macro_auc = auc(all_fpr_concat, mean_tpr) | |
| ax.plot(all_fpr_concat, mean_tpr, color="navy", lw=3, linestyle="--", | |
| label=f"Macro Average (AUC = {macro_auc:.3f})") | |
| ax.plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5) | |
| ax.set_xlim([0.0, 1.0]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel("False Positive Rate", fontsize=12) | |
| ax.set_ylabel("True Positive Rate", fontsize=12) | |
| ax.set_title("ROC Curves - Rangoli Classification", fontsize=14, fontweight="bold") | |
| ax.legend(loc="lower right", fontsize=9) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved: {save_path}") | |
| def plot_precision_recall_curves(y_true, y_probs, class_names, num_classes, save_path): | |
| """Plot Precision-Recall curves.""" | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| colors = plt.cm.Set3(np.linspace(0, 1, num_classes)) | |
| for i, cls_name in enumerate(class_names): | |
| binary_true = (y_true == i).astype(int) | |
| precision, recall, _ = precision_recall_curve(binary_true, y_probs[:, i]) | |
| ap = average_precision_score(binary_true, y_probs[:, i]) | |
| ax.plot(recall, precision, color=colors[i], lw=2, | |
| label=f"{cls_name} (AP = {ap:.3f})") | |
| ax.set_xlabel("Recall", fontsize=12) | |
| ax.set_ylabel("Precision", fontsize=12) | |
| ax.set_title("Precision-Recall Curves", fontsize=14, fontweight="bold") | |
| ax.legend(loc="lower left", fontsize=9) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved: {save_path}") | |
| def plot_tsne_embeddings(features, labels, class_names, save_path, perplexity=30): | |
| """Plot t-SNE visualization of learned features.""" | |
| print(" Computing t-SNE embeddings (this may take a minute)...") | |
| # Subsample if too many points | |
| max_samples = 2000 | |
| if len(features) > max_samples: | |
| idx = np.random.choice(len(features), max_samples, replace=False) | |
| features = features[idx] | |
| labels = labels[idx] | |
| tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42, | |
| n_iter=1000, learning_rate="auto", init="pca") | |
| embeddings = tsne.fit_transform(features) | |
| fig, ax = plt.subplots(figsize=(12, 10)) | |
| colors = plt.cm.Set3(np.linspace(0, 1, len(class_names))) | |
| for i, cls_name in enumerate(class_names): | |
| mask = labels == i | |
| ax.scatter(embeddings[mask, 0], embeddings[mask, 1], | |
| c=[colors[i]], s=30, alpha=0.7, label=cls_name, edgecolors="white", linewidths=0.5) | |
| ax.set_title("t-SNE Feature Visualization", fontsize=14, fontweight="bold") | |
| ax.legend(loc="best", fontsize=10, markerscale=2) | |
| ax.grid(True, alpha=0.2) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved: {save_path}") | |
| def plot_per_class_accuracy(metrics, class_names, save_path): | |
| """Plot per-class accuracy bar chart.""" | |
| per_class = metrics["per_class"] | |
| accuracies = [per_class[cls]["f1"] for cls in class_names] | |
| precisions = [per_class[cls]["precision"] for cls in class_names] | |
| recalls = [per_class[cls]["recall"] for cls in class_names] | |
| x = np.arange(len(class_names)) | |
| width = 0.25 | |
| fig, ax = plt.subplots(figsize=(14, 6)) | |
| ax.bar(x - width, precisions, width, label="Precision", color="#3498db", alpha=0.8) | |
| ax.bar(x, recalls, width, label="Recall", color="#e74c3c", alpha=0.8) | |
| ax.bar(x + width, accuracies, width, label="F1-Score", color="#2ecc71", alpha=0.8) | |
| ax.set_xlabel("Rangoli Class", fontsize=12) | |
| ax.set_ylabel("Score", fontsize=12) | |
| ax.set_title("Per-Class Classification Metrics", fontsize=14, fontweight="bold") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(class_names, rotation=45, ha="right") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3, axis="y") | |
| ax.set_ylim(0, 1.1) | |
| # Add value labels | |
| for i, v in enumerate(accuracies): | |
| ax.text(i + width, v + 0.02, f"{v:.2f}", ha="center", fontsize=8) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved: {save_path}") | |
| def plot_training_curves(history_path, save_path): | |
| """Plot training and validation curves.""" | |
| with open(history_path) as f: | |
| history = json.load(f) | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| epochs = range(1, len(history["train_loss"]) + 1) | |
| # Loss | |
| axes[0].plot(epochs, history["train_loss"], "b-", label="Train Loss", linewidth=2) | |
| axes[0].plot(epochs, history["val_loss"], "r-", label="Val Loss", linewidth=2) | |
| axes[0].set_xlabel("Epoch") | |
| axes[0].set_ylabel("Loss") | |
| axes[0].set_title("Training & Validation Loss") | |
| axes[0].legend() | |
| axes[0].grid(True, alpha=0.3) | |
| # Accuracy | |
| axes[1].plot(epochs, history["train_acc"], "b-", label="Train Accuracy", linewidth=2) | |
| axes[1].plot(epochs, history["val_acc"], "r-", label="Val Accuracy", linewidth=2) | |
| axes[1].set_xlabel("Epoch") | |
| axes[1].set_ylabel("Accuracy") | |
| axes[1].set_title("Training & Validation Accuracy") | |
| axes[1].legend() | |
| axes[1].grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved: {save_path}") | |
| def generate_latex_table(all_results, class_names, save_path): | |
| """Generate LaTeX table for IEEE paper.""" | |
| lines = [] | |
| lines.append(r"\begin{table*}[htbp]") | |
| lines.append(r"\centering") | |
| lines.append(r"\caption{Comparative Performance of Deep Learning Models for Rangoli Classification}") | |
| lines.append(r"\label{tab:results}") | |
| lines.append(r"\begin{tabular}{lcccccccc}") | |
| lines.append(r"\hline") | |
| lines.append(r"\textbf{Model} & \textbf{Accuracy} & \textbf{Precision} & \textbf{Recall} & " | |
| r"\textbf{F1-Score} & \textbf{AUC-ROC} & \textbf{Kappa} & \textbf{MCC} & \textbf{Params (M)} \\") | |
| lines.append(r"\hline") | |
| for model_name, res in sorted(all_results.items(), key=lambda x: x[1].get("accuracy", 0), reverse=True): | |
| line = (f"{model_name} & " | |
| f"{res.get('accuracy', 0):.4f} & " | |
| f"{res.get('precision_macro', 0):.4f} & " | |
| f"{res.get('recall_macro', 0):.4f} & " | |
| f"{res.get('f1_macro', 0):.4f} & " | |
| f"{res.get('roc_auc_macro', 0):.4f} & " | |
| f"{res.get('cohen_kappa', 0):.4f} & " | |
| f"{res.get('matthews_corrcoef', 0):.4f} & " | |
| f"{res.get('params_millions', 'N/A')} \\\\") | |
| lines.append(line) | |
| lines.append(r"\hline") | |
| lines.append(r"\end{tabular}") | |
| lines.append(r"\end{table*}") | |
| latex_table = "\n".join(lines) | |
| with open(save_path, "w") as f: | |
| f.write(latex_table) | |
| print(f" LaTeX table saved: {save_path}") | |
| return latex_table | |
| def plot_model_comparison(all_results, save_path): | |
| """Plot comparative model performance bar chart.""" | |
| model_names = list(all_results.keys()) | |
| metrics_to_plot = ["accuracy", "precision_macro", "recall_macro", "f1_macro"] | |
| metric_labels = ["Accuracy", "Precision", "Recall", "F1-Score"] | |
| x = np.arange(len(model_names)) | |
| width = 0.2 | |
| fig, ax = plt.subplots(figsize=(14, 7)) | |
| colors = ["#3498db", "#e74c3c", "#2ecc71", "#f39c12"] | |
| for i, (metric, label) in enumerate(zip(metrics_to_plot, metric_labels)): | |
| values = [all_results[m].get(metric, 0) for m in model_names] | |
| ax.bar(x + i * width, values, width, label=label, color=colors[i], alpha=0.85) | |
| ax.set_xlabel("Model Architecture", fontsize=12) | |
| ax.set_ylabel("Score", fontsize=12) | |
| ax.set_title("Comparative Model Performance - Rangoli Classification", | |
| fontsize=14, fontweight="bold") | |
| ax.set_xticks(x + width * 1.5) | |
| ax.set_xticklabels(model_names, rotation=30, ha="right") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3, axis="y") | |
| ax.set_ylim(0, 1.1) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved: {save_path}") | |
| # ===================== GRAD-CAM ===================== | |
| def generate_gradcam(model, images, targets, class_names, save_path, device, num_samples=16): | |
| """Generate Grad-CAM visualizations.""" | |
| try: | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| except ImportError: | |
| print(" [WARNING] pytorch-grad-cam not installed. Skipping Grad-CAM.") | |
| return | |
| model.eval() | |
| # Find the last conv layer | |
| target_layers = None | |
| for name, module in model.backbone.named_modules(): | |
| if isinstance(module, torch.nn.Conv2d): | |
| target_layers = [module] | |
| if target_layers is None: | |
| print(" [WARNING] Could not find conv layers for Grad-CAM") | |
| return | |
| cam = GradCAM(model=model, target_layers=target_layers) | |
| fig, axes = plt.subplots(4, 4, figsize=(16, 16)) | |
| for idx in range(min(num_samples, 16)): | |
| row, col = idx // 4, idx % 4 | |
| img_tensor = images[idx].unsqueeze(0).to(device) | |
| target = targets[idx] | |
| # Generate CAM | |
| targets_cam = [ClassifierOutputTarget(target)] | |
| grayscale_cam = cam(input_tensor=img_tensor, targets=targets_cam) | |
| grayscale_cam = grayscale_cam[0] | |
| # Denormalize image | |
| img = images[idx].permute(1, 2, 0).numpy() | |
| img = (img - img.min()) / (img.max() - img.min() + 1e-8) | |
| visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True) | |
| axes[row, col].imshow(visualization) | |
| axes[row, col].set_title(f"True: {class_names[target]}", fontsize=9) | |
| axes[row, col].axis("off") | |
| plt.suptitle("Grad-CAM Visualizations", fontsize=16, fontweight="bold") | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print(f" Saved: {save_path}") | |
| # ===================== MAIN ===================== | |
| def evaluate_model(model_name, config, device, checkpoint_dir=None): | |
| """Full evaluation pipeline for one model.""" | |
| print(f"\n{'='*60}") | |
| print(f" EVALUATING: {model_name.upper()}") | |
| print(f"{'='*60}") | |
| figures_dir = config["paths"]["figures"] | |
| os.makedirs(figures_dir, exist_ok=True) | |
| # Find best checkpoint | |
| if checkpoint_dir is None: | |
| ckpt_base = config["paths"]["checkpoints"] | |
| # Find latest run for this model | |
| runs = [d for d in os.listdir(ckpt_base) if d.startswith(model_name)] | |
| if not runs: | |
| print(f" [ERROR] No checkpoints found for {model_name}") | |
| return None | |
| latest_run = sorted(runs)[-1] | |
| checkpoint_dir = os.path.join(ckpt_base, latest_run) | |
| ckpt_path = os.path.join(checkpoint_dir, f"{model_name}_best.pth") | |
| if not os.path.exists(ckpt_path): | |
| print(f" [ERROR] Checkpoint not found: {ckpt_path}") | |
| return None | |
| # Load model | |
| model, checkpoint = load_model_from_checkpoint(ckpt_path, config, device) | |
| print(f" Loaded checkpoint: epoch {checkpoint['epoch']}, val_acc={checkpoint['val_acc']:.4f}") | |
| # Load data | |
| manifest_path = os.path.join(config["paths"]["processed_data"], "dataset_manifest.json") | |
| _, _, test_loader, class_to_idx = create_dataloaders(config, manifest_path) | |
| idx_to_class = {v: k for k, v in class_to_idx.items()} | |
| class_names = [idx_to_class[i] for i in range(len(class_to_idx))] | |
| num_classes = len(class_names) | |
| # Get predictions | |
| y_pred, y_probs, y_true, features, paths = get_predictions( | |
| model, test_loader, device, num_classes | |
| ) | |
| # Compute metrics | |
| metrics = compute_all_metrics(y_true, y_pred, y_probs, class_names, num_classes) | |
| # Print results | |
| print(f"\n --- Test Results ---") | |
| print(f" Accuracy: {metrics['accuracy']:.4f}") | |
| print(f" Precision (M): {metrics['precision_macro']:.4f}") | |
| print(f" Recall (M): {metrics['recall_macro']:.4f}") | |
| print(f" F1-Score (M): {metrics['f1_macro']:.4f}") | |
| print(f" Cohen Kappa: {metrics['cohen_kappa']:.4f}") | |
| print(f" MCC: {metrics['matthews_corrcoef']:.4f}") | |
| print(f" Top-3 Acc: {metrics['top_3_accuracy']:.4f}") | |
| print(f" Top-5 Acc: {metrics['top_5_accuracy']:.4f}") | |
| # Generate all visualizations | |
| prefix = f"{model_name}" | |
| plot_confusion_matrix(y_true, y_pred, class_names, | |
| os.path.join(figures_dir, f"{prefix}_confusion_matrix.png")) | |
| plot_roc_curves(y_true, y_probs, class_names, num_classes, | |
| os.path.join(figures_dir, f"{prefix}_roc_curves.png")) | |
| plot_precision_recall_curves(y_true, y_probs, class_names, num_classes, | |
| os.path.join(figures_dir, f"{prefix}_pr_curves.png")) | |
| plot_tsne_embeddings(features, y_true, class_names, | |
| os.path.join(figures_dir, f"{prefix}_tsne.png")) | |
| plot_per_class_accuracy(metrics, class_names, | |
| os.path.join(figures_dir, f"{prefix}_per_class.png")) | |
| # Training curves | |
| history_path = os.path.join(checkpoint_dir, "training_history.json") | |
| if os.path.exists(history_path): | |
| plot_training_curves(history_path, | |
| os.path.join(figures_dir, f"{prefix}_training_curves.png")) | |
| # Save metrics | |
| metrics_path = os.path.join(config["paths"]["reports"], f"{prefix}_metrics.json") | |
| os.makedirs(os.path.dirname(metrics_path), exist_ok=True) | |
| # Convert numpy types for JSON serialization | |
| def convert(o): | |
| if isinstance(o, np.integer): return int(o) | |
| if isinstance(o, np.floating): return float(o) | |
| if isinstance(o, np.ndarray): return o.tolist() | |
| return o | |
| with open(metrics_path, "w") as f: | |
| json.dump(metrics, f, indent=2, default=convert) | |
| print(f"\n Metrics saved: {metrics_path}") | |
| return metrics | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Evaluate Rangoli Classifier") | |
| parser.add_argument("--config", type=str, default="configs/config.yaml") | |
| parser.add_argument("--model", type=str, default="resnet50", | |
| choices=["resnet50", "efficientnet_b3", "vit_base", | |
| "convnext_small", "mobilenet_v3", "swin_transformer", "all"]) | |
| parser.add_argument("--checkpoint", type=str, default=None) | |
| parser.add_argument("--compare", action="store_true", help="Generate comparison plots") | |
| parser.add_argument("--gpu", type=int, default=0) | |
| args = parser.parse_args() | |
| with open(args.config, "r") as f: | |
| config = yaml.safe_load(f) | |
| device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") | |
| if args.model == "all" or args.compare: | |
| all_results = {} | |
| for model_name in config["models"].keys(): | |
| metrics = evaluate_model(model_name, config, device) | |
| if metrics: | |
| all_results[model_name] = metrics | |
| if len(all_results) > 1: | |
| figures_dir = config["paths"]["figures"] | |
| reports_dir = config["paths"]["reports"] | |
| plot_model_comparison(all_results, | |
| os.path.join(figures_dir, "model_comparison.png")) | |
| latex = generate_latex_table( | |
| all_results, config["classes"], | |
| os.path.join(reports_dir, "results_table.tex") | |
| ) | |
| print(f"\n LaTeX Table:\n{latex}") | |
| else: | |
| evaluate_model(args.model, config, device, args.checkpoint) | |
| if __name__ == "__main__": | |
| main() | |