import torch import torch.nn as nn import numpy as np from torcheval.metrics import MulticlassAccuracy from torch.utils.data import DataLoader DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", DEVICE) def train_model( model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, n_epochs: int = 4, lr: float = 1e-3, save_path: str = "best_model.pt", num_classes : int = 39, early_stop : int = 3, ): """ Trains the given model and returns: - training_losses: numpy array of loss per epoch - training_accuracies: numpy array of running accuracy per epoch - val_accuracies: numpy array of accuracy per epoch - best_accuracy: highest validation accuracy achieved Expected batch format: batch["image"] → Tensor [B, C, H, W] batch["label"] → Tensor [B] with class IDs (int64) Model output: outputs → Tensor [B, num_classes] (logits) """ # Move model to device model.to(DEVICE) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later # Metric trackers train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes) val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes) # Arrays to log metrics num_batches = len(train_loader) if num_batches == 0: raise RuntimeError("UH OH!!!! empty train loader") # Store training losses and accuracies for every epoch training_losses = np.zeros(n_epochs) training_accuracies = np.zeros(n_epochs) # store validation accuracy for every epoch val_accuracies = np.zeros(n_epochs) # keep track of best validation accuracy and best model best_accuracy = 0.0 # keep track of accuracy improvement improv_counter = 0 #---------------------- # training loop #---------------------- for epoch in range(n_epochs): model.train() train_accuracy_fn.reset() training_loss = 0.0 # iterate over all the dataloader's mini-batches for i, batch in enumerate(train_loader): # move to GPU memory inputs = batch["image"].to(DEVICE) labels = batch["label"].to(DEVICE).long() optimizer.zero_grad() # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) # Backward pass loss.backward() # updates the parameters optimizer.step() # log the loss value for epoch training_loss += loss.item() #updates the accuracy computation with new data train_accuracy_fn.update(outputs, labels) # compute epoch-level training metrics training_losses[epoch] = training_loss / num_batches training_accuracies[epoch] = train_accuracy_fn.compute().item() print(f'Epoch {epoch + 1} training complete. Training Accuracy: {training_accuracies[epoch]:.4f}') # ---------------------- # validation loop # ---------------------- model.eval() val_accuracy_fn.reset() with torch.no_grad(): for batch in val_loader: inputs = batch["image"].to(DEVICE) labels = batch["label"].to(DEVICE).long() outputs = model(inputs) val_accuracy_fn.update(outputs, labels) current_accuracy = val_accuracy_fn.compute().item() val_accuracies[epoch] = current_accuracy # keep track of best validation accuracy and save best model so far if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), save_path) improv_counter = 0 #Resets coounter if accuracy improves print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})') else: improv_counter +=1 print(f'No improvement for {improv_counter} epoch') if improv_counter >= early_stop: print (f"Early stopping at epoch {epoch +1}") break print(f'Epoch {epoch + 1} validation complete') print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}") print(f"Best model weights saved to: {save_path}") training_metrics = { "losses": training_losses, "accuracies": training_accuracies, "val_accuracies": val_accuracies, "best_accuracy": best_accuracy } return training_metrics