Atheer Aljuraib (k23108174) commited on
Commit
e6d94e8
·
1 Parent(s): 728c1f9

Update training loop and fixed training metrics

Browse files
trainingModel/helpers/Training.py CHANGED
@@ -1,182 +1,199 @@
1
- import torch
2
  import torch.nn as nn
3
  import numpy as np
4
  from torcheval.metrics import MulticlassAccuracy
5
  from torch.utils.data import DataLoader
6
 
7
 
8
- # fix errors in runtime
9
 
10
 
 
 
 
11
  def train_model(
12
- model: nn.Module,
13
- train_loader: DataLoader,
14
- val_loader: DataLoader,
15
- device: torch.device,
16
- n_epochs: int = 4,
17
- lr: float = 1e-3,
18
- num_classes: int = 39,
19
- optimizer_type: str = "adam",
20
- flatten_input: bool = False,
21
- save_path: str = "best_model.pt",
22
  ):
23
- """
24
- Trains the given model and returns:
25
- - training_losses: numpy array of loss per batch
26
- - training_accuracies: numpy array of running accuracy per batch
27
- - val_accuracies: numpy array of accuracy per epoch
28
- - best_accuracy: highest validation accuracy achieved
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- Expected batch format:
31
- batch["image"] → Tensor [B, C, H, W]
32
- batch["label"] → Tensor [B] with class IDs (int64)
33
- Model output:
34
- outputs → Tensor [B, num_classes] (logits)
35
- """
36
 
37
-
38
- # Move model to device
39
- model.to(device)
 
 
40
 
41
- # Loss and optimizer
42
- criterion = nn.CrossEntropyLoss()
43
 
44
- if optimizer_type.lower() == "adam":
45
- optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later
46
- else:
47
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr )
48
 
49
- # Metric trackers
50
- train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
51
- val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
52
 
53
- # Arrays to log metrics
54
- num_batches = len(train_loader)
 
55
 
56
- # Batch-level logs
57
- batch_losses = []
58
- batch_accuracies = []
59
 
60
- # Epoch-level logs
61
- epoch_losses = np.zeros(n_epochs)
62
- epoch_accuracies = np.zeros(n_epochs)
63
- val_accuracies = np.zeros(n_epochs)
64
 
65
 
66
- if num_batches == 0:
67
- raise RuntimeError("UH OH!!!! empty train loader")
 
68
 
69
- # Store training losses and accuracies for every batch
70
- # num_batches is the number of batches for every epoch
71
- #training_losses = np.zeros(num_batches * n_epochs)
72
- #training_accuracies = np.zeros(num_batches * n_epochs)
73
 
74
- # store validation accuracy for every epoch
 
75
 
76
 
77
- # keep track of best validation accuracy and best model
78
- best_accuracy = 0.0
79
 
80
 
81
- #----------------------
82
- # training loop
83
- #----------------------
84
-
85
- for epoch in range(n_epochs):
86
- model.train()
87
- train_accuracy_fn.reset()
88
 
89
- running_loss = 0.0
90
- running_correct = 0
91
- running_total = 0
92
 
93
- # iterate over all the dataloader's mini-batches
94
- for batch in train_loader:
95
 
96
- # move to GPU memory
97
- inputs = batch["image"].to(device)
98
- labels = batch["label"].to(device).long()
99
 
100
- # flatten if not cnn REVISE LATER
101
- if flatten_input:
102
- inputs = inputs.view(inputs.size(0), -1)
103
-
104
- optimizer.zero_grad()
105
 
106
- # Forward pass
107
- outputs = model(inputs)
108
- loss = criterion(outputs, labels)
109
-
110
- # Backward pass & update params
111
- loss.backward()
112
- optimizer.step()
113
 
114
- # Log batch-level metrics
115
- batch_losses.append(loss.item())
116
- batch_acc = (outputs.argmax(dim=1) == labels).float().mean().item()
117
- batch_accuracies.append(batch_acc)
118
 
