File size: 2,541 Bytes
906fcb9
 
 
caf6ee7
1baebae
906fcb9
 
 
1baebae
906fcb9
 
caf6ee7
 
1baebae
caf6ee7
906fcb9
 
 
 
 
 
 
 
 
 
caf6ee7
 
906fcb9
 
 
caf6ee7
 
906fcb9
 
 
 
1baebae
906fcb9
 
1baebae
906fcb9
 
caf6ee7
 
906fcb9
caf6ee7
906fcb9
 
 
 
 
 
 
caf6ee7
 
906fcb9
 
 
caf6ee7
 
906fcb9
1baebae
906fcb9
 
 
1baebae
 
 
 
 
 
 
906fcb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
from monai.metrics import Cumulative, CumulativeAverage
from sklearn.metrics import confusion_matrix, roc_auc_score


def train_epoch(cspca_model, loader, optimizer, epoch, args):
    cspca_model.train()
    criterion = nn.BCELoss()
    loss = 0.0
    run_loss = CumulativeAverage()
    targets_cumulative = Cumulative()
    preds_cumulative = Cumulative()

    for _, batch_data in enumerate(loader):
        data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
        target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)

        optimizer.zero_grad()
        output = cspca_model(data)
        output = output.squeeze(1)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        targets_cumulative.extend(target.detach().cpu())
        preds_cumulative.extend(output.detach().cpu())
        run_loss.append(loss.item())

    loss_epoch = run_loss.aggregate()
    target_list = targets_cumulative.get_buffer().cpu().numpy()
    pred_list = preds_cumulative.get_buffer().cpu().numpy()
    auc_epoch = roc_auc_score(target_list, pred_list)

    return loss_epoch, auc_epoch


def val_epoch(cspca_model, loader, epoch, args):
    cspca_model.eval()
    criterion = nn.BCELoss()
    loss = 0.0
    run_loss = CumulativeAverage()
    targets_cumulative = Cumulative()
    preds_cumulative = Cumulative()
    with torch.no_grad():
        for _, batch_data in enumerate(loader):
            data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
            target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)

            output = cspca_model(data)
            output = output.squeeze(1)
            loss = criterion(output, target)

            targets_cumulative.extend(target.detach().cpu())
            preds_cumulative.extend(output.detach().cpu())
            run_loss.append(loss.item())

    loss_epoch = run_loss.aggregate()
    target_list = targets_cumulative.get_buffer().cpu().numpy()
    pred_list = preds_cumulative.get_buffer().cpu().numpy()
    auc_epoch = roc_auc_score(target_list, pred_list)
    y_pred_categoric = pred_list >= 0.5
    tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
    sens_epoch = tp / (tp + fn)
    spec_epoch = tn / (tn + fp)
    val_epoch_metric = {
        "epoch": epoch,
        "loss": loss_epoch,
        "auc": auc_epoch,
        "sensitivity": sens_epoch,
        "specificity": spec_epoch,
    }
    return val_epoch_metric