EyeDiseaseDetection / utils /Comparator.py
GilbertKrantz
FIX : ty Type Error
6e8e8fb
import copy
import torch
from torch.utils.data import DataLoader
from utils.Trainer import model_train
from utils.DatasetHandler import FilteredImageDataset
from utils.Evaluator import ClassificationEvaluator
def compare_models(
models: list,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: DataLoader,
dataset: FilteredImageDataset,
epochs: int = 20,
names: list | None = None,
) -> None:
"""
Compare multiple models on validation and test datasets.
Args:
models (list): List of models to compare.
train_loader (DataLoader): DataLoader for training data.
val_loader (DataLoader): DataLoader for validation data.
test_loader (DataLoader): DataLoader for test data.
dataset (FilteredImageDataset): Dataset object containing class names.
epochs (int): Number of epochs for training.
names (list | None): List of model names. If None, default names will be used.
"""
if names is None:
names = [f"Model {i+1}" for i in range(len(models))]
val_results = {}
test_results = {}
best_model_obj = None
best_accuracy = -1
best_model_name = ""
# Summary dictionaries for metrics
val_roc_auc_summary = {}
test_roc_auc_summary = {}
val_pr_auc_summary = {}
test_pr_auc_summary = {}
val_kappa_summary = {}
test_kappa_summary = {}
for i, (model, name) in enumerate(zip(models, names)):
evaluator = ClassificationEvaluator(
class_names=dataset.classes,
)
print(f"\n\n{'#'*30} Training {name} ({i+1}/{len(models)}) {'#'*30}\n")
model_results = model_train(model, train_loader, val_loader, dataset, epochs)
# Extract accuracy from results
accuracy = model_results.get("accuracy")
val_results[name] = accuracy
# Extract and store metrics
if "roc_auc" in model_results and "micro" in model_results["roc_auc"]:
val_roc_auc_summary[name] = model_results["roc_auc"]["micro"]
else:
val_roc_auc_summary[name] = None
if "pr_auc" in model_results and "micro" in model_results["pr_auc"]:
val_pr_auc_summary[name] = model_results["pr_auc"]["micro"]
else:
val_pr_auc_summary[name] = None
# Store kappa score
if "kappa" in model_results:
val_kappa_summary[name] = model_results["kappa"]
else:
val_kappa_summary[name] = None
# Evaluate on test set
if accuracy is not None:
print(f"\n{'='*20} Testing {name} on Test Set {'='*20}\n")
test_model_results = evaluator.evaluate_model(model, test_loader)
# Extract accuracy from test results
test_accuracy = test_model_results.get("accuracy")
test_results[name] = test_accuracy
# Extract and store test metrics
if (
"roc_auc" in test_model_results
and "micro" in test_model_results["roc_auc"]
):
test_roc_auc_summary[name] = test_model_results["roc_auc"]["micro"]
else:
test_roc_auc_summary[name] = None
if (
"pr_auc" in test_model_results
and "micro" in test_model_results["pr_auc"]
):
test_pr_auc_summary[name] = test_model_results["pr_auc"]["micro"]
else:
test_pr_auc_summary[name] = None
# Store test kappa score
if "kappa" in test_model_results:
test_kappa_summary[name] = test_model_results["kappa"]
else:
test_kappa_summary[name] = None
# Track best model
if test_accuracy is not None and test_accuracy > best_accuracy:
best_accuracy = test_accuracy
best_model_obj = copy.deepcopy(model)
best_model_name = name
# Print comprehensive comparison
print("\n\n" + "=" * 100)
print("COMPREHENSIVE MODEL COMPARISON")
print("=" * 100)
print(
f"{'Model':<20}{'Val Acc':<10}{'Test Acc':<10}{'Val ROC AUC':<14}{'Test ROC AUC':<14}{'Val PR AUC':<14}{'Test PR AUC':<14}{'Val Kappa':<14}{'Test Kappa':<14}"
)
print("-" * 100)
for name in val_results.keys():
val_acc = val_results[name]
test_acc = test_results.get(name, None)
val_roc = val_roc_auc_summary.get(name, None)
test_roc = test_roc_auc_summary.get(name, None)
val_pr = val_pr_auc_summary.get(name, None)
test_pr = test_pr_auc_summary.get(name, None)
val_kappa = val_kappa_summary.get(name, None)
test_kappa = test_kappa_summary.get(name, None)
# Format values for display
val_acc_str = f"{val_acc:.4f}" if val_acc is not None else "Failed"
test_acc_str = f"{test_acc:.4f}" if test_acc is not None else "N/A"
val_roc_str = f"{val_roc:.4f}" if val_roc is not None else "N/A"
test_roc_str = f"{test_roc:.4f}" if test_roc is not None else "N/A"
val_pr_str = f"{val_pr:.4f}" if val_pr is not None else "N/A"
test_pr_str = f"{test_pr:.4f}" if test_pr is not None else "N/A"
val_kappa_str = f"{val_kappa:.4f}" if val_kappa is not None else "N/A"
test_kappa_str = f"{test_kappa:.4f}" if test_kappa is not None else "N/A"
print(
f"{name:<20}{val_acc_str:<10}{test_acc_str:<10}{val_roc_str:<14}{test_roc_str:<14}{val_pr_str:<14}{test_pr_str:<14}{val_kappa_str:<14}{test_kappa_str:<14}"
)
# Identify best model based on test metrics
if test_results:
# Best model by accuracy
best_acc_model = max(
test_results.items(), key=lambda x: x[1] if x[1] is not None else -1
)
print(
f"\nBest model by accuracy: {best_acc_model[0]} (Test Accuracy: {best_acc_model[1]:.4f})"
)
# Best model by ROC AUC (if available)
if any(v is not None for v in test_roc_auc_summary.values()):
best_roc_model = max(
[(k, v) for k, v in test_roc_auc_summary.items() if v is not None],
key=lambda x: x[1] if x[1] is not None else -1,
)
print(
f"Best model by ROC AUC: {best_roc_model[0]} (Test ROC AUC: {best_roc_model[1]:.4f})"
)
# Best model by PR AUC (if available)
if any(v is not None for v in test_pr_auc_summary.values()):
best_pr_model = max(
[(k, v) for k, v in test_pr_auc_summary.items() if v is not None],
key=lambda x: x[1] if x[1] is not None else -1,
)
print(
f"Best model by PR AUC: {best_pr_model[0]} (Test PR AUC: {best_pr_model[1]:.4f})"
)
# Best model by Kappa (if available)
if any(v is not None for v in test_kappa_summary.values()):
best_kappa_model = max(
[(k, v) for k, v in test_kappa_summary.items() if v is not None],
key=lambda x: x[1] if x[1] is not None else -1,
)
print(
f"Best model by Cohen's Kappa: {best_kappa_model[0]} (Test Kappa: {best_kappa_model[1]:.4f})"
)
# Save the best model (by accuracy)
if best_model_obj is not None:
try:
model_save_path = (
f"best_model_{best_model_name.lower().replace(' ', '_')}.pth"
)
torch.save(best_model_obj.state_dict(), model_save_path)
print(f"Best model saved to {model_save_path}")
except Exception as save_error:
print(f"Error saving best model: {save_error}")
else:
print("\nNo models successfully completed testing.")
print("=" * 100)