Atheer Aljuraib (k23108174) commited on
Commit
0174d4f
·
unverified ·
1 Parent(s): 9676661

Add files via upload

Browse files

first training model draft

Files changed (1) hide show
  1. Training.py +149 -0
Training.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torcheval.metrics import MulticlassAccuracy
5
+ #from torchvision import transforms
6
+
7
+
8
+
9
+ from torch.utils.data import DataLoader
10
+ #from torchvision.datasets import MNIST
11
+
12
+ #import torchvision.utils
13
+
14
+ # loss, optimizer, training loop, validation, best model saving
15
+
16
+
17
+ def train_model(
18
+ model: nn.Module,
19
+ train_loader: DataLoader,
20
+ val_loader: DataLoader,
21
+ device: torch.device,
22
+ n_epochs: int = 4,
23
+ lr: float = 1e-3,
24
+ save_path: str = "best_model.pt",
25
+ flatten_input = False,
26
+ num_classes : int = 39,
27
+
28
+ ):
29
+
30
+ # Move model to device
31
+ model.to(device)
32
+
33
+ # Loss and optimizer
34
+ criterion = nn.CrossEntropyLoss()
35
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later
36
+
37
+ # Metric trackers
38
+ train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
39
+ val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
40
+
41
+ # Arrays to log metrics
42
+ num_batches = len(train_loader)
43
+
44
+ # Store training losses and accuracies for every batch
45
+ # num_batches is the number of batches for every epoch
46
+ training_losses = np.zeros(num_batches * n_epochs)
47
+ training_accuracies = np.zeros(num_batches * n_epochs)
48
+
49
+
50
+ # store validation accuracy for every epoch
51
+ val_accuracies = np.zeros(n_epochs)
52
+ # keep track of best validation accuracy and best model
53
+ best_accuracy = 0.0
54
+
55
+
56
+
57
+
58
+ # training loop
59
+ for epoch in range(n_epochs):
60
+ model.train()
61
+ train_accuracy_fn.reset()
62
+
63
+ # iterate over all the dataloader's mini-batches
64
+ for i, (inputs, labels) in enumerate(train_loader):
65
+
66
+ # move to GPU memory
67
+ inputs = inputs.to(device)
68
+ labels = labels.to(device)
69
+
70
+ # flatten if not cnn REVISE LATER
71
+ if flatten_input:
72
+ inputs = inputs.view(inputs.size(0), -1)
73
+
74
+
75
+ optimizer.zero_grad()
76
+
77
+
78
+ # Forward pass
79
+ outputs = model(inputs)
80
+ loss = criterion(outputs, labels)
81
+
82
+ # Backward pass
83
+ loss.backward()
84
+
85
+ # updates the parameters
86
+ optimizer.step()
87
+
88
+ # log the loss value
89
+ training_losses[epoch * num_batches + i] = loss.item()
90
+
91
+ # Compute accuracy of the batch.
92
+
93
+
94
+ #updates the accuracy computation with new data
95
+ train_accuracy_fn.update(outputs, labels)
96
+
97
+ #compute accuracy with the current data
98
+ training_accuracies[epoch * num_batches + i] = train_accuracy_fn.compute().item()
99
+
100
+
101
+ # display some progress (every 200 batches)
102
+ # optional, you can comment out
103
+ # if i % 200 == 0:
104
+ # print(f'Epoch {epoch + 1}, batch {i+1} of {len(train_loader)}')
105
+
106
+ print(f'Epoch {epoch + 1} training complete')
107
+
108
+ # Validation after each epoch
109
+ model.eval()
110
+ val_accuracy_fn.reset()
111
+
112
+
113
+ # The context 'torch.no_grad()' tells pytorch we are not interested in computing
114
+ # gradients here, so forward pass is more efficient
115
+ with torch.no_grad():
116
+ for i, (inputs, labels) in enumerate(val_loader):
117
+ inputs = inputs.to(device)
118
+ labels = labels.to(device)
119
+
120
+ # flatten if not cnn REVISE LATER
121
+ if flatten_input:
122
+ inputs = inputs.view(inputs.size(0), -1)
123
+
124
+
125
+ outputs = model(inputs)
126
+
127
+ val_accuracy_fn.update(outputs, labels)
128
+
129
+ current_accuracy = val_accuracy_fn.compute().item()
130
+ val_accuracies[epoch] = current_accuracy
131
+
132
+
133
+ # keep track of best validation accuracy and save best model so far
134
+ if current_accuracy > best_accuracy:
135
+ best_accuracy = current_accuracy
136
+ torch.save(model.state_dict(), save_path)
137
+ print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})')
138
+ print(f'Epoch {epoch + 1} validation complete')
139
+
140
+ print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}")
141
+ print(f"Best model weights saved to: {save_path}")
142
+
143
+ return training_losses, training_accuracies, val_accuracies, best_accuracy
144
+
145
+
146
+ #tweak later
147
+ #best_model = MNISTNet().to(device)
148
+ #best_model.load_state_dict(
149
+ # torch.load('mnist-torch-best_model.pt', map_location=device))