File size: 2,624 Bytes
2ca2f68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#import generators
import monai
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

from pathlib import Path

ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute())
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/utils'))
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/loss_function'))
from utils import (
    preview_image, preview_3D_vector_field, preview_3D_deformation,
    jacobian_determinant
)
from losses import (
    warp_func, warp_nearest_func, lncc_loss_func, dice_loss_func, reg_losses, dice_loss_func2
)

def train_seg(dataloader_train_seg,
              dataloader_valid_seg,
              device,
              seg_net,
              lr_seg,
              max_epoch,
              val_step,
              result_seg_path
              ):
# (if already done then you may skip to and uncomment the checkpoint loading cell below)

    seg_net.to(device)

    learning_rate = 1e-3
    optimizer = torch.optim.Adam(seg_net.parameters(), learning_rate)

    max_epochs = 300
    training_losses = []
    validation_losses = []
    val_interval = 5
    dice_loss = dice_loss_func2()
    for epoch_number in range(max_epochs):

        print(f"Epoch {epoch_number+1}/{max_epochs}:")

        seg_net.train()
        losses = []
        for batch in dataloader_train_seg:
            imgs = batch['img'].to(device)
            true_segs = batch['seg'].to(device)

            optimizer.zero_grad()
            predicted_segs = seg_net(imgs)
            loss = dice_loss(predicted_segs, true_segs)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

        training_loss = np.mean(losses)
        print(f"\ttraining loss: {training_loss}")
        training_losses.append([epoch_number, training_loss])

        if epoch_number % val_interval == 0:
            seg_net.eval()
            losses = []
            with torch.no_grad():
                for batch in dataloader_valid_seg:
                    imgs = batch['img'].to(device)
                    true_segs = batch['seg'].to(device)
                    predicted_segs = seg_net(imgs)
                    loss = dice_loss(predicted_segs, true_segs)
                    losses.append(loss.item())

            validation_loss = np.mean(losses)
            print(f"\tvalidation loss: {validation_loss}")
            validation_losses.append([epoch_number, validation_loss])
    
    # Free up some memory
    del loss, predicted_segs, true_segs, imgs
    torch.cuda.empty_cache()
    torch.save(seg_net.state_dict(), os.path.join(result_seg_path, 'seg_net_best.pth'))