File size: 4,594 Bytes
e6d94e8
0174d4f
 
 
 
 
 
 
 
e6d94e8
 
 
0174d4f
e6d94e8
 
 
 
 
 
 
 
 
 
0174d4f
e6d94e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f597d2e
f61bf9a
e6d94e8
 
 
 
 
0174d4f
25fbc07
e6d94e8
 
0174d4f
 
e6d94e8
 
 
0174d4f
25fbc07
e6d94e8
25fbc07
 
e6d94e8
 
 
f597d2e
0174d4f
e6d94e8
 
25fbc07
f597d2e
0174d4f
25fbc07
e6d94e8
 
 
 
0174d4f
25fbc07
e6d94e8
0174d4f
 
e6d94e8
0174d4f
 
e6d94e8
 
25fbc07
 
e6d94e8
 
 
 
 
 
0174d4f
 
e6d94e8
 
 
25fbc07
0174d4f
e6d94e8
 
 
f597d2e
0174d4f
 
 
e6d94e8
0174d4f
 
e6d94e8
 
25fbc07
e6d94e8
 
 
 
 
 
0174d4f
e6d94e8
25fbc07
e6d94e8
0174d4f
 
c638d1e
0174d4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import torch
import torch.nn as nn
import numpy as np
from torcheval.metrics import MulticlassAccuracy
from torch.utils.data import DataLoader




DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

def train_model(
   model: nn.Module,
   train_loader: DataLoader,
   val_loader: DataLoader,
   n_epochs: int = 4,
   lr: float = 1e-3,
   save_path: str = "best_model.pt",
   num_classes : int = 39,
   early_stop : int = 3,


):
   """
   Trains the given model and returns:
   - training_losses: numpy array of loss per epoch
   - training_accuracies: numpy array of running accuracy per epoch
   - val_accuracies: numpy array of accuracy per epoch
   - best_accuracy: highest validation accuracy achieved


   Expected batch format:
       batch["image"] → Tensor [B, C, H, W]
       batch["label"] → Tensor [B] with class IDs (int64)
   Model output:
       outputs → Tensor [B, num_classes] (logits)
   """


   # Move model to device
   model.to(DEVICE)


   # Loss and optimizer
   criterion = nn.CrossEntropyLoss()
   optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later


   # Metric trackers
   train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
   val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)


   # Arrays to log metrics
   num_batches = len(train_loader)


   if num_batches == 0:
       raise RuntimeError("UH OH!!!! empty train loader")


   # Store training losses and accuracies for every epoch
   training_losses = np.zeros(n_epochs)
   training_accuracies = np.zeros(n_epochs)


   # store validation accuracy for every epoch
   val_accuracies = np.zeros(n_epochs)


   # keep track of best validation accuracy and best model
   best_accuracy = 0.0


   # keep track of accuracy improvement
   improv_counter = 0


   #----------------------
   # training loop
   #----------------------
  
   for epoch in range(n_epochs):
       model.train()
       train_accuracy_fn.reset()


       training_loss = 0.0


       # iterate over all the dataloader's mini-batches
       for i, batch in enumerate(train_loader):


           # move to GPU memory
           inputs = batch["image"].to(DEVICE)
           labels = batch["label"].to(DEVICE).long()




           optimizer.zero_grad()


           # Forward pass
           outputs = model(inputs)
           loss = criterion(outputs, labels)
          
           # Backward pass
           loss.backward()


           # updates the parameters
           optimizer.step()
          
           # log the loss value for epoch
           training_loss += loss.item()


           #updates the accuracy computation with new data
           train_accuracy_fn.update(outputs, labels)


       # compute epoch-level training metrics
       training_losses[epoch] = training_loss / num_batches
       training_accuracies[epoch] = train_accuracy_fn.compute().item()


       print(f'Epoch {epoch + 1} training complete. Training Accuracy: {training_accuracies[epoch]:.4f}')


       # ----------------------
       # validation loop
       # ----------------------


       model.eval()
       val_accuracy_fn.reset()




       with torch.no_grad():
           for batch in val_loader:
               inputs = batch["image"].to(DEVICE)
               labels = batch["label"].to(DEVICE).long()


               outputs = model(inputs)


               val_accuracy_fn.update(outputs, labels)


       current_accuracy = val_accuracy_fn.compute().item()
       val_accuracies[epoch] = current_accuracy


       # keep track of best validation accuracy and save best model so far
       if current_accuracy > best_accuracy:
           best_accuracy = current_accuracy
           torch.save(model.state_dict(), save_path)
           improv_counter = 0  #Resets coounter if accuracy improves
           print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})')


       else:
           improv_counter +=1
           print(f'No improvement for {improv_counter} epoch')


           if improv_counter >= early_stop:
               print (f"Early stopping at epoch {epoch +1}")
               break




       print(f'Epoch {epoch + 1} validation complete')


   print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}")
   print(f"Best model weights saved to: {save_path}")

   
   training_metrics = {
       "losses": training_losses,
       "accuracies": training_accuracies,
       "val_accuracies": val_accuracies,
       "best_accuracy": best_accuracy

   }

   return training_metrics