Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from torcheval.metrics import MulticlassAccuracy | |
| from torch.utils.data import DataLoader | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Using device:", DEVICE) | |
| def train_model( | |
| model: nn.Module, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| n_epochs: int = 4, | |
| lr: float = 1e-3, | |
| save_path: str = "best_model.pt", | |
| num_classes : int = 39, | |
| early_stop : int = 3, | |
| ): | |
| """ | |
| Trains the given model and returns: | |
| - training_losses: numpy array of loss per epoch | |
| - training_accuracies: numpy array of running accuracy per epoch | |
| - val_accuracies: numpy array of accuracy per epoch | |
| - best_accuracy: highest validation accuracy achieved | |
| Expected batch format: | |
| batch["image"] → Tensor [B, C, H, W] | |
| batch["label"] → Tensor [B] with class IDs (int64) | |
| Model output: | |
| outputs → Tensor [B, num_classes] (logits) | |
| """ | |
| # Move model to device | |
| model.to(DEVICE) | |
| # Loss and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later | |
| # Metric trackers | |
| train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes) | |
| val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes) | |
| # Arrays to log metrics | |
| num_batches = len(train_loader) | |
| if num_batches == 0: | |
| raise RuntimeError("UH OH!!!! empty train loader") | |
| # Store training losses and accuracies for every epoch | |
| training_losses = np.zeros(n_epochs) | |
| training_accuracies = np.zeros(n_epochs) | |
| # store validation accuracy for every epoch | |
| val_accuracies = np.zeros(n_epochs) | |
| # keep track of best validation accuracy and best model | |
| best_accuracy = 0.0 | |
| # keep track of accuracy improvement | |
| improv_counter = 0 | |
| #---------------------- | |
| # training loop | |
| #---------------------- | |
| for epoch in range(n_epochs): | |
| model.train() | |
| train_accuracy_fn.reset() | |
| training_loss = 0.0 | |
| # iterate over all the dataloader's mini-batches | |
| for i, batch in enumerate(train_loader): | |
| # move to GPU memory | |
| inputs = batch["image"].to(DEVICE) | |
| labels = batch["label"].to(DEVICE).long() | |
| optimizer.zero_grad() | |
| # Forward pass | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| # Backward pass | |
| loss.backward() | |
| # updates the parameters | |
| optimizer.step() | |
| # log the loss value for epoch | |
| training_loss += loss.item() | |
| #updates the accuracy computation with new data | |
| train_accuracy_fn.update(outputs, labels) | |
| # compute epoch-level training metrics | |
| training_losses[epoch] = training_loss / num_batches | |
| training_accuracies[epoch] = train_accuracy_fn.compute().item() | |
| print(f'Epoch {epoch + 1} training complete. Training Accuracy: {training_accuracies[epoch]:.4f}') | |
| # ---------------------- | |
| # validation loop | |
| # ---------------------- | |
| model.eval() | |
| val_accuracy_fn.reset() | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| inputs = batch["image"].to(DEVICE) | |
| labels = batch["label"].to(DEVICE).long() | |
| outputs = model(inputs) | |
| val_accuracy_fn.update(outputs, labels) | |
| current_accuracy = val_accuracy_fn.compute().item() | |
| val_accuracies[epoch] = current_accuracy | |
| # keep track of best validation accuracy and save best model so far | |
| if current_accuracy > best_accuracy: | |
| best_accuracy = current_accuracy | |
| torch.save(model.state_dict(), save_path) | |
| improv_counter = 0 #Resets coounter if accuracy improves | |
| print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})') | |
| else: | |
| improv_counter +=1 | |
| print(f'No improvement for {improv_counter} epoch') | |
| if improv_counter >= early_stop: | |
| print (f"Early stopping at epoch {epoch +1}") | |
| break | |
| print(f'Epoch {epoch + 1} validation complete') | |
| print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}") | |
| print(f"Best model weights saved to: {save_path}") | |
| training_metrics = { | |
| "losses": training_losses, | |
| "accuracies": training_accuracies, | |
| "val_accuracies": val_accuracies, | |
| "best_accuracy": best_accuracy | |
| } | |
| return training_metrics | |