timm / src /engine.py
YiMeng-SYSU's picture
Initial commit of timm project files
8bc22ab verified
import torch
from torch import nn
import wandb
from torch.amp import autocast, GradScaler
def train_one_epoch(epoch_id,model,data_loader,loss_fn,optimizer,device,scaler):
model.train()
training_loss = 0.0
running_correct = 0
total_samples = 0
for batch,(X,y) in enumerate(data_loader):
if not X.is_cuda:
X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True)
X = X.to(memory_format=torch.channels_last)
optimizer.zero_grad(set_to_none=True)
with autocast('cuda',dtype=torch.float16):
pred = model(X)
loss = loss_fn(pred,y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
pred_ids = pred.argmax(1)
running_correct += (pred_ids == y).type(torch.int).sum().item()
total_samples += y.size(0)
training_loss += loss.item()
train_epoch_loss = training_loss / len(data_loader)
train_epoch_acc = running_correct / total_samples
return train_epoch_loss,train_epoch_acc
def evaluate(epoch_id,model,data_loader,loss_fn,device):
model.eval()
testing_loss = 0.0
testing_correct = 0
total_samples = 0
bad_cases = []
with torch.no_grad():
for X,y in data_loader:
if not X.is_cuda:
X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True)
pred = model(X)
loss = loss_fn(pred,y)
testing_loss += loss.item()
pred_ids = pred.argmax(1)
testing_correct += (pred_ids == y).type(torch.int).sum().item()
total_samples += y.size(0)
if len(bad_cases) < 20:
wrong_idx = (pred_ids != y).nonzero()
for idx in wrong_idx:
if len(bad_cases) < 20:
raw_img = X[idx.item()].cpu()
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
img = raw_img * std + mean
img = torch.clamp(img,0,1)
bad_cases.append(
wandb.Image(img,caption=f"Pred: {pred_ids[idx].item()} | True: {y[idx].item()}")
)
val_epoch_loss = testing_loss / len(data_loader)
val_epoch_acc = testing_correct / total_samples
return val_epoch_loss,val_epoch_acc,bad_cases