File size: 13,013 Bytes
d9c5371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
#!/usr/bin/env python
#
# file: $ISIP_EXP/SOGMP/scripts/train.py
#
# revision history: xzt
#  20220824 (TE): first version
#
# usage:
#  python train.py mdir train_data val_data
#
# arguments:
#  mdir: the directory where the output model is stored
#  train_data: the directory of training data
#  val_data: the directory of valiation data
#
# This script trains a S3-Net model
#------------------------------------------------------------------------------

# import pytorch modules
#
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
import torch.nn.functional as F

# visualize:
from tensorboardX import SummaryWriter
import numpy as np

# import the model and all of its variables/functions
#
from model import *
import lovasz_losses as L

# import modules
#
import sys
import os


#-----------------------------------------------------------------------------
#
# global variables are listed here
#
#-----------------------------------------------------------------------------

# general global values
#
model_dir = './model/s3_net_model.pth'  # the path of model storage 
NUM_ARGS = 3
NUM_EPOCHS = 20000
BATCH_SIZE = 1024 
LEARNING_RATE = "lr"
BETAS = "betas"
EPS = "eps"
WEIGHT_DECAY = "weight_decay"

# Constants
NUM_INPUT_CHANNELS = 3 
NUM_OUTPUT_CHANNELS = 10 # 9 classes of semantic labels + 1 background
BETA = 0.01

# for reproducibility, we seed the rng
#
set_seed(SEED1)       

# adjust_learning_rate
# 
def adjust_learning_rate(optimizer, epoch):
    lr = 1e-4
    if epoch > 50000:
        lr = 2e-5
    if epoch > 480000:
       # lr = 5e-8
       lr = lr * (0.1 ** (epoch // 110000))
    #  if epoch > 8300:
    #      lr = 1e-9
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


# train function:
def train(model, dataloader, dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs):
    # set model to training mode:
    model.train()
    # for each batch in increments of batch size:
    running_loss = 0.0
    # kl_divergence:
    kl_avg_loss = 0.0
    # CE loss:
    ce_avg_loss = 0.0

    counter = 0
    # get the number of batches (ceiling of train_data/batch_size):
    num_batches = int(len(dataset)/dataloader.batch_size)
    for i, batch in tqdm(enumerate(dataloader), total=num_batches):
    #for i, batch in enumerate(dataloader, 0):
        counter += 1
        # collect the samples as a batch:
        scans = batch['scan']
        scans = scans.to(device)
        intensities = batch['intensity']
        intensities = intensities.to(device)
        angle_incidence = batch['angle_incidence']
        angle_incidence = angle_incidence.to(device)
        labels = batch['label']
        labels = labels.to(device)

        batch_size = scans.size(0)

        # set all gradients to 0:
        optimizer.zero_grad()

        # feed the batch to the network:
        semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence)
        # calculate the semantic ce loss:
        ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size)
        lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long))
        lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum()
        # beta-vae:
        loss = ce_loss + BETA*kl_loss + lovasz_loss
        # perform back propagation:
        loss.backward(torch.ones_like(loss))
        optimizer.step()
        # get the loss:
        # multiple GPUs:
        if torch.cuda.device_count() > 1:
            loss = loss.mean()  
            ce_loss = ce_loss.mean()
            kl_loss = lovasz_loss.mean() #kl_loss.mean()

        running_loss += loss.item()
        # kl_divergence:
        kl_avg_loss += lovasz_loss.item() #kl_loss.item()
        # CE loss:
        ce_avg_loss += ce_loss.item()

        # display informational message:
        if(i % 512 == 0):
            print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Lovasz_Loss: {:.4f}'
                    .format(epoch, epochs, i + 1, num_batches, loss.item(), ce_loss.item(), lovasz_loss.item()))
    
    train_loss = running_loss / counter  
    train_kl_loss = kl_avg_loss / counter
    train_ce_loss = ce_avg_loss / counter 

    return train_loss, train_kl_loss, train_ce_loss

# validate function:
def validate(model, dataloader, dataset, device, ce_criterion, lovasz_criterion, class_weights):
    # set model to evaluation mode:
    model.eval()
    # for each batch in increments of batch size:
    running_loss = 0.0
    # kl_divergence:
    kl_avg_loss = 0.0
    # CE loss:
    ce_avg_loss = 0.0

    counter = 0
    # get the number of batches (ceiling of train_data/batch_size):
    num_batches = int(len(dataset)/dataloader.batch_size)
    with torch.no_grad():
        for i, batch in tqdm(enumerate(dataloader), total=num_batches):
        #for i, batch in enumerate(dataloader, 0):
            counter += 1
            # collect the samples as a batch:
            scans = batch['scan']
            scans = scans.to(device)
            intensities = batch['intensity']
            intensities = intensities.to(device)
            angle_incidence = batch['angle_incidence']
            angle_incidence = angle_incidence.to(device)
            labels = batch['label']
            labels = labels.to(device)

            batch_size = scans.size(0)

            # feed the batch to the network:
            semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence)
            # calculate the semantic ce loss:
            ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size)
            lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long))
            lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum()
            # beta-vae:
            loss = ce_loss + BETA*kl_loss + lovasz_loss
            # multiple GPUs:
            if torch.cuda.device_count() > 1:
                loss = loss.mean()
                ce_loss = ce_loss.mean()
                kl_loss = lovasz_loss.mean() #kl_loss.mean()

            running_loss += loss.item()
            # kl_divergence:
            kl_avg_loss += lovasz_loss.item() #kl_loss.item()
            # CE loss:
            ce_avg_loss += ce_loss.item()

    val_loss = running_loss / counter
    val_kl_loss = kl_avg_loss / counter 
    val_ce_loss = ce_avg_loss / counter

    return val_loss, val_kl_loss, val_ce_loss

