s3_net / scripts /train.py
zzuxzt's picture
Upload folder using huggingface_hub
d9c5371 verified
#!/usr/bin/env python
#
# file: $ISIP_EXP/SOGMP/scripts/train.py
#
# revision history: xzt
# 20220824 (TE): first version
#
# usage:
# python train.py mdir train_data val_data
#
# arguments:
# mdir: the directory where the output model is stored
# train_data: the directory of training data
# val_data: the directory of valiation data
#
# This script trains a S3-Net model
#------------------------------------------------------------------------------
# import pytorch modules
#
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
import torch.nn.functional as F
# visualize:
from tensorboardX import SummaryWriter
import numpy as np
# import the model and all of its variables/functions
#
from model import *
import lovasz_losses as L
# import modules
#
import sys
import os
#-----------------------------------------------------------------------------
#
# global variables are listed here
#
#-----------------------------------------------------------------------------
# general global values
#
model_dir = './model/s3_net_model.pth' # the path of model storage
NUM_ARGS = 3
NUM_EPOCHS = 20000
BATCH_SIZE = 1024
LEARNING_RATE = "lr"
BETAS = "betas"
EPS = "eps"
WEIGHT_DECAY = "weight_decay"
# Constants
NUM_INPUT_CHANNELS = 3
NUM_OUTPUT_CHANNELS = 10 # 9 classes of semantic labels + 1 background
BETA = 0.01
# for reproducibility, we seed the rng
#
set_seed(SEED1)
# adjust_learning_rate
# 
def adjust_learning_rate(optimizer, epoch):
lr = 1e-4
if epoch > 50000:
lr = 2e-5
if epoch > 480000:
# lr = 5e-8
lr = lr * (0.1 ** (epoch // 110000))
# if epoch > 8300:
# lr = 1e-9
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# train function:
def train(model, dataloader, dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs):
# set model to training mode:
model.train()
# for each batch in increments of batch size:
running_loss = 0.0
# kl_divergence:
kl_avg_loss = 0.0
# CE loss:
ce_avg_loss = 0.0
counter = 0
# get the number of batches (ceiling of train_data/batch_size):
num_batches = int(len(dataset)/dataloader.batch_size)
for i, batch in tqdm(enumerate(dataloader), total=num_batches):
#for i, batch in enumerate(dataloader, 0):
counter += 1
# collect the samples as a batch:
scans = batch['scan']
scans = scans.to(device)
intensities = batch['intensity']
intensities = intensities.to(device)
angle_incidence = batch['angle_incidence']
angle_incidence = angle_incidence.to(device)
labels = batch['label']
labels = labels.to(device)
batch_size = scans.size(0)
# set all gradients to 0:
optimizer.zero_grad()
# feed the batch to the network:
semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence)
# calculate the semantic ce loss:
ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size)
lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long))
lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum()
# beta-vae:
loss = ce_loss + BETA*kl_loss + lovasz_loss
# perform back propagation:
loss.backward(torch.ones_like(loss))
optimizer.step()
# get the loss:
# multiple GPUs:
if torch.cuda.device_count() > 1:
loss = loss.mean()
ce_loss = ce_loss.mean()
kl_loss = lovasz_loss.mean() #kl_loss.mean()
running_loss += loss.item()
# kl_divergence:
kl_avg_loss += lovasz_loss.item() #kl_loss.item()
# CE loss:
ce_avg_loss += ce_loss.item()
# display informational message:
if(i % 512 == 0):
print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Lovasz_Loss: {:.4f}'
.format(epoch, epochs, i + 1, num_batches, loss.item(), ce_loss.item(), lovasz_loss.item()))
train_loss = running_loss / counter
train_kl_loss = kl_avg_loss / counter
train_ce_loss = ce_avg_loss / counter
return train_loss, train_kl_loss, train_ce_loss
# validate function:
def validate(model, dataloader, dataset, device, ce_criterion, lovasz_criterion, class_weights):
# set model to evaluation mode:
model.eval()
# for each batch in increments of batch size:
running_loss = 0.0
# kl_divergence:
kl_avg_loss = 0.0
# CE loss:
ce_avg_loss = 0.0
counter = 0
# get the number of batches (ceiling of train_data/batch_size):
num_batches = int(len(dataset)/dataloader.batch_size)
with torch.no_grad():
for i, batch in tqdm(enumerate(dataloader), total=num_batches):
#for i, batch in enumerate(dataloader, 0):
counter += 1
# collect the samples as a batch:
scans = batch['scan']
scans = scans.to(device)
intensities = batch['intensity']
intensities = intensities.to(device)
angle_incidence = batch['angle_incidence']
angle_incidence = angle_incidence.to(device)
labels = batch['label']
labels = labels.to(device)
batch_size = scans.size(0)
# feed the batch to the network:
semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence)
# calculate the semantic ce loss:
ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size)
lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long))
lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum()
# beta-vae:
loss = ce_loss + BETA*kl_loss + lovasz_loss
# multiple GPUs:
if torch.cuda.device_count() > 1:
loss = loss.mean()
ce_loss = ce_loss.mean()
kl_loss = lovasz_loss.mean() #kl_loss.mean()
running_loss += loss.item()
# kl_divergence:
kl_avg_loss += lovasz_loss.item() #kl_loss.item()
# CE loss:
ce_avg_loss += ce_loss.item()
val_loss = running_loss / counter
val_kl_loss = kl_avg_loss / counter
val_ce_loss = ce_avg_loss / counter
return val_loss, val_kl_loss, val_ce_loss
#------------------------------------------------------------------------------
#
# the main program starts here
#
#------------------------------------------------------------------------------
# function: main
#
# arguments: none
#
# return: none
#
# This method is the main function.
#
def main(argv):
# ensure we have the correct amount of arguments:
#global cur_batch_win
if(len(argv) != NUM_ARGS):
print("usage: python train.py [MDL_PATH] [TRAIN_PATH] [DEV_PATH] [TRAIN_MASK_PATH] [DEV_MASK_PATH]")
exit(-1)
# define local variables:
mdl_path = argv[0]
pTrain = argv[1]
pDev = argv[2]
# get the output directory name:
odir = os.path.dirname(mdl_path)
# if the odir doesn't exits, we make it:
if not os.path.exists(odir):
os.makedirs(odir)
# set the device to use GPU if available:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('...Start reading data...')
### training data ###
# training set and training data loader
train_dataset = VaeTestDataset(pTrain, 'train')
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, \
shuffle=True, drop_last=True, pin_memory=True)
### validation data ###
# validation set and validation data loader
dev_dataset = VaeTestDataset(pDev, 'dev')
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, num_workers=2, \
shuffle=True, drop_last=True, pin_memory=True)
# calculate the class weights:
class_weights = np.array([2.514399, 1.4917144, 0.51608694, 0.659483, 1.0900991, 1.6461798, 0.32852992, 1.5633508, 0.9236576, 0.10251398]) # median frequency balance
#class_weights = np.array([1.4222778, 2.1834621, 40.17538]) # inverse log class_probability
class_weights = torch.Tensor(class_weights)
print("class weights: ", class_weights)
class_weights.to(device)
print('...Finish reading data...')
# instantiate a model:
model = S3Net(input_channels=NUM_INPUT_CHANNELS,
output_channels=NUM_OUTPUT_CHANNELS)
# moves the model to device (cpu in our case so no change):
model.to(device)
# set the adam optimizer parameters:
opt_params = { LEARNING_RATE: 0.001,
BETAS: (.9,0.999),
EPS: 1e-08,
WEIGHT_DECAY: .001 }
# set the loss criterion and optimizer:
ce_criterion = nn.CrossEntropyLoss(reduction='sum', weight=class_weights)
ce_criterion.to(device)
lovasz_criterion = L.LovaszSoftmax(reduction='sum', ignore_index=0)
lovasz_criterion.to(device)
# create an optimizer, and pass the model params to it:
optimizer = Adam(model.parameters(), **opt_params)
# get the number of epochs to train on:
epochs = NUM_EPOCHS
# if there are trained models, continue training:
if os.path.exists(mdl_path):
checkpoint = torch.load(mdl_path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
print('Load epoch {} success'.format(start_epoch))
else:
start_epoch = 0
#pre_path = "./model/model_segnet_weight.pth"
#pretrained_model = torch.load(pre_path)
#model.load_state_dict(pretrained_model['model'])
print('No trained models, restart training')
# multiple GPUs:
if torch.cuda.device_count() > 1:
print("Let's use 2 of total", torch.cuda.device_count(), "GPUs!")
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = nn.DataParallel(model) #, device_ids=[0, 1])
# moves the model to device (cpu in our case so no change):
model.to(device)
# tensorboard writer:
writer = SummaryWriter('runs')
epoch_num = 0
for epoch in range(start_epoch+1, epochs):
# adjust learning rate:
adjust_learning_rate(optimizer, epoch)
################################## Train #####################################
# for each batch in increments of batch size
#
train_epoch_loss, train_kl_epoch_loss, train_ce_epoch_loss = train(
model, train_dataloader, train_dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs
)
valid_epoch_loss, valid_kl_epoch_loss, valid_ce_epoch_loss = validate(
model, dev_dataloader, dev_dataset, device, ce_criterion, lovasz_criterion, class_weights
)
# log the epoch loss
writer.add_scalar('training loss',
train_epoch_loss,
epoch)
writer.add_scalar('training kl loss',
train_kl_epoch_loss,
epoch)
writer.add_scalar('training ce loss',
train_ce_epoch_loss,
epoch)
writer.add_scalar('validation loss',
valid_epoch_loss,
epoch)
writer.add_scalar('validation kl loss',
valid_kl_epoch_loss,
epoch)
writer.add_scalar('validation ce loss',
valid_ce_epoch_loss,
epoch)
print('Train set: Average loss: {:.4f}'.format(train_epoch_loss))
print('Validation set: Average loss: {:.4f}'.format(valid_epoch_loss))
# save the model:
if(epoch % 2000 == 0):
if torch.cuda.device_count() > 1: # multiple GPUS:
state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
else:
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
path='./model/model' + str(epoch) +'.pth'
torch.save(state, path)
epoch_num = epoch
# save the final model
if torch.cuda.device_count() > 1: # multiple GPUS:
state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num}
else:
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num}
torch.save(state, mdl_path)
# exit gracefully
#
return True
#
# end of function
# begin gracefully
#
if __name__ == '__main__':
main(sys.argv[1:])
#
# end of file