Atheer Aljuraib (k23108174) commited on
Commit
f597d2e
·
unverified ·
1 Parent(s): 0d36ad3

Update Training.py

Browse files
Files changed (1) hide show
  1. trainingModel/Training.py +25 -35
trainingModel/Training.py CHANGED
@@ -2,16 +2,10 @@ import torch
2
  import torch.nn as nn
3
  import numpy as np
4
  from torcheval.metrics import MulticlassAccuracy
5
- #from torchvision import transforms
6
-
7
-
8
-
9
  from torch.utils.data import DataLoader
10
- #from torchvision.datasets import MNIST
11
 
12
- #import torchvision.utils
13
 
14
- # loss, optimizer, training loop, validation, best model saving
15
 
16
 
17
  def train_model(
@@ -26,7 +20,19 @@ def train_model(
26
  num_classes : int = 39,
27
 
28
  ):
29
-
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  # Move model to device
@@ -43,19 +49,20 @@ def train_model(
43
  # Arrays to log metrics
44
  num_batches = len(train_loader)
45
 
 
 
 
46
  # Store training losses and accuracies for every batch
47
  # num_batches is the number of batches for every epoch
48
  training_losses = np.zeros(num_batches * n_epochs)
49
  training_accuracies = np.zeros(num_batches * n_epochs)
50
 
51
-
52
  # store validation accuracy for every epoch
53
  val_accuracies = np.zeros(n_epochs)
 
54
  # keep track of best validation accuracy and best model
55
  best_accuracy = 0.0
56
 
57
-
58
-
59
  #----------------------
60
  # training loop
61
  #----------------------
@@ -69,16 +76,14 @@ def train_model(
69
 
70
  # move to GPU memory
71
  inputs = batch["image"].to(device)
72
- labels = batch["label"].to(device)
73
 
74
  # flatten if not cnn REVISE LATER
75
  if flatten_input:
76
  inputs = inputs.view(inputs.size(0), -1)
77
 
78
-
79
  optimizer.zero_grad()
80
 
81
-
82
  # Forward pass
83
  outputs = model(inputs)
84
  loss = criterion(outputs, labels)
@@ -92,40 +97,31 @@ def train_model(
92
  # log the loss value
93
  training_losses[epoch * num_batches + i] = loss.item()
94
 
95
- # Compute accuracy of the batch.
96
-
97
-
98
  #updates the accuracy computation with new data
99
  train_accuracy_fn.update(outputs, labels)
100
 
101
  #compute accuracy with the current data
102
  training_accuracies[epoch * num_batches + i] = train_accuracy_fn.compute().item()
103
 
104
-
105
- # display some progress (every 200 batches)
106
- # optional, you can comment out
107
- # if i % 200 == 0:
108
- # print(f'Epoch {epoch + 1}, batch {i+1} of {len(train_loader)}')
109
-
110
  print(f'Epoch {epoch + 1} training complete')
111
 
112
- # Validation after each epoch
 
 
 
113
  model.eval()
114
  val_accuracy_fn.reset()
115
 
116
 
117
- # The context 'torch.no_grad()' tells pytorch we are not interested in computing
118
- # gradients here, so forward pass is more efficient
119
  with torch.no_grad():
120
- for i, batch in enumerate(val_loader):
121
  inputs = batch["image"].to(device)
122
- labels = batch["label"].to(device)
123
 
124
  # flatten if not cnn REVISE LATER
125
  if flatten_input:
126
  inputs = inputs.view(inputs.size(0), -1)
127
 
128
-
129
  outputs = model(inputs)
130
 
131
  val_accuracy_fn.update(outputs, labels)
@@ -133,7 +129,6 @@ def train_model(
133
  current_accuracy = val_accuracy_fn.compute().item()
134
  val_accuracies[epoch] = current_accuracy
135
 
136
-
137
  # keep track of best validation accuracy and save best model so far
138
  if current_accuracy > best_accuracy:
139
  best_accuracy = current_accuracy
@@ -146,8 +141,3 @@ def train_model(
146
 
147
  return training_losses, training_accuracies, val_accuracies, best_accuracy
148
 
149
-
150
- #tweak later
151
- #best_model = MNISTNet().to(device)
152
- #best_model.load_state_dict(
153
- # torch.load('mnist-torch-best_model.pt', map_location=device))
 
2
  import torch.nn as nn
3
  import numpy as np
4
  from torcheval.metrics import MulticlassAccuracy
 
 
 
 
5
  from torch.utils.data import DataLoader
 
6
 
 
7
 
8
+ # fix errors in runtime
9
 
10
 
11
  def train_model(
 
20
  num_classes : int = 39,
21
 
22
  ):
23
+ """
24
+ Trains the given model and returns:
25
+ - training_losses: numpy array of loss per batch
26
+ - training_accuracies: numpy array of running accuracy per batch
27
+ - val_accuracies: numpy array of accuracy per epoch
28
+ - best_accuracy: highest validation accuracy achieved
29
+
30
+ Expected batch format:
31
+ batch["image"] → Tensor [B, C, H, W]
32
+ batch["label"] → Tensor [B] with class IDs (int64)
33
+ Model output:
34
+ outputs → Tensor [B, num_classes] (logits)
35
+ """
36
 
37
 
38
  # Move model to device
 
49
  # Arrays to log metrics
50
  num_batches = len(train_loader)
51
 
52
+ if num_batches == 0:
53
+ raise RuntimeError("UH OH!!!! empty train loader")
54
+
55
  # Store training losses and accuracies for every batch
56
  # num_batches is the number of batches for every epoch
57
  training_losses = np.zeros(num_batches * n_epochs)
58
  training_accuracies = np.zeros(num_batches * n_epochs)
59
 
 
60
  # store validation accuracy for every epoch
61
  val_accuracies = np.zeros(n_epochs)
62
+
63
  # keep track of best validation accuracy and best model
64
  best_accuracy = 0.0
65
 
 
 
66
  #----------------------
67
  # training loop
68
  #----------------------
 
76
 
77
  # move to GPU memory
78
  inputs = batch["image"].to(device)
79
+ labels = batch["label"].to(device).long()
80
 
81
  # flatten if not cnn REVISE LATER
82
  if flatten_input:
83
  inputs = inputs.view(inputs.size(0), -1)
84
 
 
85
  optimizer.zero_grad()
86
 
 
87
  # Forward pass
88
  outputs = model(inputs)
89
  loss = criterion(outputs, labels)
 
97
  # log the loss value
98
  training_losses[epoch * num_batches + i] = loss.item()
99
 
 
 
 
100
  #updates the accuracy computation with new data
101
  train_accuracy_fn.update(outputs, labels)
102
 
103
  #compute accuracy with the current data
104
  training_accuracies[epoch * num_batches + i] = train_accuracy_fn.compute().item()
105
 
 
 
 
 
 
 
106
  print(f'Epoch {epoch + 1} training complete')
107
 
108
+ # ----------------------
109
+ # validation loop
110
+ # ----------------------
111
+
112
  model.eval()
113
  val_accuracy_fn.reset()
114
 
115
 
 
 
116
  with torch.no_grad():
117
+ for batch in val_loader:
118
  inputs = batch["image"].to(device)
119
+ labels = batch["label"].to(device).long()
120
 
121
  # flatten if not cnn REVISE LATER
122
  if flatten_input:
123
  inputs = inputs.view(inputs.size(0), -1)
124
 
 
125
  outputs = model(inputs)
126
 
127
  val_accuracy_fn.update(outputs, labels)
 
129
  current_accuracy = val_accuracy_fn.compute().item()
130
  val_accuracies[epoch] = current_accuracy
131
 
 
132
  # keep track of best validation accuracy and save best model so far
133
  if current_accuracy > best_accuracy:
134
  best_accuracy = current_accuracy
 
141
 
142
  return training_losses, training_accuracies, val_accuracies, best_accuracy
143