File size: 2,519 Bytes
e3469ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch 
from torch import nn
import wandb
from torch.amp import autocast, GradScaler

def train_one_epoch(epoch_id,model,data_loader,loss_fn,optimizer,device,scaler):
    model.train()
    training_loss = 0.0
    running_correct = 0
    total_samples = 0

    for batch,(X,y) in enumerate(data_loader):
        if not X.is_cuda:
            X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True)
        X = X.to(memory_format=torch.channels_last)
        optimizer.zero_grad(set_to_none=True)

        with autocast('cuda',dtype=torch.float16):
            pred = model(X)
            loss = loss_fn(pred,y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        pred_ids = pred.argmax(1)
        running_correct += (pred_ids == y).type(torch.int).sum().item()
        total_samples += y.size(0)

        training_loss += loss.item()
        
    train_epoch_loss = training_loss / len(data_loader)
    train_epoch_acc = running_correct / total_samples

    return train_epoch_loss,train_epoch_acc

def evaluate(epoch_id,model,data_loader,loss_fn,device):
    model.eval()
    testing_loss = 0.0
    testing_correct = 0
    total_samples = 0
    bad_cases = []

    with torch.no_grad():
        for X,y in data_loader:
            if not X.is_cuda:
                X,y = X.to(device,non_blocking=True),y.to(device,non_blocking=True)

            pred = model(X)
            loss = loss_fn(pred,y)
            testing_loss += loss.item()

            pred_ids = pred.argmax(1) 
            testing_correct += (pred_ids == y).type(torch.int).sum().item()
            total_samples += y.size(0)

            if len(bad_cases) < 20:
                wrong_idx = (pred_ids != y).nonzero()
                for idx in wrong_idx:
                    if len(bad_cases) < 20:
                        raw_img = X[idx.item()].cpu()

                        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

                        img = raw_img * std + mean
                        img = torch.clamp(img,0,1)
                        
                        bad_cases.append(
                            wandb.Image(img,caption=f"Pred: {pred_ids[idx].item()} | True: {y[idx].item()}")
                        )
    
    val_epoch_loss = testing_loss / len(data_loader)
    val_epoch_acc = testing_correct / total_samples

    return val_epoch_loss,val_epoch_acc,bad_cases