Atheer Aljuraib (k23108174)
Move Training.py into trainingModel folder
3562c3d
raw
history blame
4.5 kB
import torch
import torch.nn as nn
import numpy as np
from torcheval.metrics import MulticlassAccuracy
#from torchvision import transforms
from torch.utils.data import DataLoader
#from torchvision.datasets import MNIST
#import torchvision.utils
# loss, optimizer, training loop, validation, best model saving
def train_model(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
n_epochs: int = 4,
lr: float = 1e-3,
save_path: str = "best_model.pt",
flatten_input = False,
num_classes : int = 39,
):
# Move model to device
model.to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later
# Metric trackers
train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
# Arrays to log metrics
num_batches = len(train_loader)
# Store training losses and accuracies for every batch
# num_batches is the number of batches for every epoch
training_losses = np.zeros(num_batches * n_epochs)
training_accuracies = np.zeros(num_batches * n_epochs)
# store validation accuracy for every epoch
val_accuracies = np.zeros(n_epochs)
# keep track of best validation accuracy and best model
best_accuracy = 0.0
#----------------------
# training loop
#----------------------
for epoch in range(n_epochs):
model.train()
train_accuracy_fn.reset()
# iterate over all the dataloader's mini-batches
for i, batch in enumerate(train_loader):
# move to GPU memory
inputs = batch["image"].to(device)
labels = batch["label"].to(device)
# flatten if not cnn REVISE LATER
if flatten_input:
inputs = inputs.view(inputs.size(0), -1)
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
# updates the parameters
optimizer.step()
# log the loss value
training_losses[epoch * num_batches + i] = loss.item()
# Compute accuracy of the batch.
#updates the accuracy computation with new data
train_accuracy_fn.update(outputs, labels)
#compute accuracy with the current data
training_accuracies[epoch * num_batches + i] = train_accuracy_fn.compute().item()
# display some progress (every 200 batches)
# optional, you can comment out
# if i % 200 == 0:
# print(f'Epoch {epoch + 1}, batch {i+1} of {len(train_loader)}')
print(f'Epoch {epoch + 1} training complete')
# Validation after each epoch
model.eval()
val_accuracy_fn.reset()
# The context 'torch.no_grad()' tells pytorch we are not interested in computing
# gradients here, so forward pass is more efficient
with torch.no_grad():
for i, batch in enumerate(val_loader):
inputs = batch["image"].to(device)
labels = batch["label"].to(device)
# flatten if not cnn REVISE LATER
if flatten_input:
inputs = inputs.view(inputs.size(0), -1)
outputs = model(inputs)
val_accuracy_fn.update(outputs, labels)
current_accuracy = val_accuracy_fn.compute().item()
val_accuracies[epoch] = current_accuracy
# keep track of best validation accuracy and save best model so far
if current_accuracy > best_accuracy:
best_accuracy = current_accuracy
torch.save(model.state_dict(), save_path)
print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})')
print(f'Epoch {epoch + 1} validation complete')
print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}")
print(f"Best model weights saved to: {save_path}")
return training_losses, training_accuracies, val_accuracies, best_accuracy
#tweak later
#best_model = MNISTNet().to(device)
#best_model.load_state_dict(
# torch.load('mnist-torch-best_model.pt', map_location=device))