Yusuf
feat: clearml training metrics
c638d1e
raw
history blame
4.52 kB
import torch
import torch.nn as nn
import numpy as np
from torcheval.metrics import MulticlassAccuracy
from torch.utils.data import DataLoader
# fix errors in runtime
def train_model(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
n_epochs: int = 4,
lr: float = 1e-3,
save_path: str = "best_model.pt",
flatten_input = False,
num_classes : int = 39,
):
"""
Trains the given model and returns:
- training_losses: numpy array of loss per batch
- training_accuracies: numpy array of running accuracy per batch
- val_accuracies: numpy array of accuracy per epoch
- best_accuracy: highest validation accuracy achieved
Expected batch format:
batch["image"] → Tensor [B, C, H, W]
batch["label"] → Tensor [B] with class IDs (int64)
Model output:
outputs → Tensor [B, num_classes] (logits)
"""
# Move model to device
model.to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr ) # might add momentum 0.9 later
# Metric trackers
train_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
val_accuracy_fn = MulticlassAccuracy(num_classes=num_classes)
# Arrays to log metrics
num_batches = len(train_loader)
if num_batches == 0:
raise RuntimeError("UH OH!!!! empty train loader")
# Store training losses and accuracies for every batch
# num_batches is the number of batches for every epoch
training_losses = np.zeros(num_batches * n_epochs)
training_accuracies = np.zeros(num_batches * n_epochs)
# store validation accuracy for every epoch
val_accuracies = np.zeros(n_epochs)
# keep track of best validation accuracy and best model
best_accuracy = 0.0
#----------------------
# training loop
#----------------------
for epoch in range(n_epochs):
model.train()
train_accuracy_fn.reset()
# iterate over all the dataloader's mini-batches
for i, batch in enumerate(train_loader):
# move to GPU memory
inputs = batch["image"].to(device)
labels = batch["label"].to(device).long()
# flatten if not cnn REVISE LATER
if flatten_input:
inputs = inputs.view(inputs.size(0), -1)
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
# updates the parameters
optimizer.step()
# log the loss value
training_losses[epoch * num_batches + i] = loss.item()
#updates the accuracy computation with new data
train_accuracy_fn.update(outputs, labels)
#compute accuracy with the current data
training_accuracies[epoch * num_batches + i] = train_accuracy_fn.compute().item()
print(f'Epoch {epoch + 1} training complete')
# ----------------------
# validation loop
# ----------------------
model.eval()
val_accuracy_fn.reset()
with torch.no_grad():
for batch in val_loader:
inputs = batch["image"].to(device)
labels = batch["label"].to(device).long()
# flatten if not cnn REVISE LATER
if flatten_input:
inputs = inputs.view(inputs.size(0), -1)
outputs = model(inputs)
val_accuracy_fn.update(outputs, labels)
current_accuracy = val_accuracy_fn.compute().item()
val_accuracies[epoch] = current_accuracy
# keep track of best validation accuracy and save best model so far
if current_accuracy > best_accuracy:
best_accuracy = current_accuracy
torch.save(model.state_dict(), save_path)
print(f'Epoch {epoch + 1} (validation accuracy: {best_accuracy})')
print(f'Epoch {epoch + 1} validation complete')
print(f"\nTraining finished. Best val accuracy: {best_accuracy:.4f}")
print(f"Best model weights saved to: {save_path}")
training_metrics = {
"losses": training_losses,
"accuracies": training_accuracies,
"val_accuracies": val_accuracies,
"best_accuracy": best_accuracy,
}
return training_metrics