Atheer Aljuraib (k23108174)
Update training loop and fixed training metrics
e6d94e8
import torch
import torch.nn as nn
import numpy as np
from torcheval.metrics import MulticlassAccuracy
from torch.utils.data import DataLoader
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)
def train_model(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
n_epochs: int = 4,
lr: float = 1e-3,
save_path: str = "best_model.pt",
num_classes : int = 39,
early_stop : int = 3,
):
"""
Trains the given model and returns:
- training_losses: numpy array of loss per epoch
- training_accuracies: numpy array of running accuracy per epoch
- val_accuracies: numpy array of accuracy per epoch
- best_accuracy: highest validation accuracy achieved
Expected batch format:
batch["image"] → Tensor [B, C, H, W]
batch["label"] → Tensor [B] with class IDs (int64)
Model output:
outputs → Tensor [B, num_classes] (logits)
"""
# Move model to device
model.to(DEVICE)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later
# Metric trackers
train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
# Arrays to log metrics
num_batches = len(train_loader)
if num_batches == 0:
raise RuntimeError("UH OH!!!! empty train loader")
# Store training losses and accuracies for every epoch
training_losses = np.zeros(n_epochs)
training_accuracies = np.zeros(n_epochs)
# store validation accuracy for every epoch
val_accuracies = np.zeros(n_epochs)
# keep track of best validation accuracy and best model
best_accuracy = 0.0
# keep track of accuracy improvement
improv_counter = 0
#----------------------
# training loop
#----------------------
for epoch in range(n_epochs):
model.train()
train_accuracy_fn.reset()
training_loss = 0.0
# iterate over all the dataloader's mini-batches
for i, batch in enumerate(train_loader):
# move to GPU memory
inputs = batch["image"].to(DEVICE)
labels = batch["label"].to(DEVICE).long()
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
# updates the parameters
optimizer.step()
# log the loss value for epoch
training_loss += loss.item()
#updates the accuracy computation with new data
train_accuracy_fn.update(outputs, labels)
# compute epoch-level training metrics
training_losses[epoch] = training_loss / num_batches
training_accuracies[epoch] = train_accuracy_fn.compute().item()
print(f'Epoch {epoch + 1} training complete. Training Accuracy: {training_accuracies[epoch]:.4f}')
# ----------------------
# validation loop
# ----------------------
model.eval()
val_accuracy_fn.reset()
with torch.no_grad():
for batch in val_loader:
inputs = batch["image"].to(DEVICE)
labels = batch["label"].to(DEVICE).long()
outputs = model(inputs)
val_accuracy_fn.update(outputs, labels)
current_accuracy = val_accuracy_fn.compute().item()
val_accuracies[epoch] = current_accuracy
# keep track of best validation accuracy and save best model so far
if current_accuracy > best_accuracy:
best_accuracy = current_accuracy
torch.save(model.state_dict(), save_path)
improv_counter = 0 #Resets coounter if accuracy improves
print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})')
else:
improv_counter +=1
print(f'No improvement for {improv_counter} epoch')
if improv_counter >= early_stop:
print (f"Early stopping at epoch {epoch +1}")
break
print(f'Epoch {epoch + 1} validation complete')
print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}")
print(f"Best model weights saved to: {save_path}")
training_metrics = {
"losses": training_losses,
"accuracies": training_accuracies,
"val_accuracies": val_accuracies,
"best_accuracy": best_accuracy
}
return training_metrics