File size: 9,273 Bytes
61c2d3f
 
 
 
 
6e8e8fb
61c2d3f
 
 
6e8e8fb
 
61c2d3f
 
 
975672a
 
 
6e8e8fb
975672a
 
 
 
 
 
61c2d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8e8fb
 
 
 
 
 
 
 
61c2d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8e8fb
 
 
 
 
 
 
 
 
61c2d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
975672a
 
 
 
6e8e8fb
975672a
6e8e8fb
61c2d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import gc

from utils.Evaluator import ClassificationEvaluator
from utils.Callback import EarlyStopping


def train_model(
    model: nn.Module,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    scheduler,
    train_loader: DataLoader,
    val_loader: DataLoader,
    early_stopping: EarlyStopping,
    epochs: int = 15,
    use_ddp: bool = False,
) -> tuple:
    """
    Train the model and perform validation using multiple GPUs.
    Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.

    Args:
        model: Model to train
        criterion: Loss function
        optimizer: Optimizer for training
        scheduler: Learning rate scheduler
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        early_stopping: Early stopping handler
        epochs: Maximum number of epochs to train
        use_ddp: Whether to use DistributedDataParallel (True) or DataParallel (False)
    """
    # Check available GPUs
    num_gpus = torch.cuda.device_count()
    if num_gpus < 2:
        print(
            f"Warning: Requested multi-GPU training but only {num_gpus} GPU(s) available. Continuing with available resources."
        )
    else:
        print(f"Using {num_gpus} GPUs for training")

    # Setup device and model
    if num_gpus >= 2:
        if use_ddp:
            # For DistributedDataParallel
            import torch.distributed as dist
            from torch.nn.parallel import DistributedDataParallel as DDP

            # Initialize process group
            dist.init_process_group(backend="nccl")
            local_rank = dist.get_rank()
            torch.cuda.set_device(local_rank)
            device = torch.device(f"cuda:{local_rank}")

            model = model.to(device)
            model = DDP(model, device_ids=[local_rank])
        else:
            # For DataParallel (simpler to use)
            device = torch.device("cuda:0")
            model = model.to(device)
            model = torch.nn.DataParallel(model)
    else:
        # Single GPU
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model = model.to(device)

    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    # Store validation predictions and labels for final evaluation
    all_val_labels = []
    all_val_preds = []
    all_val_scores = []

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")

        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        if total == 0:
            print("Warning: No training samples found. Skipping training.")
            epoch_train_loss = 0.0
            epoch_train_acc = 0.0
        else:
            epoch_train_loss = running_loss / total
            epoch_train_acc = correct / total

        train_losses.append(epoch_train_loss)
        train_accs.append(epoch_train_acc)

        # Validation phase
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        all_labels = []
        all_preds = []
        all_scores = []

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                probs = F.softmax(outputs, dim=1)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                all_labels.extend(labels.cpu().numpy().tolist())
                all_preds.extend(predicted.cpu().numpy().tolist())
                all_scores.append(probs.cpu().numpy())

        #  Mitigate DivideByZeroError
        if total == 0:
            print("Warning: No validation samples found. Skipping validation.")
            epoch_val_loss = 0.0
            epoch_val_acc = 0.0
        else:
            epoch_val_loss = running_loss / total
            epoch_val_acc = correct / total

        val_losses.append(epoch_val_loss)
        val_accs.append(epoch_val_acc)

        all_scores = np.vstack(all_scores) if all_scores else np.array([])

        # Store validation results for the final epoch
        all_val_labels = all_labels
        all_val_preds = all_preds
        all_val_scores = all_scores

        # Update learning rate scheduler
        scheduler.step(epoch_val_loss)

        print(f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")
        print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")

        # Check early stopping
        early_stopping(epoch_val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered!")
            break

        # Free up memory
        del all_labels, all_preds, all_scores
        gc.collect()
        torch.cuda.empty_cache()

    # Clean up DDP if used
    if num_gpus >= 2 and use_ddp:
        dist.destroy_process_group()

    return (
        model,
        train_losses,
        val_losses,
        train_accs,
        val_accs,
        all_val_labels,
        all_val_preds,
        all_val_scores,
    )


def model_train(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    dataset,
    epochs: int = 20,
) -> dict:
    model_name = type(model).__name__
    if hasattr(model, "pretrained_cfg") and "name" in model.pretrained_cfg:
        model_name = model.pretrained_cfg["name"]

    print(f"\n{'='*20} Training {model_name} {'='*20}\n")

    class_names = dataset.classes
    num_classes = len(class_names)
    learning_rate = 0.001

    try:
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.1, patience=3
        )
        early_stopping = EarlyStopping(patience=5)

        (
            model,
            train_losses,
            val_losses,
            train_accs,
            val_accs,
            val_labels,
            val_preds,
            val_scores,
        ) = train_model(
            model,
            nn.CrossEntropyLoss(),
            optimizer,
            scheduler,
            train_loader,
            val_loader,
            early_stopping,
            epochs=epochs,
            use_ddp=False,
        )

        print(f"\n{'='*20} Evaluation for {model_name} {'='*20}\n")
        evaluator = ClassificationEvaluator(
            class_names=class_names,
        )

        evaluator.plot_training_history(train_losses, val_losses, train_accs, val_accs)
        # Process validation predictions and labels
        try:
            evaluator.plot_confusion_matrix(val_labels, val_preds)
            evaluator.plot_per_class_accuracy(val_labels, val_preds)

            # Get metrics from the updated function including kappa
            accuracy, report_dict, roc_auc_dict, pr_auc_dict, kappa = (
                evaluator.compute_metrics(
                    val_labels,
                    val_preds,
                    val_scores,
                    model_name,
                )
            )

            # Build a results dictionary including kappa
            results = {
                "accuracy": accuracy,
                "report": report_dict,
                "roc_auc": roc_auc_dict,
                "pr_auc": pr_auc_dict,
                "kappa": kappa,
            }

            return results
        except Exception as viz_error:
            print(f"Error in visualization: {viz_error}")
            import traceback

            traceback.print_exc()
            return {"accuracy": None}

    except Exception as e:
        print(f"Error occurred when training {model_name}: {e}")
        import traceback

        traceback.print_exc()
        return {"accuracy": None}
    finally:
        # Clean up memory
        if "optimizer" in locals():
            del optimizer
        if "scheduler" in locals():
            del scheduler
        if "early_stopping" in locals():
            del early_stopping
        if "train_losses" in locals():
            del train_losses
            del val_losses
            del train_accs
            del val_accs
            del val_labels
            del val_preds
            del val_scores

        gc.collect()
        torch.cuda.empty_cache()