119
- # Sum epoch stats
120
- running_loss += loss.item() * inputs.size(0)
121
- running_correct += (outputs.argmax(dim=1) == labels).sum().item()
122
- running_total += labels.size(0)
123
 
 
 
 
 
 
 
124
 
125
- # Epoch-level metrics (average over all batches)
126
- epoch_loss_avg = running_loss / running_total
127
- epoch_acc_avg = running_correct / running_total
128
 
129
- epoch_losses[epoch] = epoch_loss_avg
130
- epoch_accuracies[epoch] = epoch_acc_avg
 
131
 
132
- print(f"\n--- Epoch {epoch + 1}: ---")
133
- print(f'Train loss={epoch_loss_avg:.4f}\nTrain accuracy={epoch_acc_avg:.4f}\n')
134
 
135
- # ----------------------
136
- # validation loop
137
- # ----------------------
138
 
139
- model.eval()
140
- val_accuracy_fn.reset()
141
 
142
 
143
- with torch.no_grad():
144
- for batch in val_loader:
145
- inputs = batch["image"].to(device)
146
- labels = batch["label"].to(device).long()
147
 
148
- # flatten if not cnn REVISE LATER
149
- if flatten_input:
150
- inputs = inputs.view(inputs.size(0), -1)
151
 
152
- outputs = model(inputs)
153
- val_accuracy_fn.update(outputs, labels)
154
 
155
-
156
- current_val_accuracy = val_accuracy_fn.compute().item()
157
- val_accuracies[epoch] = current_val_accuracy
158
 
159
- print(f"\nEpoch {epoch+1}: val acc={current_val_accuracy:.4f}")
 
 
 
 
 
160
 
161
- # keep track of best validation accuracy and save best model so far
162
- if current_val_accuracy > best_accuracy:
163
- best_accuracy = current_val_accuracy
164
- torch.save(model.state_dict(), save_path)
165
-
166
 
167
- print(f'Epoch {epoch + 1} validation complete\n')
168
 
169
- print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}")
170
- print(f"Best model weights saved to: {save_path}")
171
 
172
- training_metrics = {
173
- "batch_losses": np.array(batch_losses),
174
- "batch_accuracies": np.array(batch_accuracies),
175
- "epoch_losses": epoch_losses,
176
- "epoch_accuracies": epoch_accuracies,
177
- "val_accuracies": val_accuracies,
178
- "best_accuracy": best_accuracy,
179
- }
180
 
181
- return training_metrics
182
 
 
1
+ import torch
2
  import torch.nn as nn
3
  import numpy as np
4
  from torcheval.metrics import MulticlassAccuracy
5
  from torch.utils.data import DataLoader
6
 
7
 
 
8
 
9
 
10
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ print("Using device:", DEVICE)
12
+
13
  def train_model(
14
+ model: nn.Module,
15
+ train_loader: DataLoader,
16
+ val_loader: DataLoader,
17
+ n_epochs: int = 4,
18
+ lr: float = 1e-3,
19
+ save_path: str = "best_model.pt",
20
+ num_classes : int = 39,
21
+ early_stop : int = 3,
22
+
23
+
24
  ):
