import torch from torch import nn import wandb from torch.amp import autocast, GradScaler def train_one_epoch(epoch_id,model,data_loader,loss_fn,optimizer,device,scaler): model.train() training_loss = 0.0 running_correct = 0 total_samples = 0 for batch,(X,y) in enumerate(data_loader): if not X.is_cuda: X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True) X = X.to(memory_format=torch.channels_last) optimizer.zero_grad(set_to_none=True) with autocast('cuda',dtype=torch.float16): pred = model(X) loss = loss_fn(pred,y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() pred_ids = pred.argmax(1) running_correct += (pred_ids == y).type(torch.int).sum().item() total_samples += y.size(0) training_loss += loss.item() train_epoch_loss = training_loss / len(data_loader) train_epoch_acc = running_correct / total_samples return train_epoch_loss,train_epoch_acc def evaluate(epoch_id,model,data_loader,loss_fn,device): model.eval() testing_loss = 0.0 testing_correct = 0 total_samples = 0 bad_cases = [] with torch.no_grad(): for X,y in data_loader: if not X.is_cuda: X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True) pred = model(X) loss = loss_fn(pred,y) testing_loss += loss.item() pred_ids = pred.argmax(1) testing_correct += (pred_ids == y).type(torch.int).sum().item() total_samples += y.size(0) if len(bad_cases) < 20: wrong_idx = (pred_ids != y).nonzero() for idx in wrong_idx: if len(bad_cases) < 20: raw_img = X[idx.item()].cpu() mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) img = raw_img * std + mean img = torch.clamp(img,0,1) bad_cases.append( wandb.Image(img,caption=f"Pred: {pred_ids[idx].item()} | True: {y[idx].item()}") ) val_epoch_loss = testing_loss / len(data_loader) val_epoch_acc = testing_correct / total_samples return val_epoch_loss,val_epoch_acc,bad_cases