Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import os | |
| import time | |
| import hashlib | |
| import pickle | |
| import copy | |
| import uuid | |
| import numpy as np | |
| import torch | |
| import dnnlib | |
| import math | |
| import cv2 | |
| #---------------------------------------------------------------------------- | |
| class MetricOptions: | |
| def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True): | |
| assert 0 <= rank < num_gpus | |
| self.G = G | |
| self.G_kwargs = dnnlib.EasyDict(G_kwargs) | |
| self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) | |
| self.num_gpus = num_gpus | |
| self.rank = rank | |
| self.device = device if device is not None else torch.device('cuda', rank) | |
| self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() | |
| self.cache = cache | |
| #---------------------------------------------------------------------------- | |
| _feature_detector_cache = dict() | |
| def get_feature_detector_name(url): | |
| return os.path.splitext(url.split('/')[-1])[0] | |
| def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): | |
| assert 0 <= rank < num_gpus | |
| key = (url, device) | |
| if key not in _feature_detector_cache: | |
| is_leader = (rank == 0) | |
| if not is_leader and num_gpus > 1: | |
| torch.distributed.barrier() # leader goes first | |
| with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: | |
| _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) | |
| if is_leader and num_gpus > 1: | |
| torch.distributed.barrier() # others follow | |
| return _feature_detector_cache[key] | |
| #---------------------------------------------------------------------------- | |
| class FeatureStats: | |
| def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): | |
| self.capture_all = capture_all | |
| self.capture_mean_cov = capture_mean_cov | |
| self.max_items = max_items | |
| self.num_items = 0 | |
| self.num_features = None | |
| self.all_features = None | |
| self.raw_mean = None | |
| self.raw_cov = None | |
| def set_num_features(self, num_features): | |
| if self.num_features is not None: | |
| assert num_features == self.num_features | |
| else: | |
| self.num_features = num_features | |
| self.all_features = [] | |
| self.raw_mean = np.zeros([num_features], dtype=np.float64) | |
| self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) | |
| def is_full(self): | |
| return (self.max_items is not None) and (self.num_items >= self.max_items) | |
| def append(self, x): | |
| x = np.asarray(x, dtype=np.float32) | |
| assert x.ndim == 2 | |
| if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): | |
| if self.num_items >= self.max_items: | |
| return | |
| x = x[:self.max_items - self.num_items] | |
| self.set_num_features(x.shape[1]) | |
| self.num_items += x.shape[0] | |
| if self.capture_all: | |
| self.all_features.append(x) | |
| if self.capture_mean_cov: | |
| x64 = x.astype(np.float64) | |
| self.raw_mean += x64.sum(axis=0) | |
| self.raw_cov += x64.T @ x64 | |
| def append_torch(self, x, num_gpus=1, rank=0): | |
| assert isinstance(x, torch.Tensor) and x.ndim == 2 | |
| assert 0 <= rank < num_gpus | |
| if num_gpus > 1: | |
| ys = [] | |
| for src in range(num_gpus): | |
| y = x.clone() | |
| torch.distributed.broadcast(y, src=src) | |
| ys.append(y) | |
| x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples | |
| self.append(x.cpu().numpy()) | |
| def get_all(self): | |
| assert self.capture_all | |
| return np.concatenate(self.all_features, axis=0) | |
| def get_all_torch(self): | |
| return torch.from_numpy(self.get_all()) | |
| def get_mean_cov(self): | |
| assert self.capture_mean_cov | |
| mean = self.raw_mean / self.num_items | |
| cov = self.raw_cov / self.num_items | |
| cov = cov - np.outer(mean, mean) | |
| return mean, cov | |
| def save(self, pkl_file): | |
| with open(pkl_file, 'wb') as f: | |
| pickle.dump(self.__dict__, f) | |
| def load(pkl_file): | |
| with open(pkl_file, 'rb') as f: | |
| s = dnnlib.EasyDict(pickle.load(f)) | |
| obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) | |
| obj.__dict__.update(s) | |
| return obj | |
| #---------------------------------------------------------------------------- | |
| class ProgressMonitor: | |
| def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): | |
| self.tag = tag | |
| self.num_items = num_items | |
| self.verbose = verbose | |
| self.flush_interval = flush_interval | |
| self.progress_fn = progress_fn | |
| self.pfn_lo = pfn_lo | |
| self.pfn_hi = pfn_hi | |
| self.pfn_total = pfn_total | |
| self.start_time = time.time() | |
| self.batch_time = self.start_time | |
| self.batch_items = 0 | |
| if self.progress_fn is not None: | |
| self.progress_fn(self.pfn_lo, self.pfn_total) | |
| def update(self, cur_items): | |
| assert (self.num_items is None) or (cur_items <= self.num_items) | |
| if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): | |
| return | |
| cur_time = time.time() | |
| total_time = cur_time - self.start_time | |
| time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) | |
| if (self.verbose) and (self.tag is not None): | |
| print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') | |
| self.batch_time = cur_time | |
| self.batch_items = cur_items | |
| if (self.progress_fn is not None) and (self.num_items is not None): | |
| self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) | |
| def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): | |
| return ProgressMonitor( | |
| tag = tag, | |
| num_items = num_items, | |
| flush_interval = flush_interval, | |
| verbose = self.verbose, | |
| progress_fn = self.progress_fn, | |
| pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, | |
| pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, | |
| pfn_total = self.pfn_total, | |
| ) | |
| #---------------------------------------------------------------------------- | |
| def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs): | |
| dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) | |
| if data_loader_kwargs is None: | |
| data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) | |
| # Try to lookup from cache. | |
| cache_file = None | |
| if opts.cache: | |
| # Choose cache file name. | |
| args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) | |
| md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) | |
| cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' | |
| cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') | |
| # Check if the file exists (all processes must agree). | |
| flag = os.path.isfile(cache_file) if opts.rank == 0 else False | |
| if opts.num_gpus > 1: | |
| flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) | |
| torch.distributed.broadcast(tensor=flag, src=0) | |
| flag = (float(flag.cpu()) != 0) | |
| # Load. | |
| if flag: | |
| return FeatureStats.load(cache_file) | |
| # Initialize. | |
| num_items = len(dataset) | |
| if max_items is not None: | |
| num_items = min(num_items, max_items) | |
| stats = FeatureStats(max_items=num_items, **stats_kwargs) | |
| progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) | |
| detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) | |
| # Main loop. | |
| item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] | |
| # for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): | |
| # adaptation to inpainting | |
| for images, masks, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, | |
| **data_loader_kwargs): | |
| # -------------------------------- | |
| if images.shape[1] == 1: | |
| images = images.repeat([1, 3, 1, 1]) | |
| features = detector(images.to(opts.device), **detector_kwargs) | |
| stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) | |
| progress.update(stats.num_items) | |
| # Save to cache. | |
| if cache_file is not None and opts.rank == 0: | |
| os.makedirs(os.path.dirname(cache_file), exist_ok=True) | |
| temp_file = cache_file + '.' + uuid.uuid4().hex | |
| stats.save(temp_file) | |
| os.replace(temp_file, cache_file) # atomic | |
| return stats | |
| #---------------------------------------------------------------------------- | |
| def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, data_loader_kwargs=None, **stats_kwargs): | |
| if data_loader_kwargs is None: | |
| data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) | |
| if batch_gen is None: | |
| batch_gen = min(batch_size, 4) | |
| assert batch_size % batch_gen == 0 | |
| # Setup generator and load labels. | |
| G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) | |
| dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) | |
| # Image generation func. | |
| def run_generator(img_in, mask_in, z, c): | |
| img = G(img_in, mask_in, z, c, **opts.G_kwargs) | |
| # img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
| img = ((img + 1.0) * 127.5).clamp(0, 255).round().to(torch.uint8) | |
| return img | |
| # # JIT. | |
| # if jit: | |
| # z = torch.zeros([batch_gen, G.z_dim], device=opts.device) | |
| # c = torch.zeros([batch_gen, G.c_dim], device=opts.device) | |
| # run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False) | |
| # Initialize. | |
| stats = FeatureStats(**stats_kwargs) | |
| assert stats.max_items is not None | |
| progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) | |
| detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) | |
| # Main loop. | |
| item_subset = [(i * opts.num_gpus + opts.rank) % stats.max_items for i in range((stats.max_items - 1) // opts.num_gpus + 1)] | |
| for imgs_batch, masks_batch, labels_batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, | |
| batch_size=batch_size, | |
| **data_loader_kwargs): | |
| images = [] | |
| imgs_gen = (imgs_batch.to(opts.device).to(torch.float32) / 127.5 - 1).split(batch_gen) | |
| masks_gen = masks_batch.to(opts.device).to(torch.float32).split(batch_gen) | |
| for img_in, mask_in in zip(imgs_gen, masks_gen): | |
| z = torch.randn([img_in.shape[0], G.z_dim], device=opts.device) | |
| c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(img_in.shape[0])] | |
| c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) | |
| images.append(run_generator(img_in, mask_in, z, c)) | |
| images = torch.cat(images) | |
| if images.shape[1] == 1: | |
| images = images.repeat([1, 3, 1, 1]) | |
| features = detector(images, **detector_kwargs) | |
| stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) | |
| progress.update(stats.num_items) | |
| return stats | |
| #---------------------------------------------------------------------------- | |
| def compute_image_stats_for_generator(opts, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, data_loader_kwargs=None, **stats_kwargs): | |
| if data_loader_kwargs is None: | |
| data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) | |
| if batch_gen is None: | |
| batch_gen = min(batch_size, 4) | |
| assert batch_size % batch_gen == 0 | |
| # Setup generator and load labels. | |
| G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) | |
| dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) | |
| # Image generation func. | |
| def run_generator(img_in, mask_in, z, c): | |
| img = G(img_in, mask_in, z, c, **opts.G_kwargs) | |
| # img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
| img = ((img + 1.0) * 127.5).clamp(0, 255).round().to(torch.uint8) | |
| return img | |
| # Initialize. | |
| stats = FeatureStats(**stats_kwargs) | |
| assert stats.max_items is not None | |
| progress = opts.progress.sub(tag='generator images', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) | |
| # Main loop. | |
| item_subset = [(i * opts.num_gpus + opts.rank) % stats.max_items for i in range((stats.max_items - 1) // opts.num_gpus + 1)] | |
| for imgs_batch, masks_batch, labels_batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, | |
| batch_size=batch_size, | |
| **data_loader_kwargs): | |
| images = [] | |
| imgs_gen = (imgs_batch.to(opts.device).to(torch.float32) / 127.5 - 1).split(batch_gen) | |
| masks_gen = masks_batch.to(opts.device).to(torch.float32).split(batch_gen) | |
| for img_in, mask_in in zip(imgs_gen, masks_gen): | |
| z = torch.randn([img_in.shape[0], G.z_dim], device=opts.device) | |
| c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(img_in.shape[0])] | |
| c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) | |
| images.append(run_generator(img_in, mask_in, z, c)) | |
| images = torch.cat(images) | |
| if images.shape[1] == 1: | |
| images = images.repeat([1, 3, 1, 1]) | |
| assert imgs_batch.shape == images.shape | |
| metrics = [] | |
| for i in range(imgs_batch.shape[0]): | |
| img_real = np.transpose(imgs_batch[i].cpu().numpy(), [1, 2, 0]) | |
| img_gen = np.transpose(images[i].cpu().numpy(), [1, 2, 0]) | |
| psnr = calculate_psnr(img_gen, img_real) | |
| ssim = calculate_ssim(img_gen, img_real) | |
| l1 = calculate_l1(img_gen, img_real) | |
| metrics.append([psnr, ssim, l1]) | |
| metrics = torch.from_numpy(np.array(metrics)).to(torch.float32).to(opts.device) | |
| stats.append_torch(metrics, num_gpus=opts.num_gpus, rank=opts.rank) | |
| progress.update(stats.num_items) | |
| return stats | |
| def calculate_psnr(img1, img2): | |
| # img1 and img2 have range [0, 255] | |
| img1 = img1.astype(np.float64) | |
| img2 = img2.astype(np.float64) | |
| mse = np.mean((img1 - img2) ** 2) | |
| if mse == 0: | |
| return float('inf') | |
| return 20 * math.log10(255.0 / math.sqrt(mse)) | |
| def calculate_ssim(img1, img2): | |
| C1 = (0.01 * 255) ** 2 | |
| C2 = (0.03 * 255) ** 2 | |
| img1 = img1.astype(np.float64) | |
| img2 = img2.astype(np.float64) | |
| kernel = cv2.getGaussianKernel(11, 1.5) | |
| window = np.outer(kernel, kernel.transpose()) | |
| mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] | |
| mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] | |
| mu1_sq = mu1 ** 2 | |
| mu2_sq = mu2 ** 2 | |
| mu1_mu2 = mu1 * mu2 | |
| sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq | |
| sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq | |
| sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 | |
| ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
| return ssim_map.mean() | |
| def calculate_l1(img1, img2): | |
| img1 = img1.astype(np.float64) / 255.0 | |
| img2 = img2.astype(np.float64) / 255.0 | |
| l1 = np.mean(np.abs(img1 - img2)) | |
| return l1 | |
| # def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs): | |
| # if batch_gen is None: | |
| # batch_gen = min(batch_size, 4) | |
| # assert batch_size % batch_gen == 0 | |
| # | |
| # # Setup generator and load labels. | |
| # G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) | |
| # dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) | |
| # | |
| # # Image generation func. | |
| # def run_generator(z, c): | |
| # img = G(z=z, c=c, **opts.G_kwargs) | |
| # img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
| # return img | |
| # | |
| # # JIT. | |
| # if jit: | |
| # z = torch.zeros([batch_gen, G.z_dim], device=opts.device) | |
| # c = torch.zeros([batch_gen, G.c_dim], device=opts.device) | |
| # run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False) | |
| # | |
| # # Initialize. | |
| # stats = FeatureStats(**stats_kwargs) | |
| # assert stats.max_items is not None | |
| # progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) | |
| # detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) | |
| # | |
| # # Main loop. | |
| # while not stats.is_full(): | |
| # images = [] | |
| # for _i in range(batch_size // batch_gen): | |
| # z = torch.randn([batch_gen, G.z_dim], device=opts.device) | |
| # c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)] | |
| # c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) | |
| # images.append(run_generator(z, c)) | |
| # images = torch.cat(images) | |
| # if images.shape[1] == 1: | |
| # images = images.repeat([1, 3, 1, 1]) | |
| # features = detector(images, **detector_kwargs) | |
| # stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) | |
| # progress.update(stats.num_items) | |
| # return stats | |
| # | |
| # #---------------------------------------------------------------------------- | |