25
+ """
26
+ Trains the given model and returns:
27
+ - training_losses: numpy array of loss per epoch
28
+ - training_accuracies: numpy array of running accuracy per epoch
29
+ - val_accuracies: numpy array of accuracy per epoch
30
+ - best_accuracy: highest validation accuracy achieved
31
+
32
+
33
+ Expected batch format:
34
+ batch["image"] → Tensor [B, C, H, W]
35
+ batch["label"] → Tensor [B] with class IDs (int64)
36
+ Model output:
37
+ outputs → Tensor [B, num_classes] (logits)
38
+ """
39
+
40
+
41
+ # Move model to device
42
+ model.to(DEVICE)
43
+
44
+
45
+ # Loss and optimizer
46
+ criterion = nn.CrossEntropyLoss()
47
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later
48
+
49
+
50
+ # Metric trackers
51
+ train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
52
+ val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
53
+
54
+
55
+ # Arrays to log metrics
56
+ num_batches = len(train_loader)
57
+
58
+
59
+ if num_batches == 0:
60
+ raise RuntimeError("UH OH!!!! empty train loader")
61
+
62
+
63
+ # Store training losses and accuracies for every epoch
64
+ training_losses = np.zeros(n_epochs)
65
+ training_accuracies = np.zeros(n_epochs)
66
+
67
+
68
+ # store validation accuracy for every epoch
69
+ val_accuracies = np.zeros(n_epochs)
70
+
71
+
72
+ # keep track of best validation accuracy and best model
73
+ best_accuracy = 0.0
74
+
75
+
76
+ # keep track of accuracy improvement
77
+ improv_counter = 0
78
+
79
+
80
+ #----------------------
81
+ # training loop
82
+ #----------------------
83
+
84
+ for epoch in range(n_epochs):
85
+ model.train()
86
+ train_accuracy_fn.reset()
87
+
88
+
89
+ training_loss = 0.0
90
+
91
+
92
+ # iterate over all the dataloader's mini-batches
93
+ for i, batch in enumerate(train_loader):
94
+
95
+
96
+ # move to GPU memory
97
+ inputs = batch["image"].to(DEVICE)
98
+ labels = batch["label"].to(DEVICE).long()
99
+
100
+
101
+
102
+
103
+ optimizer.zero_grad()
104
+
105
+
106
+ # Forward pass
107
+ outputs = model(inputs)
108
+ loss = criterion(outputs, labels)
109
+
110
+ # Backward pass
111
+ loss.backward()
112
 
 
 
 
 
 
 
113
 
114
+ # updates the parameters
115
+ optimizer.step()
116
+
117
+ # log the loss value for epoch
118
+ training_loss += loss.item()
119
 
 
 
120
 
121
+ #updates the accuracy computation with new data
122
+ train_accuracy_fn.update(outputs, labels)
 
 
123
 
 
 
 
124
 
125
+ # compute epoch-level training metrics
126
+ training_losses[epoch] = training_loss / num_batches
127
+ training_accuracies[epoch] = train_accuracy_fn.compute().item()
128
 
 
 
 
129
 
130
+ print(f'Epoch {epoch + 1} training complete. Training Accuracy: {training_accuracies[epoch]:.4f}')
 
 
 
131
 
132
 
133
+ # ----------------------
134
+ # validation loop
135
+ # ----------------------
136
 
 
 
 
 
137
 
138
+ model.eval()
139
+ val_accuracy_fn.reset()
140
 
141
 
 
 
142
 
143
 
144
+ with torch.no_grad():
145
+ for batch in val_loader:
146
+ inputs = batch["image"].to(DEVICE)
147
+ labels = batch["label"].to(DEVICE).long()
 
 
 
148
 
 
 
 
149
 
150
+ outputs = model(inputs)
 
151
 
 
 
 
152
 
153
+ val_accuracy_fn.update(outputs, labels)
 
 
 
 
154
 
 
 
 
 
 
 
 
155
 
156
+ current_accuracy = val_accuracy_fn.compute().item()
157
+ val_accuracies[epoch] = current_accuracy
 
 
158
 
 
 
 
 
159
 
160
+ # keep track of best validation accuracy and save best model so far
161
+ if current_accuracy > best_accuracy:
162
+ best_accuracy = current_accuracy
163
+ torch.save(model.state_dict(), save_path)
164
+ improv_counter = 0 #Resets coounter if accuracy improves
165
+ print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})')
166
 
 
 
 
167
 
168
+ else:
169
+ improv_counter +=1
170
+ print(f'No improvement for {improv_counter} epoch')
171
 
 
 
172
 
