|
|
import argparse |
|
|
import os |
|
|
import torch |
|
|
import sys |
|
|
import time |
|
|
import wandb |
|
|
import torch.nn as nn |
|
|
from tqdm import tqdm |
|
|
from dataloader import SaliconDataset |
|
|
from loss import * |
|
|
from utils import AverageMeter |
|
|
from utils import img_save |
|
|
from torchvision import utils |
|
|
import torch.nn.functional as nnf |
|
|
from os.path import join |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--no_epochs',default=30, type=int) |
|
|
parser.add_argument('--lr',default=1e-5, type=float) |
|
|
parser.add_argument('--kldiv',default=True, type=bool) |
|
|
parser.add_argument('--cc',default=True, type=bool) |
|
|
parser.add_argument('--nss',default=False, type=bool) |
|
|
parser.add_argument('--sim',default=False, type=bool) |
|
|
parser.add_argument('--nss_emlnet',default=False, type=bool) |
|
|
parser.add_argument('--nss_norm',default=False, type=bool) |
|
|
parser.add_argument('--l1',default=False, type=bool) |
|
|
parser.add_argument('--lr_sched',default=False, type=bool) |
|
|
parser.add_argument('--dilation',default=False, type=bool) |
|
|
parser.add_argument('--enc_model',default="pnas", type=str) |
|
|
parser.add_argument('--optim',default="Adam", type=str) |
|
|
|
|
|
parser.add_argument('--load_weight',default=1, type=int) |
|
|
parser.add_argument('--kldiv_coeff',default=1.0, type=float) |
|
|
parser.add_argument('--step_size',default=5, type=int) |
|
|
parser.add_argument('--cc_coeff',default=-1.0, type=float) |
|
|
parser.add_argument('--sim_coeff',default=-1.0, type=float) |
|
|
parser.add_argument('--nss_coeff',default=-1.0, type=float) |
|
|
parser.add_argument('--nss_emlnet_coeff',default=1.0, type=float) |
|
|
parser.add_argument('--nss_norm_coeff',default=1.0, type=float) |
|
|
parser.add_argument('--l1_coeff',default=1.0, type=float) |
|
|
parser.add_argument('--train_enc',default=1, type=int) |
|
|
|
|
|
parser.add_argument('--dataset_dir',default="../data/", type=str) |
|
|
parser.add_argument('--batch_size',default=32, type=int) |
|
|
parser.add_argument('--log_interval',default=60, type=int) |
|
|
parser.add_argument('--no_workers',default=4, type=int) |
|
|
parser.add_argument('--train_model',default=False, type=bool) |
|
|
parser.add_argument('--time_slices',default=5, type=int) |
|
|
parser.add_argument('--selected_slices',default="", type=str) |
|
|
parser.add_argument('--results_dir',default="", type=str ) |
|
|
|
|
|
|
|
|
parser.add_argument('--model_val_path',default="model.pt", type=str) |
|
|
|
|
|
parser.add_argument('--model_path',default="", type=str) |
|
|
|
|
|
parser.add_argument('--model_vol_path',default="", type=str) |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
train_img_dir = args.dataset_dir + "images/train/" |
|
|
train_gt_dir = args.dataset_dir + "maps/train/" |
|
|
train_fix_dir = args.dataset_dir + "fixation_maps/train/" |
|
|
|
|
|
val_img_dir = args.dataset_dir + "images/val/" |
|
|
val_gt_dir = args.dataset_dir + "maps/val/" |
|
|
val_fix_dir = args.dataset_dir + "fixation_maps/val/" |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
if args.enc_model == "pnas": |
|
|
print("PNAS Model") |
|
|
from model import PNASModel |
|
|
model = PNASModel(train_enc=bool(args.train_enc), load_weight=args.load_weight) |
|
|
|
|
|
elif args.enc_model == "pnas_boosted_multi": |
|
|
print("PNAS Boosted Model PNASBoostedModelMultilevel") |
|
|
from model import PNASBoostedModelMultilevel |
|
|
model = PNASBoostedModelMultilevel(device, args.model_path, args.model_vol_path, args.time_slices, train_model=args.train_model,selected_slices = args.selected_slices ) |
|
|
|
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
|
print("Let's use", torch.cuda.device_count(), "GPUs!") |
|
|
model = nn.DataParallel(model) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
train_img_ids = [nm.split(".")[0] for nm in os.listdir(train_img_dir)] |
|
|
val_img_ids = [nm.split(".")[0] for nm in os.listdir(val_img_dir)] |
|
|
|
|
|
train_dataset = SaliconDataset(train_img_dir, train_gt_dir, train_fix_dir, train_img_ids) |
|
|
val_dataset = SaliconDataset(val_img_dir, val_gt_dir, val_fix_dir, val_img_ids) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.no_workers) |
|
|
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.no_workers) |
|
|
|
|
|
|
|
|
|
|
|
def loss_func(pred_map, gt, fixations, args): |
|
|
loss = torch.FloatTensor([0.0]).cuda() |
|
|
criterion = nn.L1Loss() |
|
|
if args.kldiv: |
|
|
loss += args.kldiv_coeff * kldiv(pred_map, gt) |
|
|
if args.cc: |
|
|
loss += args.cc_coeff * cc(pred_map, gt) |
|
|
if args.nss: |
|
|
loss += args.nss_coeff * nss(pred_map, fixations) |
|
|
if args.l1: |
|
|
loss += args.l1_coeff * criterion(pred_map, gt) |
|
|
if args.sim: |
|
|
loss += args.sim_coeff * similarity(pred_map, gt) |
|
|
|
|
|
return loss |
|
|
|
|
|
def train(model, optimizer, loader, epoch, device, args): |
|
|
model.train() |
|
|
|
|
|
tic = time.time() |
|
|
|
|
|
total_loss = 0.0 |
|
|
cur_loss = 0.0 |
|
|
|
|
|
for idx, (img, gt, fixations) in enumerate(loader): |
|
|
img = img.to(device) |
|
|
gt = gt.to(device) |
|
|
fixations = fixations.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
pred_map, vol_pred = model(img) |
|
|
|
|
|
assert pred_map.size() == gt.size() |
|
|
|
|
|
loss = loss_func(pred_map, gt, fixations, args) |
|
|
loss.backward() |
|
|
|
|
|
total_loss += loss.item() |
|
|
cur_loss += loss.item() |
|
|
|
|
|
optimizer.step() |
|
|
if idx%args.log_interval==(args.log_interval-1): |
|
|
print('[{:2d}, {:5d}] avg_loss : {:.5f}, time:{:3f} minutes'.format(epoch, idx, cur_loss/args.log_interval, (time.time()-tic)/60)) |
|
|
wandb.log({"loss": cur_loss/args.log_interval}) |
|
|
cur_loss = 0.0 |
|
|
sys.stdout.flush() |
|
|
|
|
|
print('[{:2d}, train] avg_loss : {:.5f}'.format(epoch, total_loss/len(loader))) |
|
|
sys.stdout.flush() |
|
|
|
|
|
return total_loss/len(loader) |
|
|
|
|
|
def validate(model, loader, epoch, device, args): |
|
|
model.eval() |
|
|
tic = time.time() |
|
|
cc_loss = AverageMeter() |
|
|
kldiv_loss = AverageMeter() |
|
|
nss_loss = AverageMeter() |
|
|
sim_loss = AverageMeter() |
|
|
|
|
|
for (img, gt, fixations) in tqdm(loader): |
|
|
img = img.to(device) |
|
|
gt = gt.to(device) |
|
|
fixations = fixations.to(device) |
|
|
|
|
|
pred_map , vol_pred = model(img) |
|
|
|
|
|
cc_loss.update(cc(pred_map, gt)) |
|
|
kldiv_loss.update(kldiv(pred_map, gt)) |
|
|
nss_loss.update(nss(pred_map, fixations)) |
|
|
sim_loss.update(similarity(pred_map, gt)) |
|
|
|
|
|
print('[{:2d}, val] CC : {:.5f}, KLDIV : {:.5f}, NSS : {:.5f}, SIM : {:.5f} time:{:3f} minutes'.format(epoch, cc_loss.avg, kldiv_loss.avg, nss_loss.avg, sim_loss.avg, (time.time()-tic)/60)) |
|
|
wandb.log({"CC": cc_loss.avg, 'KLDIV': kldiv_loss.avg, 'NSS': nss_loss.avg, 'SIM': sim_loss.avg}) |
|
|
sys.stdout.flush() |
|
|
|
|
|
return cc_loss.avg,cc_loss,kldiv_loss,nss_loss,sim_loss |
|
|
|
|
|
params = list(filter(lambda p: p.requires_grad, model.parameters())) |
|
|
|
|
|
if args.optim=="Adam": |
|
|
optimizer = torch.optim.Adam(params, lr=args.lr) |
|
|
if args.optim=="Adagrad": |
|
|
optimizer = torch.optim.Adagrad(params, lr=args.lr) |
|
|
if args.optim=="SGD": |
|
|
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9) |
|
|
if args.lr_sched: |
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1) |
|
|
|
|
|
print(device) |
|
|
best_loss = 0 |
|
|
for epoch in range(0, args.no_epochs): |
|
|
loss = train(model, optimizer, train_loader, epoch, device, args) |
|
|
|
|
|
with torch.no_grad(): |
|
|
cc_loss,cc_loss_obj,kldiv_loss,nss_loss,sim_loss = validate(model, val_loader, epoch, device, args) |
|
|
cc_loss -=kldiv_loss.avg |
|
|
if epoch == 0 : |
|
|
best_loss = cc_loss |
|
|
if best_loss <= cc_loss: |
|
|
best_loss = cc_loss |
|
|
print('[{:2d}, save, {}]'.format(epoch, args.model_val_path)) |
|
|
wandb.log({"Best/CC mean": cc_loss,"Best/CC median": cc_loss_obj.get_median(), "Best/CC std": cc_loss_obj.get_std(), |
|
|
"Best/KLD mean": kldiv_loss.avg,"Best/KLD median": kldiv_loss.get_median(), "Best/KLD std": kldiv_loss.get_std(), |
|
|
"Best/NSS mean": nss_loss.avg,"Best/NSS median": nss_loss.get_median(), "Best/NSS std": nss_loss.get_std(), |
|
|
"Best/SIM mean": sim_loss.avg,"Best/SIM median": sim_loss.get_median(), "Best/SIM std": sim_loss.get_std()}) |
|
|
if torch.cuda.device_count() > 1: |
|
|
torch.save(model.module.state_dict(), args.model_val_path) |
|
|
else: |
|
|
torch.save(model.state_dict(), args.model_val_path) |
|
|
print() |
|
|
|
|
|
if args.lr_sched: |
|
|
scheduler.step() |
|
|
|