Spaces:
Runtime error
Runtime error
| import cv2 | |
| import os | |
| import sys | |
| sys.path.insert(0, '../') | |
| import numpy as np | |
| import math | |
| import glob | |
| import pyspng | |
| import PIL.Image | |
| import torch | |
| import dnnlib | |
| import scipy.linalg | |
| import sklearn.svm | |
| _feature_detector_cache = dict() | |
| def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): | |
| assert 0 <= rank < num_gpus | |
| key = (url, device) | |
| if key not in _feature_detector_cache: | |
| is_leader = (rank == 0) | |
| if not is_leader and num_gpus > 1: | |
| torch.distributed.barrier() # leader goes first | |
| with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: | |
| _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) | |
| if is_leader and num_gpus > 1: | |
| torch.distributed.barrier() # others follow | |
| return _feature_detector_cache[key] | |
| def read_image(image_path): | |
| with open(image_path, 'rb') as f: | |
| if pyspng is not None and image_path.endswith('.png'): | |
| image = pyspng.load(f.read()) | |
| else: | |
| image = np.array(PIL.Image.open(f)) | |
| if image.ndim == 2: | |
| image = image[:, :, np.newaxis] # HW => HWC | |
| if image.shape[2] == 1: | |
| image = np.repeat(image, 3, axis=2) | |
| image = image.transpose(2, 0, 1) # HWC => CHW | |
| image = torch.from_numpy(image).unsqueeze(0).to(torch.uint8) | |
| return image | |
| class FeatureStats: | |
| def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): | |
| self.capture_all = capture_all | |
| self.capture_mean_cov = capture_mean_cov | |
| self.max_items = max_items | |
| self.num_items = 0 | |
| self.num_features = None | |
| self.all_features = None | |
| self.raw_mean = None | |
| self.raw_cov = None | |
| def set_num_features(self, num_features): | |
| if self.num_features is not None: | |
| assert num_features == self.num_features | |
| else: | |
| self.num_features = num_features | |
| self.all_features = [] | |
| self.raw_mean = np.zeros([num_features], dtype=np.float64) | |
| self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) | |
| def is_full(self): | |
| return (self.max_items is not None) and (self.num_items >= self.max_items) | |
| def append(self, x): | |
| x = np.asarray(x, dtype=np.float32) | |
| assert x.ndim == 2 | |
| if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): | |
| if self.num_items >= self.max_items: | |
| return | |
| x = x[:self.max_items - self.num_items] | |
| self.set_num_features(x.shape[1]) | |
| self.num_items += x.shape[0] | |
| if self.capture_all: | |
| self.all_features.append(x) | |
| if self.capture_mean_cov: | |
| x64 = x.astype(np.float64) | |
| self.raw_mean += x64.sum(axis=0) | |
| self.raw_cov += x64.T @ x64 | |
| def append_torch(self, x, num_gpus=1, rank=0): | |
| assert isinstance(x, torch.Tensor) and x.ndim == 2 | |
| assert 0 <= rank < num_gpus | |
| if num_gpus > 1: | |
| ys = [] | |
| for src in range(num_gpus): | |
| y = x.clone() | |
| torch.distributed.broadcast(y, src=src) | |
| ys.append(y) | |
| x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples | |
| self.append(x.cpu().numpy()) | |
| def get_all(self): | |
| assert self.capture_all | |
| return np.concatenate(self.all_features, axis=0) | |
| def get_all_torch(self): | |
| return torch.from_numpy(self.get_all()) | |
| def get_mean_cov(self): | |
| assert self.capture_mean_cov | |
| mean = self.raw_mean / self.num_items | |
| cov = self.raw_cov / self.num_items | |
| cov = cov - np.outer(mean, mean) | |
| return mean, cov | |
| def save(self, pkl_file): | |
| with open(pkl_file, 'wb') as f: | |
| pickle.dump(self.__dict__, f) | |
| def load(pkl_file): | |
| with open(pkl_file, 'rb') as f: | |
| s = dnnlib.EasyDict(pickle.load(f)) | |
| obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) | |
| obj.__dict__.update(s) | |
| return obj | |
| def calculate_metrics(folder1, folder2): | |
| l1 = sorted(glob.glob(folder1 + '/*.png') + glob.glob(folder1 + '/*.jpg')) | |
| l2 = sorted(glob.glob(folder2 + '/*.png') + glob.glob(folder2 + '/*.jpg')) | |
| assert(len(l1) == len(l2)) | |
| print('length:', len(l1)) | |
| # l1 = l1[:3]; l2 = l2[:3]; | |
| # build detector | |
| detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' | |
| detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. | |
| device = torch.device('cuda:0') | |
| detector = get_feature_detector(url=detector_url, device=device, num_gpus=1, rank=0, verbose=False) | |
| detector.eval() | |
| stat1 = FeatureStats(capture_all=True, capture_mean_cov=True, max_items=len(l1)) | |
| stat2 = FeatureStats(capture_all=True, capture_mean_cov=True, max_items=len(l1)) | |
| with torch.no_grad(): | |
| for i, (fpath1, fpath2) in enumerate(zip(l1, l2)): | |
| print(i) | |
| _, name1 = os.path.split(fpath1) | |
| _, name2 = os.path.split(fpath2) | |
| name1 = name1.split('.')[0] | |
| name2 = name2.split('.')[0] | |
| assert name1 == name2, 'Illegal mapping: %s, %s' % (name1, name2) | |
| img1 = read_image(fpath1).to(device) | |
| img2 = read_image(fpath2).to(device) | |
| assert img1.shape == img2.shape, 'Illegal shape' | |
| fea1 = detector(img1, **detector_kwargs) | |
| stat1.append_torch(fea1, num_gpus=1, rank=0) | |
| fea2 = detector(img2, **detector_kwargs) | |
| stat2.append_torch(fea2, num_gpus=1, rank=0) | |
| # calculate fid | |
| mu1, sigma1 = stat1.get_mean_cov() | |
| mu2, sigma2 = stat2.get_mean_cov() | |
| m = np.square(mu1 - mu2).sum() | |
| s, _ = scipy.linalg.sqrtm(np.dot(sigma1, sigma2), disp=False) # pylint: disable=no-member | |
| fid = np.real(m + np.trace(sigma1 + sigma2 - s * 2)) | |
| # calculate pids and uids | |
| fake_activations = stat1.get_all() | |
| real_activations = stat2.get_all() | |
| svm = sklearn.svm.LinearSVC(dual=False) | |
| svm_inputs = np.concatenate([real_activations, fake_activations]) | |
| svm_targets = np.array([1] * real_activations.shape[0] + [0] * fake_activations.shape[0]) | |
| print('SVM fitting ...') | |
| svm.fit(svm_inputs, svm_targets) | |
| uids = 1 - svm.score(svm_inputs, svm_targets) | |
| real_outputs = svm.decision_function(real_activations) | |
| fake_outputs = svm.decision_function(fake_activations) | |
| pids = np.mean(fake_outputs > real_outputs) | |
| return fid, pids, uids | |
| if __name__ == '__main__': | |
| folder1 = 'path to the inpainted result' | |
| folder2 = 'path to the gt' | |
| fid, pids, uids = calculate_metrics(folder1, folder2) | |
| print('fid: %.4f, pids: %.4f, uids: %.4f' % (fid, pids, uids)) | |
| with open('fid_pids_uids.txt', 'w') as f: | |
| f.write('fid: %.4f, pids: %.4f, uids: %.4f' % (fid, pids, uids)) | |