#------------------------------------------------------------------------------
#
# the main program starts here
#
#------------------------------------------------------------------------------

# function: main
#
# arguments: none
#
# return: none
#
# This method is the main function.
#
def main(argv):
    # ensure we have the correct amount of arguments:
    #global cur_batch_win
    if(len(argv) != NUM_ARGS):
        print("usage: python train.py [MDL_PATH] [TRAIN_PATH] [DEV_PATH] [TRAIN_MASK_PATH] [DEV_MASK_PATH]")
        exit(-1)

    # define local variables:
    mdl_path = argv[0]
    pTrain = argv[1]
    pDev = argv[2]

    # get the output directory name:
    odir = os.path.dirname(mdl_path)

    # if the odir doesn't exits, we make it:
    if not os.path.exists(odir):
        os.makedirs(odir)

    # set the device to use GPU if available:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print('...Start reading data...')
    ### training data ###
    # training set and training data loader
    train_dataset = VaeTestDataset(pTrain, 'train')
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, \
                                                   shuffle=True, drop_last=True, pin_memory=True)

    ### validation data ###
    # validation set and validation data loader
    dev_dataset = VaeTestDataset(pDev, 'dev')
    dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, num_workers=2, \
                                                 shuffle=True, drop_last=True, pin_memory=True)

    # calculate the class weights:
    class_weights = np.array([2.514399, 1.4917144, 0.51608694, 0.659483, 1.0900991, 1.6461798, 0.32852992, 1.5633508, 0.9236576, 0.10251398])  # median frequency balance 

    #class_weights = np.array([1.4222778, 2.1834621, 40.17538]) # inverse log class_probability
    class_weights = torch.Tensor(class_weights)
    print("class weights: ", class_weights)
    class_weights.to(device)
    print('...Finish reading data...')

    # instantiate a model:
    model = S3Net(input_channels=NUM_INPUT_CHANNELS,
                 output_channels=NUM_OUTPUT_CHANNELS)
    # moves the model to device (cpu in our case so no change):
    model.to(device)

    # set the adam optimizer parameters:
    opt_params = { LEARNING_RATE: 0.001,
                   BETAS: (.9,0.999),
                   EPS: 1e-08,
                   WEIGHT_DECAY: .001 }
    # set the loss criterion and optimizer:
    ce_criterion = nn.CrossEntropyLoss(reduction='sum', weight=class_weights)
    ce_criterion.to(device)
    lovasz_criterion = L.LovaszSoftmax(reduction='sum', ignore_index=0)
    lovasz_criterion.to(device)
    # create an optimizer, and pass the model params to it:
    optimizer = Adam(model.parameters(), **opt_params)

    # get the number of epochs to train on:
    epochs = NUM_EPOCHS

    # if there are trained models, continue training:
    if os.path.exists(mdl_path):
        checkpoint = torch.load(mdl_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('Load epoch {} success'.format(start_epoch))
    else:
        start_epoch = 0
        #pre_path = "./model/model_segnet_weight.pth"
        #pretrained_model = torch.load(pre_path)
        #model.load_state_dict(pretrained_model['model'])
        print('No trained models, restart training')

    # multiple GPUs:
    if torch.cuda.device_count() > 1:
        print("Let's use 2 of total", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model) #, device_ids=[0, 1])
    # moves the model to device (cpu in our case so no change):
    model.to(device)

    # tensorboard writer:
    writer = SummaryWriter('runs')

    epoch_num = 0
    for epoch in range(start_epoch+1, epochs):
        # adjust learning rate:
        adjust_learning_rate(optimizer, epoch)
        ################################## Train #####################################
        # for each batch in increments of batch size
        #
        train_epoch_loss, train_kl_epoch_loss, train_ce_epoch_loss = train(
            model, train_dataloader, train_dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs
        )
        valid_epoch_loss, valid_kl_epoch_loss, valid_ce_epoch_loss = validate(
            model, dev_dataloader, dev_dataset, device, ce_criterion, lovasz_criterion, class_weights
        )
        
        # log the epoch loss
        writer.add_scalar('training loss',
                        train_epoch_loss,
                        epoch)
        writer.add_scalar('training kl loss',
                        train_kl_epoch_loss,
                        epoch)
        writer.add_scalar('training ce loss',
                train_ce_epoch_loss,
                epoch)

        writer.add_scalar('validation loss',
                        valid_epoch_loss,
                        epoch)
        writer.add_scalar('validation kl loss',
                        valid_kl_epoch_loss,
                        epoch)
        writer.add_scalar('validation ce loss',
                        valid_ce_epoch_loss,
                        epoch)

        print('Train set: Average loss: {:.4f}'.format(train_epoch_loss))
        print('Validation set: Average loss: {:.4f}'.format(valid_epoch_loss))
        
        # save the model:
        if(epoch % 2000 == 0):
            if torch.cuda.device_count() > 1: # multiple GPUS: 
                state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
            else:
                state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
            path='./model/model' + str(epoch) +'.pth'
            torch.save(state, path)

        epoch_num = epoch

    # save the final model
    if torch.cuda.device_count() > 1: # multiple GPUS: 
        state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num}
    else:
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num}
    torch.save(state, mdl_path)

    # exit gracefully
    #

    return True
#
# end of function


# begin gracefully
#
if __name__ == '__main__':
    main(sys.argv[1:])
#
# end of file