| import torch | |
| from torch import nn | |
| import wandb | |
| from torch.amp import autocast, GradScaler | |
| def train_one_epoch(epoch_id,model,data_loader,loss_fn,optimizer,device,scaler): | |
| model.train() | |
| training_loss = 0.0 | |
| running_correct = 0 | |
| total_samples = 0 | |
| for batch,(X,y) in enumerate(data_loader): | |
| if not X.is_cuda: | |
| X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True) | |
| X = X.to(memory_format=torch.channels_last) | |
| optimizer.zero_grad(set_to_none=True) | |
| with autocast('cuda',dtype=torch.float16): | |
| pred = model(X) | |
| loss = loss_fn(pred,y) | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| pred_ids = pred.argmax(1) | |
| running_correct += (pred_ids == y).type(torch.int).sum().item() | |
| total_samples += y.size(0) | |
| training_loss += loss.item() | |
| train_epoch_loss = training_loss / len(data_loader) | |
| train_epoch_acc = running_correct / total_samples | |
| return train_epoch_loss,train_epoch_acc | |
| def evaluate(epoch_id,model,data_loader,loss_fn,device): | |
| model.eval() | |
| testing_loss = 0.0 | |
| testing_correct = 0 | |
| total_samples = 0 | |
| bad_cases = [] | |
| with torch.no_grad(): | |
| for X,y in data_loader: | |
| if not X.is_cuda: | |
| X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True) | |
| pred = model(X) | |
| loss = loss_fn(pred,y) | |
| testing_loss += loss.item() | |
| pred_ids = pred.argmax(1) | |
| testing_correct += (pred_ids == y).type(torch.int).sum().item() | |
| total_samples += y.size(0) | |
| if len(bad_cases) < 20: | |
| wrong_idx = (pred_ids != y).nonzero() | |
| for idx in wrong_idx: | |
| if len(bad_cases) < 20: | |
| raw_img = X[idx.item()].cpu() | |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| img = raw_img * std + mean | |
| img = torch.clamp(img,0,1) | |
| bad_cases.append( | |
| wandb.Image(img,caption=f"Pred: {pred_ids[idx].item()} | True: {y[idx].item()}") | |
| ) | |
| val_epoch_loss = testing_loss / len(data_loader) | |
| val_epoch_acc = testing_correct / total_samples | |
| return val_epoch_loss,val_epoch_acc,bad_cases | |