import torch import torch.nn as nn import numpy as np from torcheval.metrics import MulticlassAccuracy from torch.utils.data import DataLoader # fix errors in runtime def train_model( model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, n_epochs: int = 4, lr: float = 1e-3, save_path: str = "best_model.pt", flatten_input = False, num_classes : int = 39, ): """ Trains the given model and returns: - training_losses: numpy array of loss per batch - training_accuracies: numpy array of running accuracy per batch - val_accuracies: numpy array of accuracy per epoch - best_accuracy: highest validation accuracy achieved Expected batch format: batch["image"] → Tensor [B, C, H, W] batch["label"] → Tensor [B] with class IDs (int64) Model output: outputs → Tensor [B, num_classes] (logits) """ # Move model to device model.to(device) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later # Metric trackers train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes) val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes) # Arrays to log metrics num_batches = len(train_loader) if num_batches == 0: raise RuntimeError("UH OH!!!! empty train loader") # Store training losses and accuracies for every batch # num_batches is the number of batches for every epoch training_losses = np.zeros(num_batches * n_epochs) training_accuracies = np.zeros(num_batches * n_epochs) # store validation accuracy for every epoch val_accuracies = np.zeros(n_epochs) # keep track of best validation accuracy and best model best_accuracy = 0.0 #---------------------- # training loop #---------------------- for epoch in range(n_epochs): model.train() train_accuracy_fn.reset() # iterate over all the dataloader's mini-batches for i, batch in enumerate(train_loader): # move to GPU memory inputs = batch["image"].to(device) labels = batch["label"].to(device).long() # flatten if not cnn REVISE LATER if flatten_input: inputs = inputs.view(inputs.size(0), -1) optimizer.zero_grad() # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) # Backward pass loss.backward() # updates the parameters optimizer.step() # log the loss value training_losses[epoch * num_batches + i] = loss.item() #updates the accuracy computation with new data train_accuracy_fn.update(outputs, labels) #compute accuracy with the current data training_accuracies[epoch * num_batches + i] = train_accuracy_fn.compute().item() print(f'Epoch {epoch + 1} training complete') # ---------------------- # validation loop # ---------------------- model.eval() val_accuracy_fn.reset() with torch.no_grad(): for batch in val_loader: inputs = batch["image"].to(device) labels = batch["label"].to(device).long() # flatten if not cnn REVISE LATER if flatten_input: inputs = inputs.view(inputs.size(0), -1) outputs = model(inputs) val_accuracy_fn.update(outputs, labels) current_accuracy = val_accuracy_fn.compute().item() val_accuracies[epoch] = current_accuracy # keep track of best validation accuracy and save best model so far if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), save_path) print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})') print(f'Epoch {epoch + 1} validation complete') print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}") print(f"Best model weights saved to: {save_path}") training_metrics = { "losses": training_losses, "accuracies": training_accuracies, "val_accuracies": val_accuracies, "best_accuracy": best_accuracy, } return training_metrics