Spaces:
Runtime error
Runtime error
File size: 2,541 Bytes
906fcb9 caf6ee7 1baebae 906fcb9 1baebae 906fcb9 caf6ee7 1baebae caf6ee7 906fcb9 caf6ee7 906fcb9 caf6ee7 906fcb9 1baebae 906fcb9 1baebae 906fcb9 caf6ee7 906fcb9 caf6ee7 906fcb9 caf6ee7 906fcb9 caf6ee7 906fcb9 1baebae 906fcb9 1baebae 906fcb9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | import torch
import torch.nn as nn
from monai.metrics import Cumulative, CumulativeAverage
from sklearn.metrics import confusion_matrix, roc_auc_score
def train_epoch(cspca_model, loader, optimizer, epoch, args):
cspca_model.train()
criterion = nn.BCELoss()
loss = 0.0
run_loss = CumulativeAverage()
targets_cumulative = Cumulative()
preds_cumulative = Cumulative()
for _, batch_data in enumerate(loader):
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
optimizer.zero_grad()
output = cspca_model(data)
output = output.squeeze(1)
loss = criterion(output, target)
loss.backward()
optimizer.step()
targets_cumulative.extend(target.detach().cpu())
preds_cumulative.extend(output.detach().cpu())
run_loss.append(loss.item())
loss_epoch = run_loss.aggregate()
target_list = targets_cumulative.get_buffer().cpu().numpy()
pred_list = preds_cumulative.get_buffer().cpu().numpy()
auc_epoch = roc_auc_score(target_list, pred_list)
return loss_epoch, auc_epoch
def val_epoch(cspca_model, loader, epoch, args):
cspca_model.eval()
criterion = nn.BCELoss()
loss = 0.0
run_loss = CumulativeAverage()
targets_cumulative = Cumulative()
preds_cumulative = Cumulative()
with torch.no_grad():
for _, batch_data in enumerate(loader):
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
output = cspca_model(data)
output = output.squeeze(1)
loss = criterion(output, target)
targets_cumulative.extend(target.detach().cpu())
preds_cumulative.extend(output.detach().cpu())
run_loss.append(loss.item())
loss_epoch = run_loss.aggregate()
target_list = targets_cumulative.get_buffer().cpu().numpy()
pred_list = preds_cumulative.get_buffer().cpu().numpy()
auc_epoch = roc_auc_score(target_list, pred_list)
y_pred_categoric = pred_list >= 0.5
tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
sens_epoch = tp / (tp + fn)
spec_epoch = tn / (tn + fp)
val_epoch_metric = {
"epoch": epoch,
"loss": loss_epoch,
"auc": auc_epoch,
"sensitivity": sens_epoch,
"specificity": spec_epoch,
}
return val_epoch_metric
|