173
+ if improv_counter >= early_stop:
174
+ print (f"Early stopping at epoch {epoch +1}")
175
+ break
176
 
 
 
177
 
178
 
 
 
 
 
179
 
180
+ print(f'Epoch {epoch + 1} validation complete')
 
 
181
 
 
 
182
 
183
+ print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}")
184
+ print(f"Best model weights saved to: {save_path}")
 
185
 
186
+
187
+ training_metrics = {
188
+ "losses": training_losses,
189
+ "accuracies": training_accuracies,
190
+ "val_accuracies": val_accuracies,
191
+ "best_accuracy": best_accuracy
192
 
193
+ }
 
 
 
 
194
 
195
+ return training_metrics
196
 
 
 
197
 
 
 
 
 
 
 
 
 
198
 
 
199
 
trainingModel/run_training.py CHANGED
@@ -53,33 +53,27 @@ training_metrics = train_model(
53
  model=model,
54
  train_loader=subset_loaders['train'],
55
  val_loader=subset_loaders['val'],
56
- device=device,
57
  n_epochs=training_config["n_epochs"],
58
  lr=training_config["learning_rate"],
59
  num_classes=training_config["num_classes"],
60
- optimizer_type=training_config["optimizer"],
61
  save_path=training_config["save_path"],
 
62
  )
63
 
64
 
65
  # ----------- Log metrics to ClearML -----------
66
- # Per-batch training losses and accuracies
67
- for i, loss in enumerate(training_metrics["batch_losses"]):
68
- training_logger.report_scalar("training batch loss", "loss", value=loss, iteration=i)
69
-
70
- for i, acc in enumerate(training_metrics["batch_accuracies"]):
71
- training_logger.report_scalar("training batch accuracy", "accuracy", value=acc, iteration=i)
72
-
73
  # Per-epoch training losses and accuracies
74
- epoch_metrics = zip(training_metrics["epoch_losses"], training_metrics["epoch_accuracies"])
75
- for epoch, (loss, acc) in enumerate(epoch_metrics):
76
- training_logger.report_scalar("training epoch loss", "loss", loss, iteration=epoch)
77
- training_logger.report_scalar("training epoch accuracy", "accuracy", acc, iteration=epoch)
 
78
 
79
  # Per-epoch validation accuracies
80
  for epoch, acc in enumerate(training_metrics["val_accuracies"]):
81
  training_logger.report_scalar("validation epoch accuracy", "accuracy", value=acc, iteration=epoch)
82
 
 
83
  training_logger.report_single_value("best_val_accuracy", training_metrics["best_accuracy"])
84
 
85
  # Upload best model as artifact
 
53
  model=model,
54
  train_loader=subset_loaders['train'],
55
  val_loader=subset_loaders['val'],
 
56
  n_epochs=training_config["n_epochs"],
57
  lr=training_config["learning_rate"],
58
  num_classes=training_config["num_classes"],
 
59
  save_path=training_config["save_path"],
60
+ early_stop=3,
61
  )
62
 
63
 
64
  # ----------- Log metrics to ClearML -----------
 
 
 
 
 
 
 
65
  # Per-epoch training losses and accuracies
66
+ for epoch, loss in enumerate(training_metrics["losses"]):
67
+ training_logger.report_scalar("training epoch loss", "loss", value=loss, iteration=epoch)
68
+
69
+ for epoch, acc in enumerate(training_metrics["accuracies"]):
70
+ training_logger.report_scalar("training epoch accuracy", "accuracy", value=acc, iteration=epoch)
71
 
72
  # Per-epoch validation accuracies
73
  for epoch, acc in enumerate(training_metrics["val_accuracies"]):
74
  training_logger.report_scalar("validation epoch accuracy", "accuracy", value=acc, iteration=epoch)
75
 
76
+ # Best validation accuracy
77
  training_logger.report_single_value("best_val_accuracy", training_metrics["best_accuracy"])
78
 
79
  # Upload best model as artifact