Spaces:
Sleeping
Sleeping
File size: 7,845 Bytes
6e8e8fb 61c2d3f 975672a 6e8e8fb 975672a 6e8e8fb 975672a 6e8e8fb 975672a 6e8e8fb 975672a 61c2d3f 6e8e8fb 61c2d3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import copy
import torch
from torch.utils.data import DataLoader
from utils.Trainer import model_train
from utils.DatasetHandler import FilteredImageDataset
from utils.Evaluator import ClassificationEvaluator
def compare_models(
models: list,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: DataLoader,
dataset: FilteredImageDataset,
epochs: int = 20,
names: list | None = None,
) -> None:
"""
Compare multiple models on validation and test datasets.
Args:
models (list): List of models to compare.
train_loader (DataLoader): DataLoader for training data.
val_loader (DataLoader): DataLoader for validation data.
test_loader (DataLoader): DataLoader for test data.
dataset (FilteredImageDataset): Dataset object containing class names.
epochs (int): Number of epochs for training.
names (list | None): List of model names. If None, default names will be used.
"""
if names is None:
names = [f"Model {i+1}" for i in range(len(models))]
val_results = {}
test_results = {}
best_model_obj = None
best_accuracy = -1
best_model_name = ""
# Summary dictionaries for metrics
val_roc_auc_summary = {}
test_roc_auc_summary = {}
val_pr_auc_summary = {}
test_pr_auc_summary = {}
val_kappa_summary = {}
test_kappa_summary = {}
for i, (model, name) in enumerate(zip(models, names)):
evaluator = ClassificationEvaluator(
class_names=dataset.classes,
)
print(f"\n\n{'#'*30} Training {name} ({i+1}/{len(models)}) {'#'*30}\n")
model_results = model_train(model, train_loader, val_loader, dataset, epochs)
# Extract accuracy from results
accuracy = model_results.get("accuracy")
val_results[name] = accuracy
# Extract and store metrics
if "roc_auc" in model_results and "micro" in model_results["roc_auc"]:
val_roc_auc_summary[name] = model_results["roc_auc"]["micro"]
else:
val_roc_auc_summary[name] = None
if "pr_auc" in model_results and "micro" in model_results["pr_auc"]:
val_pr_auc_summary[name] = model_results["pr_auc"]["micro"]
else:
val_pr_auc_summary[name] = None
# Store kappa score
if "kappa" in model_results:
val_kappa_summary[name] = model_results["kappa"]
else:
val_kappa_summary[name] = None
# Evaluate on test set
if accuracy is not None:
print(f"\n{'='*20} Testing {name} on Test Set {'='*20}\n")
test_model_results = evaluator.evaluate_model(model, test_loader)
# Extract accuracy from test results
test_accuracy = test_model_results.get("accuracy")
test_results[name] = test_accuracy
# Extract and store test metrics
if (
"roc_auc" in test_model_results
and "micro" in test_model_results["roc_auc"]
):
test_roc_auc_summary[name] = test_model_results["roc_auc"]["micro"]
else:
test_roc_auc_summary[name] = None
if (
"pr_auc" in test_model_results
and "micro" in test_model_results["pr_auc"]
):
test_pr_auc_summary[name] = test_model_results["pr_auc"]["micro"]
else:
test_pr_auc_summary[name] = None
# Store test kappa score
if "kappa" in test_model_results:
test_kappa_summary[name] = test_model_results["kappa"]
else:
test_kappa_summary[name] = None
# Track best model
if test_accuracy is not None and test_accuracy > best_accuracy:
best_accuracy = test_accuracy
best_model_obj = copy.deepcopy(model)
best_model_name = name
# Print comprehensive comparison
print("\n\n" + "=" * 100)
print("COMPREHENSIVE MODEL COMPARISON")
print("=" * 100)
print(
f"{'Model':<20}{'Val Acc':<10}{'Test Acc':<10}{'Val ROC AUC':<14}{'Test ROC AUC':<14}{'Val PR AUC':<14}{'Test PR AUC':<14}{'Val Kappa':<14}{'Test Kappa':<14}"
)
print("-" * 100)
for name in val_results.keys():
val_acc = val_results[name]
test_acc = test_results.get(name, None)
val_roc = val_roc_auc_summary.get(name, None)
test_roc = test_roc_auc_summary.get(name, None)
val_pr = val_pr_auc_summary.get(name, None)
test_pr = test_pr_auc_summary.get(name, None)
val_kappa = val_kappa_summary.get(name, None)
test_kappa = test_kappa_summary.get(name, None)
# Format values for display
val_acc_str = f"{val_acc:.4f}" if val_acc is not None else "Failed"
test_acc_str = f"{test_acc:.4f}" if test_acc is not None else "N/A"
val_roc_str = f"{val_roc:.4f}" if val_roc is not None else "N/A"
test_roc_str = f"{test_roc:.4f}" if test_roc is not None else "N/A"
val_pr_str = f"{val_pr:.4f}" if val_pr is not None else "N/A"
test_pr_str = f"{test_pr:.4f}" if test_pr is not None else "N/A"
val_kappa_str = f"{val_kappa:.4f}" if val_kappa is not None else "N/A"
test_kappa_str = f"{test_kappa:.4f}" if test_kappa is not None else "N/A"
print(
f"{name:<20}{val_acc_str:<10}{test_acc_str:<10}{val_roc_str:<14}{test_roc_str:<14}{val_pr_str:<14}{test_pr_str:<14}{val_kappa_str:<14}{test_kappa_str:<14}"
)
# Identify best model based on test metrics
if test_results:
# Best model by accuracy
best_acc_model = max(
test_results.items(), key=lambda x: x[1] if x[1] is not None else -1
)
print(
f"\nBest model by accuracy: {best_acc_model[0]} (Test Accuracy: {best_acc_model[1]:.4f})"
)
# Best model by ROC AUC (if available)
if any(v is not None for v in test_roc_auc_summary.values()):
best_roc_model = max(
[(k, v) for k, v in test_roc_auc_summary.items() if v is not None],
key=lambda x: x[1] if x[1] is not None else -1,
)
print(
f"Best model by ROC AUC: {best_roc_model[0]} (Test ROC AUC: {best_roc_model[1]:.4f})"
)
# Best model by PR AUC (if available)
if any(v is not None for v in test_pr_auc_summary.values()):
best_pr_model = max(
[(k, v) for k, v in test_pr_auc_summary.items() if v is not None],
key=lambda x: x[1] if x[1] is not None else -1,
)
print(
f"Best model by PR AUC: {best_pr_model[0]} (Test PR AUC: {best_pr_model[1]:.4f})"
)
# Best model by Kappa (if available)
if any(v is not None for v in test_kappa_summary.values()):
best_kappa_model = max(
[(k, v) for k, v in test_kappa_summary.items() if v is not None],
key=lambda x: x[1] if x[1] is not None else -1,
)
print(
f"Best model by Cohen's Kappa: {best_kappa_model[0]} (Test Kappa: {best_kappa_model[1]:.4f})"
)
# Save the best model (by accuracy)
if best_model_obj is not None:
try:
model_save_path = (
f"best_model_{best_model_name.lower().replace(' ', '_')}.pth"
)
torch.save(best_model_obj.state_dict(), model_save_path)
print(f"Best model saved to {model_save_path}")
except Exception as save_error:
print(f"Error saving best model: {save_error}")
else:
print("\nNo models successfully completed testing.")
print("=" * 100)
|