Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import | |
| import numpy as np | |
| import torch | |
| from collections import OrderedDict | |
| from torch.autograd import Variable | |
| from scipy.ndimage import zoom | |
| from tqdm import tqdm | |
| import os | |
| from .lpips import LPIPS, L2, DSSIM, BCERankingLoss | |
| from .utils import tensor2im, voc_ap | |
| class Trainer: | |
| def name(self): | |
| return self.model_name | |
| def initialize( | |
| self, | |
| model="lpips", | |
| net="alex", | |
| colorspace="Lab", | |
| pnet_rand=False, | |
| pnet_tune=False, | |
| model_path=None, | |
| use_gpu=True, | |
| printNet=False, | |
| spatial=False, | |
| is_train=False, | |
| lr=0.0001, | |
| beta1=0.5, | |
| version="0.1", | |
| gpu_ids=[0], | |
| ): | |
| """ | |
| INPUTS | |
| model - ['lpips'] for linearly calibrated network | |
| ['baseline'] for off-the-shelf network | |
| ['L2'] for L2 distance in Lab colorspace | |
| ['SSIM'] for ssim in RGB colorspace | |
| net - ['squeeze','alex','vgg'] | |
| model_path - if None, will look in weights/[NET_NAME].pth | |
| colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM | |
| use_gpu - bool - whether or not to use a GPU | |
| printNet - bool - whether or not to print network architecture out | |
| spatial - bool - whether to output an array containing varying distances across spatial dimensions | |
| is_train - bool - [True] for training mode | |
| lr - float - initial learning rate | |
| beta1 - float - initial momentum term for adam | |
| version - 0.1 for latest, 0.0 was original (with a bug) | |
| gpu_ids - int array - [0] by default, gpus to use | |
| """ | |
| self.use_gpu = use_gpu | |
| self.gpu_ids = gpu_ids | |
| self.model = model | |
| self.net = net | |
| self.is_train = is_train | |
| self.spatial = spatial | |
| self.model_name = "%s [%s]" % (model, net) | |
| if self.model == "lpips": # pretrained net + linear layer | |
| self.net = LPIPS( | |
| pretrained=not is_train, | |
| net=net, | |
| version=version, | |
| lpips=True, | |
| spatial=spatial, | |
| pnet_rand=pnet_rand, | |
| pnet_tune=pnet_tune, | |
| use_dropout=True, | |
| model_path=model_path, | |
| eval_mode=False, | |
| ) | |
| elif self.model == "baseline": # pretrained network | |
| self.net = LPIPS(pnet_rand=pnet_rand, net=net, lpips=False) | |
| elif self.model in ["L2", "l2"]: | |
| self.net = L2( | |
| use_gpu=use_gpu, colorspace=colorspace | |
| ) # not really a network, only for testing | |
| self.model_name = "L2" | |
| elif self.model in ["DSSIM", "dssim", "SSIM", "ssim"]: | |
| self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace) | |
| self.model_name = "SSIM" | |
| else: | |
| raise ValueError("Model [%s] not recognized." % self.model) | |
| self.parameters = list(self.net.parameters()) | |
| if self.is_train: # training mode | |
| # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) | |
| self.rankLoss = BCERankingLoss() | |
| self.parameters += list(self.rankLoss.net.parameters()) | |
| self.lr = lr | |
| self.old_lr = lr | |
| self.optimizer_net = torch.optim.Adam( | |
| self.parameters, lr=lr, betas=(beta1, 0.999) | |
| ) | |
| else: # test mode | |
| self.net.eval() | |
| if use_gpu: | |
| self.net.to(gpu_ids[0]) | |
| self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) | |
| if self.is_train: | |
| self.rankLoss = self.rankLoss.to( | |
| device=gpu_ids[0] | |
| ) # just put this on GPU0 | |
| if printNet: | |
| pass | |
| def forward(self, in0, in1, retPerLayer=False): | |
| """Function computes the distance between image patches in0 and in1 | |
| INPUTS | |
| in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] | |
| OUTPUT | |
| computed distances between in0 and in1 | |
| """ | |
| return self.net.forward(in0, in1, retPerLayer=retPerLayer) | |
| # ***** TRAINING FUNCTIONS ***** | |
| def optimize_parameters(self): | |
| self.forward_train() | |
| self.optimizer_net.zero_grad() | |
| self.backward_train() | |
| self.optimizer_net.step() | |
| self.clamp_weights() | |
| def clamp_weights(self): | |
| for module in self.net.modules(): | |
| if hasattr(module, "weight") and module.kernel_size == (1, 1): | |
| module.weight.data = torch.clamp(module.weight.data, min=0) | |
| def set_input(self, data): | |
| self.input_ref = data["ref"] | |
| self.input_p0 = data["p0"] | |
| self.input_p1 = data["p1"] | |
| self.input_judge = data["judge"] | |
| if self.use_gpu: | |
| self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) | |
| self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) | |
| self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) | |
| self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) | |
| self.var_ref = Variable(self.input_ref, requires_grad=True) | |
| self.var_p0 = Variable(self.input_p0, requires_grad=True) | |
| self.var_p1 = Variable(self.input_p1, requires_grad=True) | |
| def forward_train(self): # run forward pass | |
| self.d0 = self.forward(self.var_ref, self.var_p0) | |
| self.d1 = self.forward(self.var_ref, self.var_p1) | |
| self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge) | |
| self.var_judge = Variable(1.0 * self.input_judge).view(self.d0.size()) | |
| self.loss_total = self.rankLoss.forward( | |
| self.d0, self.d1, self.var_judge * 2.0 - 1.0 | |
| ) | |
| return self.loss_total | |
| def backward_train(self): | |
| torch.mean(self.loss_total).backward() | |
| def compute_accuracy(self, d0, d1, judge): | |
| """d0, d1 are Variables, judge is a Tensor""" | |
| d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() | |
| judge_per = judge.cpu().numpy().flatten() | |
| return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) | |
| def get_current_errors(self): | |
| retDict = OrderedDict( | |
| [("loss_total", self.loss_total.data.cpu().numpy()), ("acc_r", self.acc_r)] | |
| ) | |
| for key in retDict.keys(): | |
| retDict[key] = np.mean(retDict[key]) | |
| return retDict | |
| def get_current_visuals(self): | |
| zoom_factor = 256 / self.var_ref.data.size()[2] | |
| ref_img = tensor2im(self.var_ref.data) | |
| p0_img = tensor2im(self.var_p0.data) | |
| p1_img = tensor2im(self.var_p1.data) | |
| ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0) | |
| p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0) | |
| p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0) | |
| return OrderedDict( | |
| [("ref", ref_img_vis), ("p0", p0_img_vis), ("p1", p1_img_vis)] | |
| ) | |
| def save(self, path, label): | |
| if self.use_gpu: | |
| self.save_network(self.net.module, path, "", label) | |
| else: | |
| self.save_network(self.net, path, "", label) | |
| self.save_network(self.rankLoss.net, path, "rank", label) | |
| # helper saving function that can be used by subclasses | |
| def save_network(self, network, path, network_label, epoch_label): | |
| save_filename = "%s_net_%s.pth" % (epoch_label, network_label) | |
| save_path = os.path.join(path, save_filename) | |
| torch.save(network.state_dict(), save_path) | |
| # helper loading function that can be used by subclasses | |
| def load_network(self, network, network_label, epoch_label): | |
| save_filename = "%s_net_%s.pth" % (epoch_label, network_label) | |
| save_path = os.path.join(self.save_dir, save_filename) | |
| print("Loading network from %s" % save_path) | |
| network.load_state_dict(torch.load(save_path)) | |
| def update_learning_rate(self, nepoch_decay): | |
| lrd = self.lr / nepoch_decay | |
| lr = self.old_lr - lrd | |
| for param_group in self.optimizer_net.param_groups: | |
| param_group["lr"] = lr | |
| print("update lr [%s] decay: %f -> %f" % (type, self.old_lr, lr)) | |
| self.old_lr = lr | |
| def get_image_paths(self): | |
| return self.image_paths | |
| def save_done(self, flag=False): | |
| np.save(os.path.join(self.save_dir, "done_flag"), flag) | |
| np.savetxt( | |
| os.path.join(self.save_dir, "done_flag"), | |
| [ | |
| flag, | |
| ], | |
| fmt="%i", | |
| ) | |
| def score_2afc_dataset(data_loader, func, name=""): | |
| """Function computes Two Alternative Forced Choice (2AFC) score using | |
| distance function 'func' in dataset 'data_loader' | |
| INPUTS | |
| data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside | |
| func - callable distance function - calling d=func(in0,in1) should take 2 | |
| pytorch tensors with shape Nx3xXxY, and return numpy array of length N | |
| OUTPUTS | |
| [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators | |
| [1] - dictionary with following elements | |
| d0s,d1s - N arrays containing distances between reference patch to perturbed patches | |
| gts - N array in [0,1], preferred patch selected by human evaluators | |
| (closer to "0" for left patch p0, "1" for right patch p1, | |
| "0.6" means 60pct people preferred right patch, 40pct preferred left) | |
| scores - N array in [0,1], corresponding to what percentage function agreed with humans | |
| CONSTS | |
| N - number of test triplets in data_loader | |
| """ | |
| d0s = [] | |
| d1s = [] | |
| gts = [] | |
| for data in tqdm(data_loader.load_data(), desc=name): | |
| d0s += func(data["ref"], data["p0"]).data.cpu().numpy().flatten().tolist() | |
| d1s += func(data["ref"], data["p1"]).data.cpu().numpy().flatten().tolist() | |
| gts += data["judge"].cpu().numpy().flatten().tolist() | |
| d0s = np.array(d0s) | |
| d1s = np.array(d1s) | |
| gts = np.array(gts) | |
| scores = (d0s < d1s) * (1.0 - gts) + (d1s < d0s) * gts + (d1s == d0s) * 0.5 | |
| return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores)) | |
| def score_jnd_dataset(data_loader, func, name=""): | |
| """Function computes JND score using distance function 'func' in dataset 'data_loader' | |
| INPUTS | |
| data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside | |
| func - callable distance function - calling d=func(in0,in1) should take 2 | |
| pytorch tensors with shape Nx3xXxY, and return pytorch array of length N | |
| OUTPUTS | |
| [0] - JND score in [0,1], mAP score (area under precision-recall curve) | |
| [1] - dictionary with following elements | |
| ds - N array containing distances between two patches shown to human evaluator | |
| sames - N array containing fraction of people who thought the two patches were identical | |
| CONSTS | |
| N - number of test triplets in data_loader | |
| """ | |
| ds = [] | |
| gts = [] | |
| for data in tqdm(data_loader.load_data(), desc=name): | |
| ds += func(data["p0"], data["p1"]).data.cpu().numpy().tolist() | |
| gts += data["same"].cpu().numpy().flatten().tolist() | |
| sames = np.array(gts) | |
| ds = np.array(ds) | |
| sorted_inds = np.argsort(ds) | |
| ds_sorted = ds[sorted_inds] | |
| sames_sorted = sames[sorted_inds] | |
| TPs = np.cumsum(sames_sorted) | |
| FPs = np.cumsum(1 - sames_sorted) | |
| FNs = np.sum(sames_sorted) - TPs | |
| precs = TPs / (TPs + FPs) | |
| recs = TPs / (TPs + FNs) | |
| score = voc_ap(recs, precs) | |
| return (score, dict(ds=ds, sames=sames)) | |