File size: 7,845 Bytes
6e8e8fb
 
 
 
 
 
 
61c2d3f
 
 
975672a
 
 
 
6e8e8fb
975672a
6e8e8fb
975672a
 
 
 
 
 
 
 
6e8e8fb
975672a
6e8e8fb
975672a
61c2d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8e8fb
61c2d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import copy
import torch
from torch.utils.data import DataLoader
from utils.Trainer import model_train

from utils.DatasetHandler import FilteredImageDataset
from utils.Evaluator import ClassificationEvaluator


def compare_models(
    models: list,
    train_loader: DataLoader,
    val_loader: DataLoader,
    test_loader: DataLoader,
    dataset: FilteredImageDataset,
    epochs: int = 20,
    names: list | None = None,
) -> None:
    """
    Compare multiple models on validation and test datasets.
    Args:
        models (list): List of models to compare.
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        test_loader (DataLoader): DataLoader for test data.
        dataset (FilteredImageDataset): Dataset object containing class names.
        epochs (int): Number of epochs for training.
        names (list | None): List of model names. If None, default names will be used.
    """
    if names is None:
        names = [f"Model {i+1}" for i in range(len(models))]

    val_results = {}
    test_results = {}
    best_model_obj = None
    best_accuracy = -1
    best_model_name = ""

    # Summary dictionaries for metrics
    val_roc_auc_summary = {}
    test_roc_auc_summary = {}
    val_pr_auc_summary = {}
    test_pr_auc_summary = {}
    val_kappa_summary = {}
    test_kappa_summary = {}

    for i, (model, name) in enumerate(zip(models, names)):
        evaluator = ClassificationEvaluator(
            class_names=dataset.classes,
        )

        print(f"\n\n{'#'*30} Training {name} ({i+1}/{len(models)}) {'#'*30}\n")
        model_results = model_train(model, train_loader, val_loader, dataset, epochs)

        # Extract accuracy from results
        accuracy = model_results.get("accuracy")
        val_results[name] = accuracy

        # Extract and store metrics
        if "roc_auc" in model_results and "micro" in model_results["roc_auc"]:
            val_roc_auc_summary[name] = model_results["roc_auc"]["micro"]
        else:
            val_roc_auc_summary[name] = None

        if "pr_auc" in model_results and "micro" in model_results["pr_auc"]:
            val_pr_auc_summary[name] = model_results["pr_auc"]["micro"]
        else:
            val_pr_auc_summary[name] = None

        # Store kappa score
        if "kappa" in model_results:
            val_kappa_summary[name] = model_results["kappa"]
        else:
            val_kappa_summary[name] = None

        # Evaluate on test set
        if accuracy is not None:
            print(f"\n{'='*20} Testing {name} on Test Set {'='*20}\n")
            test_model_results = evaluator.evaluate_model(model, test_loader)

            # Extract accuracy from test results
            test_accuracy = test_model_results.get("accuracy")
            test_results[name] = test_accuracy

            # Extract and store test metrics
            if (
                "roc_auc" in test_model_results
                and "micro" in test_model_results["roc_auc"]
            ):
                test_roc_auc_summary[name] = test_model_results["roc_auc"]["micro"]
            else:
                test_roc_auc_summary[name] = None

            if (
                "pr_auc" in test_model_results
                and "micro" in test_model_results["pr_auc"]
            ):
                test_pr_auc_summary[name] = test_model_results["pr_auc"]["micro"]
            else:
                test_pr_auc_summary[name] = None

            # Store test kappa score
            if "kappa" in test_model_results:
                test_kappa_summary[name] = test_model_results["kappa"]
            else:
                test_kappa_summary[name] = None

            # Track best model
            if test_accuracy is not None and test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                best_model_obj = copy.deepcopy(model)
                best_model_name = name

    # Print comprehensive comparison
    print("\n\n" + "=" * 100)
    print("COMPREHENSIVE MODEL COMPARISON")
    print("=" * 100)
    print(
        f"{'Model':<20}{'Val Acc':<10}{'Test Acc':<10}{'Val ROC AUC':<14}{'Test ROC AUC':<14}{'Val PR AUC':<14}{'Test PR AUC':<14}{'Val Kappa':<14}{'Test Kappa':<14}"
    )
    print("-" * 100)

    for name in val_results.keys():
        val_acc = val_results[name]
        test_acc = test_results.get(name, None)
        val_roc = val_roc_auc_summary.get(name, None)
        test_roc = test_roc_auc_summary.get(name, None)
        val_pr = val_pr_auc_summary.get(name, None)
        test_pr = test_pr_auc_summary.get(name, None)
        val_kappa = val_kappa_summary.get(name, None)
        test_kappa = test_kappa_summary.get(name, None)

        # Format values for display
        val_acc_str = f"{val_acc:.4f}" if val_acc is not None else "Failed"
        test_acc_str = f"{test_acc:.4f}" if test_acc is not None else "N/A"
        val_roc_str = f"{val_roc:.4f}" if val_roc is not None else "N/A"
        test_roc_str = f"{test_roc:.4f}" if test_roc is not None else "N/A"
        val_pr_str = f"{val_pr:.4f}" if val_pr is not None else "N/A"
        test_pr_str = f"{test_pr:.4f}" if test_pr is not None else "N/A"
        val_kappa_str = f"{val_kappa:.4f}" if val_kappa is not None else "N/A"
        test_kappa_str = f"{test_kappa:.4f}" if test_kappa is not None else "N/A"

        print(
            f"{name:<20}{val_acc_str:<10}{test_acc_str:<10}{val_roc_str:<14}{test_roc_str:<14}{val_pr_str:<14}{test_pr_str:<14}{val_kappa_str:<14}{test_kappa_str:<14}"
        )

    # Identify best model based on test metrics
    if test_results:
        # Best model by accuracy
        best_acc_model = max(
            test_results.items(), key=lambda x: x[1] if x[1] is not None else -1
        )
        print(
            f"\nBest model by accuracy: {best_acc_model[0]} (Test Accuracy: {best_acc_model[1]:.4f})"
        )

        # Best model by ROC AUC (if available)
        if any(v is not None for v in test_roc_auc_summary.values()):
            best_roc_model = max(
                [(k, v) for k, v in test_roc_auc_summary.items() if v is not None],
                key=lambda x: x[1] if x[1] is not None else -1,
            )
            print(
                f"Best model by ROC AUC: {best_roc_model[0]} (Test ROC AUC: {best_roc_model[1]:.4f})"
            )

        # Best model by PR AUC (if available)
        if any(v is not None for v in test_pr_auc_summary.values()):
            best_pr_model = max(
                [(k, v) for k, v in test_pr_auc_summary.items() if v is not None],
                key=lambda x: x[1] if x[1] is not None else -1,
            )
            print(
                f"Best model by PR AUC: {best_pr_model[0]} (Test PR AUC: {best_pr_model[1]:.4f})"
            )

        # Best model by Kappa (if available)
        if any(v is not None for v in test_kappa_summary.values()):
            best_kappa_model = max(
                [(k, v) for k, v in test_kappa_summary.items() if v is not None],
                key=lambda x: x[1] if x[1] is not None else -1,
            )
            print(
                f"Best model by Cohen's Kappa: {best_kappa_model[0]} (Test Kappa: {best_kappa_model[1]:.4f})"
            )

        # Save the best model (by accuracy)
        if best_model_obj is not None:
            try:
                model_save_path = (
                    f"best_model_{best_model_name.lower().replace(' ', '_')}.pth"
                )
                torch.save(best_model_obj.state_dict(), model_save_path)
                print(f"Best model saved to {model_save_path}")
            except Exception as save_error:
                print(f"Error saving best model: {save_error}")
    else:
        print("\nNo models successfully completed testing.")

    print("=" * 100)