Spaces:
Sleeping
Sleeping
| import copy | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from utils.Trainer import model_train | |
| from utils.DatasetHandler import FilteredImageDataset | |
| from utils.Evaluator import ClassificationEvaluator | |
| def compare_models( | |
| models: list, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| test_loader: DataLoader, | |
| dataset: FilteredImageDataset, | |
| epochs: int = 20, | |
| names: list | None = None, | |
| ) -> None: | |
| """ | |
| Compare multiple models on validation and test datasets. | |
| Args: | |
| models (list): List of models to compare. | |
| train_loader (DataLoader): DataLoader for training data. | |
| val_loader (DataLoader): DataLoader for validation data. | |
| test_loader (DataLoader): DataLoader for test data. | |
| dataset (FilteredImageDataset): Dataset object containing class names. | |
| epochs (int): Number of epochs for training. | |
| names (list | None): List of model names. If None, default names will be used. | |
| """ | |
| if names is None: | |
| names = [f"Model {i+1}" for i in range(len(models))] | |
| val_results = {} | |
| test_results = {} | |
| best_model_obj = None | |
| best_accuracy = -1 | |
| best_model_name = "" | |
| # Summary dictionaries for metrics | |
| val_roc_auc_summary = {} | |
| test_roc_auc_summary = {} | |
| val_pr_auc_summary = {} | |
| test_pr_auc_summary = {} | |
| val_kappa_summary = {} | |
| test_kappa_summary = {} | |
| for i, (model, name) in enumerate(zip(models, names)): | |
| evaluator = ClassificationEvaluator( | |
| class_names=dataset.classes, | |
| ) | |
| print(f"\n\n{'#'*30} Training {name} ({i+1}/{len(models)}) {'#'*30}\n") | |
| model_results = model_train(model, train_loader, val_loader, dataset, epochs) | |
| # Extract accuracy from results | |
| accuracy = model_results.get("accuracy") | |
| val_results[name] = accuracy | |
| # Extract and store metrics | |
| if "roc_auc" in model_results and "micro" in model_results["roc_auc"]: | |
| val_roc_auc_summary[name] = model_results["roc_auc"]["micro"] | |
| else: | |
| val_roc_auc_summary[name] = None | |
| if "pr_auc" in model_results and "micro" in model_results["pr_auc"]: | |
| val_pr_auc_summary[name] = model_results["pr_auc"]["micro"] | |
| else: | |
| val_pr_auc_summary[name] = None | |
| # Store kappa score | |
| if "kappa" in model_results: | |
| val_kappa_summary[name] = model_results["kappa"] | |
| else: | |
| val_kappa_summary[name] = None | |
| # Evaluate on test set | |
| if accuracy is not None: | |
| print(f"\n{'='*20} Testing {name} on Test Set {'='*20}\n") | |
| test_model_results = evaluator.evaluate_model(model, test_loader) | |
| # Extract accuracy from test results | |
| test_accuracy = test_model_results.get("accuracy") | |
| test_results[name] = test_accuracy | |
| # Extract and store test metrics | |
| if ( | |
| "roc_auc" in test_model_results | |
| and "micro" in test_model_results["roc_auc"] | |
| ): | |
| test_roc_auc_summary[name] = test_model_results["roc_auc"]["micro"] | |
| else: | |
| test_roc_auc_summary[name] = None | |
| if ( | |
| "pr_auc" in test_model_results | |
| and "micro" in test_model_results["pr_auc"] | |
| ): | |
| test_pr_auc_summary[name] = test_model_results["pr_auc"]["micro"] | |
| else: | |
| test_pr_auc_summary[name] = None | |
| # Store test kappa score | |
| if "kappa" in test_model_results: | |
| test_kappa_summary[name] = test_model_results["kappa"] | |
| else: | |
| test_kappa_summary[name] = None | |
| # Track best model | |
| if test_accuracy is not None and test_accuracy > best_accuracy: | |
| best_accuracy = test_accuracy | |
| best_model_obj = copy.deepcopy(model) | |
| best_model_name = name | |
| # Print comprehensive comparison | |
| print("\n\n" + "=" * 100) | |
| print("COMPREHENSIVE MODEL COMPARISON") | |
| print("=" * 100) | |
| print( | |
| f"{'Model':<20}{'Val Acc':<10}{'Test Acc':<10}{'Val ROC AUC':<14}{'Test ROC AUC':<14}{'Val PR AUC':<14}{'Test PR AUC':<14}{'Val Kappa':<14}{'Test Kappa':<14}" | |
| ) | |
| print("-" * 100) | |
| for name in val_results.keys(): | |
| val_acc = val_results[name] | |
| test_acc = test_results.get(name, None) | |
| val_roc = val_roc_auc_summary.get(name, None) | |
| test_roc = test_roc_auc_summary.get(name, None) | |
| val_pr = val_pr_auc_summary.get(name, None) | |
| test_pr = test_pr_auc_summary.get(name, None) | |
| val_kappa = val_kappa_summary.get(name, None) | |
| test_kappa = test_kappa_summary.get(name, None) | |
| # Format values for display | |
| val_acc_str = f"{val_acc:.4f}" if val_acc is not None else "Failed" | |
| test_acc_str = f"{test_acc:.4f}" if test_acc is not None else "N/A" | |
| val_roc_str = f"{val_roc:.4f}" if val_roc is not None else "N/A" | |
| test_roc_str = f"{test_roc:.4f}" if test_roc is not None else "N/A" | |
| val_pr_str = f"{val_pr:.4f}" if val_pr is not None else "N/A" | |
| test_pr_str = f"{test_pr:.4f}" if test_pr is not None else "N/A" | |
| val_kappa_str = f"{val_kappa:.4f}" if val_kappa is not None else "N/A" | |
| test_kappa_str = f"{test_kappa:.4f}" if test_kappa is not None else "N/A" | |
| print( | |
| f"{name:<20}{val_acc_str:<10}{test_acc_str:<10}{val_roc_str:<14}{test_roc_str:<14}{val_pr_str:<14}{test_pr_str:<14}{val_kappa_str:<14}{test_kappa_str:<14}" | |
| ) | |
| # Identify best model based on test metrics | |
| if test_results: | |
| # Best model by accuracy | |
| best_acc_model = max( | |
| test_results.items(), key=lambda x: x[1] if x[1] is not None else -1 | |
| ) | |
| print( | |
| f"\nBest model by accuracy: {best_acc_model[0]} (Test Accuracy: {best_acc_model[1]:.4f})" | |
| ) | |
| # Best model by ROC AUC (if available) | |
| if any(v is not None for v in test_roc_auc_summary.values()): | |
| best_roc_model = max( | |
| [(k, v) for k, v in test_roc_auc_summary.items() if v is not None], | |
| key=lambda x: x[1] if x[1] is not None else -1, | |
| ) | |
| print( | |
| f"Best model by ROC AUC: {best_roc_model[0]} (Test ROC AUC: {best_roc_model[1]:.4f})" | |
| ) | |
| # Best model by PR AUC (if available) | |
| if any(v is not None for v in test_pr_auc_summary.values()): | |
| best_pr_model = max( | |
| [(k, v) for k, v in test_pr_auc_summary.items() if v is not None], | |
| key=lambda x: x[1] if x[1] is not None else -1, | |
| ) | |
| print( | |
| f"Best model by PR AUC: {best_pr_model[0]} (Test PR AUC: {best_pr_model[1]:.4f})" | |
| ) | |
| # Best model by Kappa (if available) | |
| if any(v is not None for v in test_kappa_summary.values()): | |
| best_kappa_model = max( | |
| [(k, v) for k, v in test_kappa_summary.items() if v is not None], | |
| key=lambda x: x[1] if x[1] is not None else -1, | |
| ) | |
| print( | |
| f"Best model by Cohen's Kappa: {best_kappa_model[0]} (Test Kappa: {best_kappa_model[1]:.4f})" | |
| ) | |
| # Save the best model (by accuracy) | |
| if best_model_obj is not None: | |
| try: | |
| model_save_path = ( | |
| f"best_model_{best_model_name.lower().replace(' ', '_')}.pth" | |
| ) | |
| torch.save(best_model_obj.state_dict(), model_save_path) | |
| print(f"Best model saved to {model_save_path}") | |
| except Exception as save_error: | |
| print(f"Error saving best model: {save_error}") | |
| else: | |
| print("\nNo models successfully completed testing.") | |
| print("=" * 100) | |