#!/usr/bin/env python # # file: $ISIP_EXP/SOGMP/scripts/train.py # # revision history: xzt # 20220824 (TE): first version # # usage: # python train.py mdir train_data val_data # # arguments: # mdir: the directory where the output model is stored # train_data: the directory of training data # val_data: the directory of valiation data # # This script trains a S3-Net model #------------------------------------------------------------------------------ # import pytorch modules # import torch import torch.nn as nn from torch.optim import Adam from tqdm import tqdm import torch.nn.functional as F # visualize: from tensorboardX import SummaryWriter import numpy as np # import the model and all of its variables/functions # from model import * import lovasz_losses as L # import modules # import sys import os #----------------------------------------------------------------------------- # # global variables are listed here # #----------------------------------------------------------------------------- # general global values # model_dir = './model/s3_net_model.pth' # the path of model storage NUM_ARGS = 3 NUM_EPOCHS = 20000 BATCH_SIZE = 1024 LEARNING_RATE = "lr" BETAS = "betas" EPS = "eps" WEIGHT_DECAY = "weight_decay" # Constants NUM_INPUT_CHANNELS = 3 NUM_OUTPUT_CHANNELS = 10 # 9 classes of semantic labels + 1 background BETA = 0.01 # for reproducibility, we seed the rng # set_seed(SEED1) # adjust_learning_rate #  def adjust_learning_rate(optimizer, epoch): lr = 1e-4 if epoch > 50000: lr = 2e-5 if epoch > 480000: # lr = 5e-8 lr = lr * (0.1 ** (epoch // 110000)) # if epoch > 8300: # lr = 1e-9 for param_group in optimizer.param_groups: param_group['lr'] = lr # train function: def train(model, dataloader, dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs): # set model to training mode: model.train() # for each batch in increments of batch size: running_loss = 0.0 # kl_divergence: kl_avg_loss = 0.0 # CE loss: ce_avg_loss = 0.0 counter = 0 # get the number of batches (ceiling of train_data/batch_size): num_batches = int(len(dataset)/dataloader.batch_size) for i, batch in tqdm(enumerate(dataloader), total=num_batches): #for i, batch in enumerate(dataloader, 0): counter += 1 # collect the samples as a batch: scans = batch['scan'] scans = scans.to(device) intensities = batch['intensity'] intensities = intensities.to(device) angle_incidence = batch['angle_incidence'] angle_incidence = angle_incidence.to(device) labels = batch['label'] labels = labels.to(device) batch_size = scans.size(0) # set all gradients to 0: optimizer.zero_grad() # feed the batch to the network: semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence) # calculate the semantic ce loss: ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size) lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long)) lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum() # beta-vae: loss = ce_loss + BETA*kl_loss + lovasz_loss # perform back propagation: loss.backward(torch.ones_like(loss)) optimizer.step() # get the loss: # multiple GPUs: if torch.cuda.device_count() > 1: loss = loss.mean() ce_loss = ce_loss.mean() kl_loss = lovasz_loss.mean() #kl_loss.mean() running_loss += loss.item() # kl_divergence: kl_avg_loss += lovasz_loss.item() #kl_loss.item() # CE loss: ce_avg_loss += ce_loss.item() # display informational message: if(i % 512 == 0): print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Lovasz_Loss: {:.4f}' .format(epoch, epochs, i + 1, num_batches, loss.item(), ce_loss.item(), lovasz_loss.item())) train_loss = running_loss / counter train_kl_loss = kl_avg_loss / counter train_ce_loss = ce_avg_loss / counter return train_loss, train_kl_loss, train_ce_loss # validate function: def validate(model, dataloader, dataset, device, ce_criterion, lovasz_criterion, class_weights): # set model to evaluation mode: model.eval() # for each batch in increments of batch size: running_loss = 0.0 # kl_divergence: kl_avg_loss = 0.0 # CE loss: ce_avg_loss = 0.0 counter = 0 # get the number of batches (ceiling of train_data/batch_size): num_batches = int(len(dataset)/dataloader.batch_size) with torch.no_grad(): for i, batch in tqdm(enumerate(dataloader), total=num_batches): #for i, batch in enumerate(dataloader, 0): counter += 1 # collect the samples as a batch: scans = batch['scan'] scans = scans.to(device) intensities = batch['intensity'] intensities = intensities.to(device) angle_incidence = batch['angle_incidence'] angle_incidence = angle_incidence.to(device) labels = batch['label'] labels = labels.to(device) batch_size = scans.size(0) # feed the batch to the network: semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence) # calculate the semantic ce loss: ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size) lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long)) lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum() # beta-vae: loss = ce_loss + BETA*kl_loss + lovasz_loss # multiple GPUs: if torch.cuda.device_count() > 1: loss = loss.mean() ce_loss = ce_loss.mean() kl_loss = lovasz_loss.mean() #kl_loss.mean() running_loss += loss.item() # kl_divergence: kl_avg_loss += lovasz_loss.item() #kl_loss.item() # CE loss: ce_avg_loss += ce_loss.item() val_loss = running_loss / counter val_kl_loss = kl_avg_loss / counter val_ce_loss = ce_avg_loss / counter return val_loss, val_kl_loss, val_ce_loss #------------------------------------------------------------------------------ # # the main program starts here # #------------------------------------------------------------------------------ # function: main # # arguments: none # # return: none # # This method is the main function. # def main(argv): # ensure we have the correct amount of arguments: #global cur_batch_win if(len(argv) != NUM_ARGS): print("usage: python train.py [MDL_PATH] [TRAIN_PATH] [DEV_PATH] [TRAIN_MASK_PATH] [DEV_MASK_PATH]") exit(-1) # define local variables: mdl_path = argv[0] pTrain = argv[1] pDev = argv[2] # get the output directory name: odir = os.path.dirname(mdl_path) # if the odir doesn't exits, we make it: if not os.path.exists(odir): os.makedirs(odir) # set the device to use GPU if available: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print('...Start reading data...') ### training data ### # training set and training data loader train_dataset = VaeTestDataset(pTrain, 'train') train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, \ shuffle=True, drop_last=True, pin_memory=True) ### validation data ### # validation set and validation data loader dev_dataset = VaeTestDataset(pDev, 'dev') dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, num_workers=2, \ shuffle=True, drop_last=True, pin_memory=True) # calculate the class weights: class_weights = np.array([2.514399, 1.4917144, 0.51608694, 0.659483, 1.0900991, 1.6461798, 0.32852992, 1.5633508, 0.9236576, 0.10251398]) # median frequency balance #class_weights = np.array([1.4222778, 2.1834621, 40.17538]) # inverse log class_probability class_weights = torch.Tensor(class_weights) print("class weights: ", class_weights) class_weights.to(device) print('...Finish reading data...') # instantiate a model: model = S3Net(input_channels=NUM_INPUT_CHANNELS, output_channels=NUM_OUTPUT_CHANNELS) # moves the model to device (cpu in our case so no change): model.to(device) # set the adam optimizer parameters: opt_params = { LEARNING_RATE: 0.001, BETAS: (.9,0.999), EPS: 1e-08, WEIGHT_DECAY: .001 } # set the loss criterion and optimizer: ce_criterion = nn.CrossEntropyLoss(reduction='sum', weight=class_weights) ce_criterion.to(device) lovasz_criterion = L.LovaszSoftmax(reduction='sum', ignore_index=0) lovasz_criterion.to(device) # create an optimizer, and pass the model params to it: optimizer = Adam(model.parameters(), **opt_params) # get the number of epochs to train on: epochs = NUM_EPOCHS # if there are trained models, continue training: if os.path.exists(mdl_path): checkpoint = torch.load(mdl_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] print('Load epoch {} success'.format(start_epoch)) else: start_epoch = 0 #pre_path = "./model/model_segnet_weight.pth" #pretrained_model = torch.load(pre_path) #model.load_state_dict(pretrained_model['model']) print('No trained models, restart training') # multiple GPUs: if torch.cuda.device_count() > 1: print("Let's use 2 of total", torch.cuda.device_count(), "GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs model = nn.DataParallel(model) #, device_ids=[0, 1]) # moves the model to device (cpu in our case so no change): model.to(device) # tensorboard writer: writer = SummaryWriter('runs') epoch_num = 0 for epoch in range(start_epoch+1, epochs): # adjust learning rate: adjust_learning_rate(optimizer, epoch) ################################## Train ##################################### # for each batch in increments of batch size # train_epoch_loss, train_kl_epoch_loss, train_ce_epoch_loss = train( model, train_dataloader, train_dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs ) valid_epoch_loss, valid_kl_epoch_loss, valid_ce_epoch_loss = validate( model, dev_dataloader, dev_dataset, device, ce_criterion, lovasz_criterion, class_weights ) # log the epoch loss writer.add_scalar('training loss', train_epoch_loss, epoch) writer.add_scalar('training kl loss', train_kl_epoch_loss, epoch) writer.add_scalar('training ce loss', train_ce_epoch_loss, epoch) writer.add_scalar('validation loss', valid_epoch_loss, epoch) writer.add_scalar('validation kl loss', valid_kl_epoch_loss, epoch) writer.add_scalar('validation ce loss', valid_ce_epoch_loss, epoch) print('Train set: Average loss: {:.4f}'.format(train_epoch_loss)) print('Validation set: Average loss: {:.4f}'.format(valid_epoch_loss)) # save the model: if(epoch % 2000 == 0): if torch.cuda.device_count() > 1: # multiple GPUS: state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} else: state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} path='./model/model' + str(epoch) +'.pth' torch.save(state, path) epoch_num = epoch # save the final model if torch.cuda.device_count() > 1: # multiple GPUS: state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num} else: state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num} torch.save(state, mdl_path) # exit gracefully # return True # # end of function # begin gracefully # if __name__ == '__main__': main(sys.argv[1:]) # # end of file