tempsal / src /train.py
baharay's picture
Upload 31 files
a3fc39c verified
import argparse
import os
import torch
import sys
import time
import wandb
import torch.nn as nn
from tqdm import tqdm
from dataloader import SaliconDataset
from loss import *
from utils import AverageMeter
from utils import img_save
from torchvision import utils
import torch.nn.functional as nnf
from os.path import join
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument('--no_epochs',default=30, type=int)
parser.add_argument('--lr',default=1e-5, type=float)
parser.add_argument('--kldiv',default=True, type=bool)
parser.add_argument('--cc',default=True, type=bool)
parser.add_argument('--nss',default=False, type=bool)
parser.add_argument('--sim',default=False, type=bool)
parser.add_argument('--nss_emlnet',default=False, type=bool)
parser.add_argument('--nss_norm',default=False, type=bool)
parser.add_argument('--l1',default=False, type=bool)
parser.add_argument('--lr_sched',default=False, type=bool)
parser.add_argument('--dilation',default=False, type=bool)
parser.add_argument('--enc_model',default="pnas", type=str)
parser.add_argument('--optim',default="Adam", type=str)
parser.add_argument('--load_weight',default=1, type=int)
parser.add_argument('--kldiv_coeff',default=1.0, type=float)
parser.add_argument('--step_size',default=5, type=int)
parser.add_argument('--cc_coeff',default=-1.0, type=float)
parser.add_argument('--sim_coeff',default=-1.0, type=float)
parser.add_argument('--nss_coeff',default=-1.0, type=float)
parser.add_argument('--nss_emlnet_coeff',default=1.0, type=float)
parser.add_argument('--nss_norm_coeff',default=1.0, type=float)
parser.add_argument('--l1_coeff',default=1.0, type=float)
parser.add_argument('--train_enc',default=1, type=int)
parser.add_argument('--dataset_dir',default="../data/", type=str)
parser.add_argument('--batch_size',default=32, type=int)
parser.add_argument('--log_interval',default=60, type=int)
parser.add_argument('--no_workers',default=4, type=int)
parser.add_argument('--train_model',default=False, type=bool)
parser.add_argument('--time_slices',default=5, type=int)
parser.add_argument('--selected_slices',default="", type=str)
parser.add_argument('--results_dir',default="", type=str )
# Path to save the model weights
parser.add_argument('--model_val_path',default="model.pt", type=str)
# If the model type is pnas_boosted, specify the path of the pre-trained pnas model here
parser.add_argument('--model_path',default="", type=str)
# If the model type is pnas_boosted, specify the path of the pre-trained pnasvol model here
parser.add_argument('--model_vol_path',default="", type=str)
args = parser.parse_args()
train_img_dir = args.dataset_dir + "images/train/"
train_gt_dir = args.dataset_dir + "maps/train/"
train_fix_dir = args.dataset_dir + "fixation_maps/train/"
val_img_dir = args.dataset_dir + "images/val/"
val_gt_dir = args.dataset_dir + "maps/val/"
val_fix_dir = args.dataset_dir + "fixation_maps/val/"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.enc_model == "pnas":
print("PNAS Model")
from model import PNASModel
model = PNASModel(train_enc=bool(args.train_enc), load_weight=args.load_weight)
elif args.enc_model == "pnas_boosted_multi":
print("PNAS Boosted Model PNASBoostedModelMultilevel")
from model import PNASBoostedModelMultilevel
model = PNASBoostedModelMultilevel(device, args.model_path, args.model_vol_path, args.time_slices, train_model=args.train_model,selected_slices = args.selected_slices )
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.to(device)
train_img_ids = [nm.split(".")[0] for nm in os.listdir(train_img_dir)]
val_img_ids = [nm.split(".")[0] for nm in os.listdir(val_img_dir)]
train_dataset = SaliconDataset(train_img_dir, train_gt_dir, train_fix_dir, train_img_ids)
val_dataset = SaliconDataset(val_img_dir, val_gt_dir, val_fix_dir, val_img_ids)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.no_workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.no_workers)
def loss_func(pred_map, gt, fixations, args):
loss = torch.FloatTensor([0.0]).cuda()
criterion = nn.L1Loss()
if args.kldiv:
loss += args.kldiv_coeff * kldiv(pred_map, gt)
if args.cc:
loss += args.cc_coeff * cc(pred_map, gt)
if args.nss:
loss += args.nss_coeff * nss(pred_map, fixations)
if args.l1:
loss += args.l1_coeff * criterion(pred_map, gt)
if args.sim:
loss += args.sim_coeff * similarity(pred_map, gt)
#print("Loss: ", loss)
return loss
def train(model, optimizer, loader, epoch, device, args):
model.train()
tic = time.time()
total_loss = 0.0
cur_loss = 0.0
for idx, (img, gt, fixations) in enumerate(loader):
img = img.to(device)
gt = gt.to(device)
fixations = fixations.to(device)
optimizer.zero_grad()
pred_map, vol_pred = model(img)
assert pred_map.size() == gt.size()
loss = loss_func(pred_map, gt, fixations, args)
loss.backward()
total_loss += loss.item()
cur_loss += loss.item()
optimizer.step()
if idx%args.log_interval==(args.log_interval-1):
print('[{:2d}, {:5d}] avg_loss : {:.5f}, time:{:3f} minutes'.format(epoch, idx, cur_loss/args.log_interval, (time.time()-tic)/60))
wandb.log({"loss": cur_loss/args.log_interval})
cur_loss = 0.0
sys.stdout.flush()
print('[{:2d}, train] avg_loss : {:.5f}'.format(epoch, total_loss/len(loader)))
sys.stdout.flush()
return total_loss/len(loader)
def validate(model, loader, epoch, device, args):
model.eval()
tic = time.time()
cc_loss = AverageMeter()
kldiv_loss = AverageMeter()
nss_loss = AverageMeter()
sim_loss = AverageMeter()
for (img, gt, fixations) in tqdm(loader):
img = img.to(device)
gt = gt.to(device)
fixations = fixations.to(device)
pred_map , vol_pred = model(img)
cc_loss.update(cc(pred_map, gt))
kldiv_loss.update(kldiv(pred_map, gt))
nss_loss.update(nss(pred_map, fixations))
sim_loss.update(similarity(pred_map, gt))
print('[{:2d}, val] CC : {:.5f}, KLDIV : {:.5f}, NSS : {:.5f}, SIM : {:.5f} time:{:3f} minutes'.format(epoch, cc_loss.avg, kldiv_loss.avg, nss_loss.avg, sim_loss.avg, (time.time()-tic)/60))
wandb.log({"CC": cc_loss.avg, 'KLDIV': kldiv_loss.avg, 'NSS': nss_loss.avg, 'SIM': sim_loss.avg})
sys.stdout.flush()
return cc_loss.avg,cc_loss,kldiv_loss,nss_loss,sim_loss
params = list(filter(lambda p: p.requires_grad, model.parameters()))
if args.optim=="Adam":
optimizer = torch.optim.Adam(params, lr=args.lr)
if args.optim=="Adagrad":
optimizer = torch.optim.Adagrad(params, lr=args.lr)
if args.optim=="SGD":
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
if args.lr_sched:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
print(device)
best_loss = 0
for epoch in range(0, args.no_epochs):
loss = train(model, optimizer, train_loader, epoch, device, args)
with torch.no_grad():
cc_loss,cc_loss_obj,kldiv_loss,nss_loss,sim_loss = validate(model, val_loader, epoch, device, args)
cc_loss -=kldiv_loss.avg
if epoch == 0 :
best_loss = cc_loss
if best_loss <= cc_loss:
best_loss = cc_loss
print('[{:2d}, save, {}]'.format(epoch, args.model_val_path))
wandb.log({"Best/CC mean": cc_loss,"Best/CC median": cc_loss_obj.get_median(), "Best/CC std": cc_loss_obj.get_std(),
"Best/KLD mean": kldiv_loss.avg,"Best/KLD median": kldiv_loss.get_median(), "Best/KLD std": kldiv_loss.get_std(),
"Best/NSS mean": nss_loss.avg,"Best/NSS median": nss_loss.get_median(), "Best/NSS std": nss_loss.get_std(),
"Best/SIM mean": sim_loss.avg,"Best/SIM median": sim_loss.get_median(), "Best/SIM std": sim_loss.get_std()})
if torch.cuda.device_count() > 1:
torch.save(model.module.state_dict(), args.model_val_path)
else:
torch.save(model.state_dict(), args.model_val_path)
print()
if args.lr_sched:
scheduler.step()