Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from torcheval.metrics import MulticlassAccuracy | |
| from torch.utils.data import DataLoader | |
| # fix errors in runtime | |
| def train_model( | |
| model: nn.Module, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| device: torch.device, | |
| n_epochs: int = 4, | |
| lr: float = 1e-3, | |
| save_path: str = "best_model.pt", | |
| flatten_input = False, | |
| num_classes : int = 39, | |
| ): | |
| """ | |
| Trains the given model and returns: | |
| - training_losses: numpy array of loss per batch | |
| - training_accuracies: numpy array of running accuracy per batch | |
| - val_accuracies: numpy array of accuracy per epoch | |
| - best_accuracy: highest validation accuracy achieved | |
| Expected batch format: | |
| batch["image"] → Tensor [B, C, H, W] | |
| batch["label"] → Tensor [B] with class IDs (int64) | |
| Model output: | |
| outputs → Tensor [B, num_classes] (logits) | |
| """ | |
| # Move model to device | |
| model.to(device) | |
| # Loss and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later | |
| # Metric trackers | |
| train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes) | |
| val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes) | |
| # Arrays to log metrics | |
| num_batches = len(train_loader) | |
| if num_batches == 0: | |
| raise RuntimeError("UH OH!!!! empty train loader") | |
| # Store training losses and accuracies for every batch | |
| # num_batches is the number of batches for every epoch | |
| training_losses = np.zeros(num_batches * n_epochs) | |
| training_accuracies = np.zeros(num_batches * n_epochs) | |
| # store validation accuracy for every epoch | |
| val_accuracies = np.zeros(n_epochs) | |
| # keep track of best validation accuracy and best model | |
| best_accuracy = 0.0 | |
| #---------------------- | |
| # training loop | |
| #---------------------- | |
| for epoch in range(n_epochs): | |
| model.train() | |
| train_accuracy_fn.reset() | |
| # iterate over all the dataloader's mini-batches | |
| for i, batch in enumerate(train_loader): | |
| # move to GPU memory | |
| inputs = batch["image"].to(device) | |
| labels = batch["label"].to(device).long() | |
| # flatten if not cnn REVISE LATER | |
| if flatten_input: | |
| inputs = inputs.view(inputs.size(0), -1) | |
| optimizer.zero_grad() | |
| # Forward pass | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| # Backward pass | |
| loss.backward() | |
| # updates the parameters | |
| optimizer.step() | |
| # log the loss value | |
| training_losses[epoch * num_batches + i] = loss.item() | |
| #updates the accuracy computation with new data | |
| train_accuracy_fn.update(outputs, labels) | |
| #compute accuracy with the current data | |
| training_accuracies[epoch * num_batches + i] = train_accuracy_fn.compute().item() | |
| print(f'Epoch {epoch + 1} training complete') | |
| # ---------------------- | |
| # validation loop | |
| # ---------------------- | |
| model.eval() | |
| val_accuracy_fn.reset() | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| inputs = batch["image"].to(device) | |
| labels = batch["label"].to(device).long() | |
| # flatten if not cnn REVISE LATER | |
| if flatten_input: | |
| inputs = inputs.view(inputs.size(0), -1) | |
| outputs = model(inputs) | |
| val_accuracy_fn.update(outputs, labels) | |
| current_accuracy = val_accuracy_fn.compute().item() | |
| val_accuracies[epoch] = current_accuracy | |
| # keep track of best validation accuracy and save best model so far | |
| if current_accuracy > best_accuracy: | |
| best_accuracy = current_accuracy | |
| torch.save(model.state_dict(), save_path) | |
| print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})') | |
| print(f'Epoch {epoch + 1} validation complete') | |
| print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}") | |
| print(f"Best model weights saved to: {save_path}") | |
| training_metrics = { | |
| "losses": training_losses, | |
| "accuracies": training_accuracies, | |
| "val_accuracies": val_accuracies, | |
| "best_accuracy": best_accuracy, | |
| } | |
| return training_metrics | |