RustCoSeg / train.py
HirraA's picture
Upload 30 files
168ec29 verified
import torch
import torch.nn as nn
import torch.optim as optim
from util import Logger, AverageMeter, save_checkpoint, save_tensor_img, set_seed
import os
import numpy as np
from matplotlib import pyplot as plt
import time
import argparse
from tqdm import tqdm
from dataset import get_loader
from loss import *
from config import Config
from evaluation.dataloader import EvalDataset
from evaluation.evaluator import Eval_thread
from models.main import *
import torch.nn.functional as F
import pytorch_toolbelt.losses as PTL
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# Parameter from command line
parser = argparse.ArgumentParser(description='')
parser.add_argument('--loss',
default='Scale_IoU',
type=str,
help="Options: '', ''")
parser.add_argument('--bs', '--batch_size', default=1, type=int)
parser.add_argument('--lr',
'--learning_rate',
default=1e-4,
type=float,
help='Initial learning rate')
parser.add_argument('--resume',
default=None,
type=str,
help='path to latest checkpoint')
parser.add_argument('--epochs', default=200, type=int)
parser.add_argument('--start_epoch',
default=0,
type=int,
help='manual epoch number (useful on restarts)')
parser.add_argument('--trainset',
default='CoCo',
type=str,
help="Options: 'CoCo'")
parser.add_argument('--testsets',
default='CoCA',
type=str,
help="Options: 'CoCA','CoSal2015','CoSOD3k','iCoseg','MSRC'")
parser.add_argument('--size',
default=224,
type=int,
help='input size')
parser.add_argument('--tmp', default='/data1/dcfm/temp', help='Temporary folder')
parser.add_argument('--save_root', default='./CoSODmaps/pred', type=str, help='Output folder')
args = parser.parse_args()
config = Config()
# Prepare dataset
if args.trainset == 'CoCo':
train_img_path = './data/CoCo/img/'
train_gt_path = './data/CoCo/gt/'
train_loader = get_loader(train_img_path,
train_gt_path,
args.size,
args.bs,
max_num=16, #20,
istrain=True,
shuffle=False,
num_workers=8, #4,
pin=True)
else:
print('Unkonwn train dataset')
print(args.dataset)
for testset in ['CoCA']:
if testset == 'CoCA':
test_img_path = './data/images/CoCA/'
test_gt_path = './data/gts/CoCA/'
saved_root = os.path.join(args.save_root, 'CoCA')
elif testset == 'CoSOD3k':
test_img_path = './data/images/CoSOD3k/'
test_gt_path = './data/gts/CoSOD3k/'
saved_root = os.path.join(args.save_root, 'CoSOD3k')
elif testset == 'CoSal2015':
test_img_path = './data/images/CoSal2015/'
test_gt_path = './data/gts/CoSal2015/'
saved_root = os.path.join(args.save_root, 'CoSal2015')
elif testset == 'CoCo':
test_img_path = './data/images/CoCo/'
test_gt_path = './data/gts/CoCo/'
saved_root = os.path.join(args.save_root, 'CoCo')
else:
print('Unkonwn test dataset')
print(args.dataset)
test_loader = get_loader(
test_img_path, test_gt_path, args.size, 1, istrain=False, shuffle=False, num_workers=8, pin=True)
# make dir for tmp
os.makedirs(args.tmp, exist_ok=True)
# Init log file
logger = Logger(os.path.join(args.tmp, "log.txt"))
set_seed(123)
# Init model
device = torch.device("cuda")
model = DCFM()
model = model.to(device)
model.apply(weights_init)
model.dcfmnet.backbone._initialize_weights(torch.load('./models/vgg16-397923af.pth'))
backbone_params = list(map(id, model.dcfmnet.backbone.parameters()))
base_params = filter(lambda p: id(p) not in backbone_params,
model.dcfmnet.parameters())
all_params = [{'params': base_params}, {'params': model.dcfmnet.backbone.parameters(), 'lr': args.lr*0.1}]
# Setting optimizer
optimizer = optim.Adam(params=all_params,lr=args.lr, weight_decay=1e-4, betas=[0.9, 0.99])
for key, value in model.named_parameters():
if 'dcfmnet.backbone' in key and 'dcfmnet.backbone.conv5.conv5_3' not in key:
value.requires_grad = False
for key, value in model.named_parameters():
print(key, value.requires_grad)
# log model and optimizer pars
logger.info("Model details:")
logger.info(model)
logger.info("Optimizer details:")
logger.info(optimizer)
logger.info("Scheduler details:")
# logger.info(scheduler)
logger.info("Other hyperparameters:")
logger.info(args)
# Setting Loss
exec('from loss import ' + args.loss)
IOUloss = eval(args.loss+'()')
def main():
val_measures = []
# Optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
logger.info("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
model.dcfmnet.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
logger.info("=> loaded checkpoint '{}' (epoch {})".format(
args.resume, checkpoint['epoch']))
else:
logger.info("=> no checkpoint found at '{}'".format(args.resume))
print(args.epochs)
for epoch in range(args.start_epoch, args.epochs):
train_loss = train(epoch)
if config.validation:
measures = validate(model, test_loader, args.testsets)
val_measures.append(measures)
print(
'Validation: S_measure on CoCA for epoch-{} is {:.4f}. Best epoch is epoch-{} with S_measure {:.4f}'.format(
epoch, measures[0], np.argmax(np.array(val_measures)[:, 0].squeeze()),
np.max(np.array(val_measures)[:, 0]))
)
# Save checkpoint
save_checkpoint(
{
'epoch': epoch + 1,
'state_dict': model.dcfmnet.state_dict(),
#'scheduler': scheduler.state_dict(),
},
path=args.tmp)
if config.validation:
if np.max(np.array(val_measures)[:, 0].squeeze()) == measures[0]:
best_weights_before = [os.path.join(args.tmp, weight_file) for weight_file in
os.listdir(args.tmp) if 'best_' in weight_file]
for best_weight_before in best_weights_before:
os.remove(best_weight_before)
torch.save(model.dcfmnet.state_dict(),
os.path.join(args.tmp, 'best_ep{}_Smeasure{:.4f}.pth'.format(epoch, measures[0])))
if (epoch + 1) % 10 == 0 or epoch == 0:
torch.save(model.dcfmnet.state_dict(), args.tmp + '/model-' + str(epoch + 1) + '.pt')
if epoch > 188:
torch.save(model.dcfmnet.state_dict(), args.tmp+'/model-' + str(epoch + 1) + '.pt')
#dcfmnet_dict = model.dcfmnet.state_dict()
#torch.save(dcfmnet_dict, os.path.join(args.tmp, 'final.pth'))
def sclloss(x, xt, xb):
cosc = (1+compute_cos_dis(x, xt))*0.5
cosb = (1+compute_cos_dis(x, xb))*0.5
loss = -torch.log(cosc+1e-5)-torch.log(1-cosb+1e-5)
return loss.sum()
def train(epoch):
# Switch to train mode
model.train()
model.set_mode('train')
loss_sum = 0.0
loss_sumkl = 0.0
for batch_idx, batch in enumerate(train_loader):
inputs = batch[0].to(device).squeeze(0)
gts = batch[1].to(device).squeeze(0)
pred, proto, protogt, protobg = model(inputs, gts)
loss_iou = IOUloss(pred, gts)
loss_scl = sclloss(proto, protogt, protobg)
loss = loss_iou+0.1*loss_scl
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum = loss_sum + loss_iou.detach().item()
if batch_idx % 20 == 0:
logger.info('Epoch[{0}/{1}] Iter[{2}/{3}] '
'Train Loss: loss_iou: {4:.3f}, loss_scl: {5:.3f} '.format(
epoch,
args.epochs,
batch_idx,
len(train_loader),
loss_iou,
loss_scl,
))
loss_mean = loss_sum / len(train_loader)
return loss_sum
def validate(model, test_loaders, testsets):
model.eval()
testsets = testsets.split('+')
measures = []
for testset in testsets[:1]:
print('Validating {}...'.format(testset))
#test_loader = test_loaders[testset]
saved_root = os.path.join(args.save_root, testset)
for batch in test_loader:
inputs = batch[0].to(device).squeeze(0)
gts = batch[1].to(device).squeeze(0)
subpaths = batch[2]
ori_sizes = batch[3]
with torch.no_grad():
scaled_preds = model(inputs, gts)[-1].sigmoid()
os.makedirs(os.path.join(saved_root, subpaths[0][0].split('/')[0]), exist_ok=True)
num = len(scaled_preds)
for inum in range(num):
subpath = subpaths[inum][0]
ori_size = (ori_sizes[inum][0].item(), ori_sizes[inum][1].item())
res = nn.functional.interpolate(scaled_preds[inum].unsqueeze(0), size=ori_size, mode='bilinear',
align_corners=True)
save_tensor_img(res, os.path.join(saved_root, subpath))
eval_loader = EvalDataset(
saved_root, # preds
os.path.join('./data/gts', testset) # GT
)
evaler = Eval_thread(eval_loader, cuda=True)
# Use S_measure for validation
s_measure = evaler.Eval_Smeasure()
if s_measure > config.val_measures['Smeasure']['CoCA'] and 0:
# TODO: evluate others measures if s_measure is very high.
e_max = evaler.Eval_Emeasure().max().item()
f_max = evaler.Eval_fmeasure().max().item()
print('Emax: {:4.f}, Fmax: {:4.f}'.format(e_max, f_max))
measures.append(s_measure)
model.train()
return measures
if __name__ == '__main__':
main()