|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.optim import Adam |
|
|
from tqdm import tqdm |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
from tensorboardX import SummaryWriter |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
from model import * |
|
|
import lovasz_losses as L |
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_dir = './model/s3_net_model.pth' |
|
|
NUM_ARGS = 3 |
|
|
NUM_EPOCHS = 20000 |
|
|
BATCH_SIZE = 1024 |
|
|
LEARNING_RATE = "lr" |
|
|
BETAS = "betas" |
|
|
EPS = "eps" |
|
|
WEIGHT_DECAY = "weight_decay" |
|
|
|
|
|
|
|
|
NUM_INPUT_CHANNELS = 3 |
|
|
NUM_OUTPUT_CHANNELS = 10 |
|
|
BETA = 0.01 |
|
|
|
|
|
|
|
|
|
|
|
set_seed(SEED1) |
|
|
|
|
|
|
|
|
|
|
|
def adjust_learning_rate(optimizer, epoch): |
|
|
lr = 1e-4 |
|
|
if epoch > 50000: |
|
|
lr = 2e-5 |
|
|
if epoch > 480000: |
|
|
|
|
|
lr = lr * (0.1 ** (epoch // 110000)) |
|
|
|
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
|
|
|
|
|
|
|
|
|
def train(model, dataloader, dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs): |
|
|
|
|
|
model.train() |
|
|
|
|
|
running_loss = 0.0 |
|
|
|
|
|
kl_avg_loss = 0.0 |
|
|
|
|
|
ce_avg_loss = 0.0 |
|
|
|
|
|
counter = 0 |
|
|
|
|
|
num_batches = int(len(dataset)/dataloader.batch_size) |
|
|
for i, batch in tqdm(enumerate(dataloader), total=num_batches): |
|
|
|
|
|
counter += 1 |
|
|
|
|
|
scans = batch['scan'] |
|
|
scans = scans.to(device) |
|
|
intensities = batch['intensity'] |
|
|
intensities = intensities.to(device) |
|
|
angle_incidence = batch['angle_incidence'] |
|
|
angle_incidence = angle_incidence.to(device) |
|
|
labels = batch['label'] |
|
|
labels = labels.to(device) |
|
|
|
|
|
batch_size = scans.size(0) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence) |
|
|
|
|
|
ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size) |
|
|
lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long)) |
|
|
lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum() |
|
|
|
|
|
loss = ce_loss + BETA*kl_loss + lovasz_loss |
|
|
|
|
|
loss.backward(torch.ones_like(loss)) |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
|
loss = loss.mean() |
|
|
ce_loss = ce_loss.mean() |
|
|
kl_loss = lovasz_loss.mean() |
|
|
|
|
|
running_loss += loss.item() |
|
|
|
|
|
kl_avg_loss += lovasz_loss.item() |
|
|
|
|
|
ce_avg_loss += ce_loss.item() |
|
|
|
|
|
|
|
|
if(i % 512 == 0): |
|
|
print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Lovasz_Loss: {:.4f}' |
|
|
.format(epoch, epochs, i + 1, num_batches, loss.item(), ce_loss.item(), lovasz_loss.item())) |
|
|
|
|
|
train_loss = running_loss / counter |
|
|
train_kl_loss = kl_avg_loss / counter |
|
|
train_ce_loss = ce_avg_loss / counter |
|
|
|
|
|
return train_loss, train_kl_loss, train_ce_loss |
|
|
|
|
|
|
|
|
def validate(model, dataloader, dataset, device, ce_criterion, lovasz_criterion, class_weights): |
|
|
|
|
|
model.eval() |
|
|
|
|
|
running_loss = 0.0 |
|
|
|
|
|
kl_avg_loss = 0.0 |
|
|
|
|
|
ce_avg_loss = 0.0 |
|
|
|
|
|
counter = 0 |
|
|
|
|
|
num_batches = int(len(dataset)/dataloader.batch_size) |
|
|
with torch.no_grad(): |
|
|
for i, batch in tqdm(enumerate(dataloader), total=num_batches): |
|
|
|
|
|
counter += 1 |
|
|
|
|
|
scans = batch['scan'] |
|
|
scans = scans.to(device) |
|
|
intensities = batch['intensity'] |
|
|
intensities = intensities.to(device) |
|
|
angle_incidence = batch['angle_incidence'] |
|
|
angle_incidence = angle_incidence.to(device) |
|
|
labels = batch['label'] |
|
|
labels = labels.to(device) |
|
|
|
|
|
batch_size = scans.size(0) |
|
|
|
|
|
|
|
|
semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence) |
|
|
|
|
|
ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size) |
|
|
lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long)) |
|
|
lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum() |
|
|
|
|
|
loss = ce_loss + BETA*kl_loss + lovasz_loss |
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
|
loss = loss.mean() |
|
|
ce_loss = ce_loss.mean() |
|
|
kl_loss = lovasz_loss.mean() |
|
|
|
|
|
running_loss += loss.item() |
|
|
|
|
|
kl_avg_loss += lovasz_loss.item() |
|
|
|
|
|
ce_avg_loss += ce_loss.item() |
|
|
|
|
|
val_loss = running_loss / counter |
|
|
val_kl_loss = kl_avg_loss / counter |
|
|
val_ce_loss = ce_avg_loss / counter |
|
|
|
|
|
return val_loss, val_kl_loss, val_ce_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(argv): |
|
|
|
|
|
|
|
|
if(len(argv) != NUM_ARGS): |
|
|
print("usage: python train.py [MDL_PATH] [TRAIN_PATH] [DEV_PATH] [TRAIN_MASK_PATH] [DEV_MASK_PATH]") |
|
|
exit(-1) |
|
|
|
|
|
|
|
|
mdl_path = argv[0] |
|
|
pTrain = argv[1] |
|
|
pDev = argv[2] |
|
|
|
|
|
|
|
|
odir = os.path.dirname(mdl_path) |
|
|
|
|
|
|
|
|
if not os.path.exists(odir): |
|
|
os.makedirs(odir) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
print('...Start reading data...') |
|
|
|
|
|
|
|
|
train_dataset = VaeTestDataset(pTrain, 'train') |
|
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, \ |
|
|
shuffle=True, drop_last=True, pin_memory=True) |
|
|
|
|
|
|
|
|
|
|
|
dev_dataset = VaeTestDataset(pDev, 'dev') |
|
|
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, num_workers=2, \ |
|
|
shuffle=True, drop_last=True, pin_memory=True) |
|
|
|
|
|
|
|
|
class_weights = np.array([2.514399, 1.4917144, 0.51608694, 0.659483, 1.0900991, 1.6461798, 0.32852992, 1.5633508, 0.9236576, 0.10251398]) |
|
|
|
|
|
|
|
|
class_weights = torch.Tensor(class_weights) |
|
|
print("class weights: ", class_weights) |
|
|
class_weights.to(device) |
|
|
print('...Finish reading data...') |
|
|
|
|
|
|
|
|
model = S3Net(input_channels=NUM_INPUT_CHANNELS, |
|
|
output_channels=NUM_OUTPUT_CHANNELS) |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
|
|
|
opt_params = { LEARNING_RATE: 0.001, |
|
|
BETAS: (.9,0.999), |
|
|
EPS: 1e-08, |
|
|
WEIGHT_DECAY: .001 } |
|
|
|
|
|
ce_criterion = nn.CrossEntropyLoss(reduction='sum', weight=class_weights) |
|
|
ce_criterion.to(device) |
|
|
lovasz_criterion = L.LovaszSoftmax(reduction='sum', ignore_index=0) |
|
|
lovasz_criterion.to(device) |
|
|
|
|
|
optimizer = Adam(model.parameters(), **opt_params) |
|
|
|
|
|
|
|
|
epochs = NUM_EPOCHS |
|
|
|
|
|
|
|
|
if os.path.exists(mdl_path): |
|
|
checkpoint = torch.load(mdl_path) |
|
|
model.load_state_dict(checkpoint['model']) |
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
start_epoch = checkpoint['epoch'] |
|
|
print('Load epoch {} success'.format(start_epoch)) |
|
|
else: |
|
|
start_epoch = 0 |
|
|
|
|
|
|
|
|
|
|
|
print('No trained models, restart training') |
|
|
|
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
|
print("Let's use 2 of total", torch.cuda.device_count(), "GPUs!") |
|
|
|
|
|
model = nn.DataParallel(model) |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
|
|
|
writer = SummaryWriter('runs') |
|
|
|
|
|
epoch_num = 0 |
|
|
for epoch in range(start_epoch+1, epochs): |
|
|
|
|
|
adjust_learning_rate(optimizer, epoch) |
|
|
|
|
|
|
|
|
|
|
|
train_epoch_loss, train_kl_epoch_loss, train_ce_epoch_loss = train( |
|
|
model, train_dataloader, train_dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs |
|
|
) |
|
|
valid_epoch_loss, valid_kl_epoch_loss, valid_ce_epoch_loss = validate( |
|
|
model, dev_dataloader, dev_dataset, device, ce_criterion, lovasz_criterion, class_weights |
|
|
) |
|
|
|
|
|
|
|
|
writer.add_scalar('training loss', |
|
|
train_epoch_loss, |
|
|
epoch) |
|
|
writer.add_scalar('training kl loss', |
|
|
train_kl_epoch_loss, |
|
|
epoch) |
|
|
writer.add_scalar('training ce loss', |
|
|
train_ce_epoch_loss, |
|
|
epoch) |
|
|
|
|
|
writer.add_scalar('validation loss', |
|
|
valid_epoch_loss, |
|
|
epoch) |
|
|
writer.add_scalar('validation kl loss', |
|
|
valid_kl_epoch_loss, |
|
|
epoch) |
|
|
writer.add_scalar('validation ce loss', |
|
|
valid_ce_epoch_loss, |
|
|
epoch) |
|
|
|
|
|
print('Train set: Average loss: {:.4f}'.format(train_epoch_loss)) |
|
|
print('Validation set: Average loss: {:.4f}'.format(valid_epoch_loss)) |
|
|
|
|
|
|
|
|
if(epoch % 2000 == 0): |
|
|
if torch.cuda.device_count() > 1: |
|
|
state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} |
|
|
else: |
|
|
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} |
|
|
path='./model/model' + str(epoch) +'.pth' |
|
|
torch.save(state, path) |
|
|
|
|
|
epoch_num = epoch |
|
|
|
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
|
state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num} |
|
|
else: |
|
|
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num} |
|
|
torch.save(state, mdl_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main(sys.argv[1:]) |
|
|
|
|
|
|
|
|
|