| |
| |
|
|
| """Script for calculating Frechet Inception Distance (FID).""" |
|
|
| import os |
| import glob |
| import pickle |
| import re |
| import json |
| import click |
| import tqdm |
| import numpy as np |
| import scipy.linalg |
|
|
| import torch |
|
|
| from fastgen.networks.inception import InceptionV3 |
| from fastgen.datasets.class_cond_dataset import ImageFolderDataset |
| from fastgen.utils.distributed import get_rank, is_rank0, synchronize, world_size |
| import fastgen.utils.logging_utils as logger |
| from fastgen.utils.io_utils import open_url |
| from fastgen.configs.data import DATA_ROOT_DIR |
|
|
|
|
| def calculate_inception_stats( |
| detector_net, |
| feature_dim, |
| image_path, |
| num_expected=None, |
| seed=0, |
| max_batch_size=64, |
| num_workers=3, |
| prefetch_factor=2, |
| device=torch.device("cuda"), |
| ): |
| |
| if not is_rank0(): |
| synchronize() |
|
|
| |
| logger.info(f'Loading images from "{image_path}"...') |
| dataset_obj = ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed) |
| if num_expected is not None and len(dataset_obj) < num_expected: |
| raise click.ClickException(f"Found {len(dataset_obj)} images, but expected at least {num_expected}") |
| if len(dataset_obj) < 2: |
| raise click.ClickException(f"Found {len(dataset_obj)} images, but need at least 2 to compute statistics") |
|
|
| |
| if is_rank0(): |
| synchronize() |
|
|
| |
| num_batches = ((len(dataset_obj) - 1) // (max_batch_size * world_size()) + 1) * world_size() |
| all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) |
| rank_batches = all_batches[get_rank() :: world_size()] |
| data_loader = torch.utils.data.DataLoader( |
| dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor |
| ) |
|
|
| |
| logger.info(f"Calculating statistics for {len(dataset_obj)} images...") |
| mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) |
| sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) |
| for data in tqdm.tqdm(data_loader, unit="batch", disable=(get_rank() != 0)): |
| synchronize() |
| images = data["real"] |
|
|
| if images.shape[0] == 0: |
| continue |
| if images.shape[1] == 1: |
| images = images.repeat([1, 3, 1, 1]) |
|
|
| with torch.no_grad(): |
| features = detector_net(images.to(device)) |
| features = features.to(torch.float64) |
|
|
| mu += features.sum(0) |
| sigma += features.T @ features |
|
|
| |
| if world_size() > 1: |
| torch.distributed.all_reduce(mu) |
| torch.distributed.all_reduce(sigma) |
| mu /= len(dataset_obj) |
| sigma -= mu.ger(mu) * len(dataset_obj) |
| sigma /= len(dataset_obj) - 1 |
| return mu.cpu().numpy(), sigma.cpu().numpy() |
|
|
|
|
| def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): |
| m = np.square(mu - mu_ref).sum() |
| s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) |
| fid = m + np.trace(sigma + sigma_ref - s * 2) |
| return float(np.real(fid)) |
|
|
|
|
| def calc( |
| samples_dir, num_expected, seed, min_ckpt, max_ckpt, batch, dataset, regenerate=False, device=torch.device("cuda") |
| ): |
| """Calculate FID for a given set of images.""" |
|
|
| ref = None |
| if dataset == "cifar10": |
| ref_path = f"{DATA_ROOT_DIR}/fid-refs/cifar10-32x32.npz" |
| elif dataset == "imagenet64": |
| ref_path = f"{DATA_ROOT_DIR}/fid-refs/imagenet-64x64.npz" |
| elif dataset == "imagenet64-edmv2": |
| ref_path = f"{DATA_ROOT_DIR}/fid-refs/imagenet-64x64-edmv2.npz" |
| elif dataset == "imagenet256": |
| ref_path = f"{DATA_ROOT_DIR}/fid-refs/imagenet_256.pkl" |
| else: |
| raise ValueError(f"Unknown dataset: {dataset}") |
| logger.info(f'Loading dataset reference statistics from "{ref_path}"...') |
| if is_rank0(): |
| if ref_path.endswith(".npz"): |
| with open_url(ref_path) as f: |
| ref = dict(np.load(f)) |
| else: |
| assert ref_path.endswith(".pkl"), f"Unknown file type: {ref_path}" |
| with open_url(ref_path) as f: |
| ref = pickle.load(f)["fid"] |
|
|
| stats = glob.glob(f"{samples_dir}/iter_[0-9]*") |
| stats.sort(key=lambda x: int(re.search(r"iter_(\d+)", x).group(1))) |
|
|
| ckpt_num_list = [] |
| fid_list = [] |
| if os.path.exists(f"{samples_dir}/fid.json"): |
| with open(f"{samples_dir}/fid.json", "r") as f: |
| metric_scores = json.load(f) |
| logger.info(f"metric_scores in the existing file: {metric_scores}") |
| ckpt_num_list = metric_scores["ckpt_num"] |
| fid_list = metric_scores["fid"] |
|
|
| |
| logger.info("Loading Inception-v3 model...") |
| feature_dim = 2048 |
| |
| |
| detector_net = InceptionV3().to(device) |
| detector_net.eval() |
|
|
| for path in stats: |
| ckpt_num = int(re.search(r"iter_(\d+)", path).group(1)) |
|
|
| if ckpt_num in ckpt_num_list and not regenerate: |
| logger.info(f"ckpt {ckpt_num} already has metrics. Skip.") |
| continue |
|
|
| if ckpt_num < min_ckpt or ckpt_num > max_ckpt: |
| continue |
|
|
| mu, sigma = calculate_inception_stats( |
| detector_net, feature_dim, image_path=path, num_expected=num_expected, seed=seed, max_batch_size=batch |
| ) |
|
|
| logger.info(f"Calculating FID for {path}... ") |
| if is_rank0(): |
| fid = calculate_fid_from_inception_stats(mu, sigma, ref["mu"], ref["sigma"]) |
| logger.info(f"path: {path}") |
| logger.info(f"FID: {fid}") |
| logger.info("=" * 20) |
| fid_list.append(fid) |
| ckpt_num_list.append(ckpt_num) |
|
|
| synchronize() |
|
|
| |
| if is_rank0(): |
| metric_scores = {} |
| |
| if os.path.exists(f"{samples_dir}/fid.json"): |
| with open(f"{samples_dir}/fid.json", "r") as f: |
| metric_scores = json.load(f) |
| metric_scores = {ckpt: fid for ckpt, fid in zip(metric_scores["ckpt_num"], metric_scores["fid"])} |
|
|
| |
| for ckpt, fid in zip(ckpt_num_list, fid_list): |
| metric_scores[ckpt] = fid |
| metric_scores = sorted(metric_scores.items(), key=lambda x: x[0]) |
| metric_scores = {"ckpt_num": [ckpt for ckpt, _ in metric_scores], "fid": [fid for _, fid in metric_scores]} |
| with open(f"{samples_dir}/fid.json", "w") as f: |
| json.dump(metric_scores, f) |
|
|