Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import tops | |
| from dp2 import utils | |
| from torch_fidelity.helpers import get_kwarg, vassert | |
| from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS | |
| from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity | |
| from torchvision.transforms.functional import resize | |
| def slerp(a, b, t): | |
| a = a / a.norm(dim=-1, keepdim=True) | |
| b = b / b.norm(dim=-1, keepdim=True) | |
| d = (a * b).sum(dim=-1, keepdim=True) | |
| p = t * torch.acos(d) | |
| c = b - d * a | |
| c = c / c.norm(dim=-1, keepdim=True) | |
| d = a * torch.cos(p) + c * torch.sin(p) | |
| d = d / d.norm(dim=-1, keepdim=True) | |
| return d | |
| def calculate_ppl( | |
| dataloader, | |
| generator, | |
| latent_space=None, | |
| data_len=None, | |
| upsample_size=None, | |
| **kwargs) -> dict: | |
| """ | |
| Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py | |
| """ | |
| if latent_space is None: | |
| latent_space = generator.latent_space | |
| assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}" | |
| assert len(upsample_size) == 2 | |
| epsilon = PPL_DEFAULTS["ppl_epsilon"] | |
| interp = PPL_DEFAULTS['ppl_z_interp_mode'] | |
| similarity_name = PPL_DEFAULTS['ppl_sample_similarity'] | |
| sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize'] | |
| sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype'] | |
| discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower'] | |
| discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher'] | |
| vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number') | |
| vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile') | |
| vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile') | |
| if discard_percentile_lower is not None and discard_percentile_higher is not None: | |
| vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles') | |
| sample_similarity = create_sample_similarity( | |
| similarity_name, | |
| sample_similarity_resize=sample_similarity_resize, | |
| sample_similarity_dtype=sample_similarity_dtype, | |
| cuda=False, | |
| **kwargs | |
| ) | |
| sample_similarity = tops.to_cuda(sample_similarity) | |
| rng = np.random.RandomState(get_kwarg('rng_seed', kwargs)) | |
| distances = [] | |
| if data_len is None: | |
| data_len = len(dataloader) * dataloader.batch_size | |
| z0 = sample_random(rng, (data_len, generator.z_channels), "normal") | |
| z1 = sample_random(rng, (data_len, generator.z_channels), "normal") | |
| if latent_space == "Z": | |
| z1 = batch_interp(z0, z1, epsilon, interp) | |
| print("Computing PPL IN", latent_space) | |
| distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device()) | |
| print(distances.shape) | |
| end = 0 | |
| n_samples = 0 | |
| for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")): | |
| start = end | |
| end = start + batch["img"].shape[0] | |
| n_samples += batch["img"].shape[0] | |
| batch_lat_e0 = tops.to_cuda(z0[start:end]) | |
| batch_lat_e1 = tops.to_cuda(z1[start:end]) | |
| if latent_space == "W": | |
| w0 = generator.get_w(batch_lat_e0, update_emas=False) | |
| w1 = generator.get_w(batch_lat_e1, update_emas=False) | |
| w1 = w0.lerp(w1, epsilon) # PPL end | |
| rgb1 = generator(**batch, w=w0)["img"] | |
| rgb2 = generator(**batch, w=w1)["img"] | |
| else: | |
| rgb1 = generator(**batch, z=batch_lat_e0)["img"] | |
| rgb2 = generator(**batch, z=batch_lat_e1)["img"] | |
| if rgb1.shape[-2] < upsample_size[0] or rgb1.shape[-1] < upsample_size[1]: | |
| rgb1 = resize(rgb1, upsample_size, antialias=True) | |
| rgb2 = resize(rgb2, upsample_size, antialias=True) | |
| rgb1 = utils.denormalize_img(rgb1).mul(255).byte() | |
| rgb2 = utils.denormalize_img(rgb2).mul(255).byte() | |
| sim = sample_similarity(rgb1, rgb2) | |
| dist_lat_e01 = sim / (epsilon ** 2) | |
| distances[start:end] = dist_lat_e01.view(-1) | |
| distances = distances[:n_samples] | |
| distances = tops.all_gather_uneven(distances).cpu().numpy() | |
| if tops.rank() != 0: | |
| return {"ppl/mean": -1, "ppl/std": -1} | |
| if tops.rank() == 0: | |
| cond, lo, hi = None, None, None | |
| if discard_percentile_lower is not None: | |
| lo = np.percentile(distances, discard_percentile_lower, interpolation='lower') | |
| cond = lo <= distances | |
| if discard_percentile_higher is not None: | |
| hi = np.percentile(distances, discard_percentile_higher, interpolation='higher') | |
| cond = np.logical_and(cond, distances <= hi) | |
| if cond is not None: | |
| distances = np.extract(cond, distances) | |
| return { | |
| "ppl/mean": float(np.mean(distances)), | |
| "ppl/std": float(np.std(distances)), | |
| } | |
| else: | |
| return {"ppl/mean"} | |