File size: 6,823 Bytes
0839907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Script for calculating Frechet Inception Distance (FID)."""

import os
import glob
import pickle
import re
import json
import click
import tqdm
import numpy as np
import scipy.linalg

import torch

from fastgen.networks.inception import InceptionV3
from fastgen.datasets.class_cond_dataset import ImageFolderDataset
from fastgen.utils.distributed import get_rank, is_rank0, synchronize, world_size
import fastgen.utils.logging_utils as logger
from fastgen.utils.io_utils import open_url
from fastgen.configs.data import DATA_ROOT_DIR


def calculate_inception_stats(
    detector_net,
    feature_dim,
    image_path,
    num_expected=None,
    seed=0,
    max_batch_size=64,
    num_workers=3,
    prefetch_factor=2,
    device=torch.device("cuda"),
):
    # Rank 0 goes first.
    if not is_rank0():
        synchronize()

    # List images.
    logger.info(f'Loading images from "{image_path}"...')
    dataset_obj = ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
    if num_expected is not None and len(dataset_obj) < num_expected:
        raise click.ClickException(f"Found {len(dataset_obj)} images, but expected at least {num_expected}")
    if len(dataset_obj) < 2:
        raise click.ClickException(f"Found {len(dataset_obj)} images, but need at least 2 to compute statistics")

    # Other ranks follow.
    if is_rank0():
        synchronize()

    # Divide images into batches.
    num_batches = ((len(dataset_obj) - 1) // (max_batch_size * world_size()) + 1) * world_size()
    all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
    rank_batches = all_batches[get_rank() :: world_size()]
    data_loader = torch.utils.data.DataLoader(
        dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor
    )

    # Accumulate statistics.
    logger.info(f"Calculating statistics for {len(dataset_obj)} images...")
    mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
    sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
    for data in tqdm.tqdm(data_loader, unit="batch", disable=(get_rank() != 0)):
        synchronize()
        images = data["real"]

        if images.shape[0] == 0:
            continue
        if images.shape[1] == 1:
            images = images.repeat([1, 3, 1, 1])

        with torch.no_grad():
            features = detector_net(images.to(device))
            features = features.to(torch.float64)

        mu += features.sum(0)
        sigma += features.T @ features

    # Calculate grand totals.
    if world_size() > 1:
        torch.distributed.all_reduce(mu)
        torch.distributed.all_reduce(sigma)
    mu /= len(dataset_obj)
    sigma -= mu.ger(mu) * len(dataset_obj)
    sigma /= len(dataset_obj) - 1
    return mu.cpu().numpy(), sigma.cpu().numpy()


def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
    m = np.square(mu - mu_ref).sum()
    s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
    fid = m + np.trace(sigma + sigma_ref - s * 2)
    return float(np.real(fid))


def calc(
    samples_dir, num_expected, seed, min_ckpt, max_ckpt, batch, dataset, regenerate=False, device=torch.device("cuda")
):
    """Calculate FID for a given set of images."""

    ref = None
    if dataset == "cifar10":
        ref_path = f"{DATA_ROOT_DIR}/fid-refs/cifar10-32x32.npz"
    elif dataset == "imagenet64":
        ref_path = f"{DATA_ROOT_DIR}/fid-refs/imagenet-64x64.npz"
    elif dataset == "imagenet64-edmv2":
        ref_path = f"{DATA_ROOT_DIR}/fid-refs/imagenet-64x64-edmv2.npz"
    elif dataset == "imagenet256":
        ref_path = f"{DATA_ROOT_DIR}/fid-refs/imagenet_256.pkl"
    else:
        raise ValueError(f"Unknown dataset: {dataset}")
    logger.info(f'Loading dataset reference statistics from "{ref_path}"...')
    if is_rank0():
        if ref_path.endswith(".npz"):
            with open_url(ref_path) as f:
                ref = dict(np.load(f))
        else:
            assert ref_path.endswith(".pkl"), f"Unknown file type: {ref_path}"
            with open_url(ref_path) as f:
                ref = pickle.load(f)["fid"]

    stats = glob.glob(f"{samples_dir}/iter_[0-9]*")
    stats.sort(key=lambda x: int(re.search(r"iter_(\d+)", x).group(1)))

    ckpt_num_list = []
    fid_list = []
    if os.path.exists(f"{samples_dir}/fid.json"):
        with open(f"{samples_dir}/fid.json", "r") as f:
            metric_scores = json.load(f)
        logger.info(f"metric_scores in the existing file: {metric_scores}")
        ckpt_num_list = metric_scores["ckpt_num"]
        fid_list = metric_scores["fid"]

    # Load Inception-v3 model.
    logger.info("Loading Inception-v3 model...")
    feature_dim = 2048
    # block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[feature_dim]
    # detector_net = InceptionV3([block_idx], resize_input=False, normalize_input=False).to(device)
    detector_net = InceptionV3().to(device)
    detector_net.eval()

    for path in stats:
        ckpt_num = int(re.search(r"iter_(\d+)", path).group(1))

        if ckpt_num in ckpt_num_list and not regenerate:
            logger.info(f"ckpt {ckpt_num} already has metrics. Skip.")
            continue

        if ckpt_num < min_ckpt or ckpt_num > max_ckpt:
            continue

        mu, sigma = calculate_inception_stats(
            detector_net, feature_dim, image_path=path, num_expected=num_expected, seed=seed, max_batch_size=batch
        )

        logger.info(f"Calculating FID for {path}... ")
        if is_rank0():
            fid = calculate_fid_from_inception_stats(mu, sigma, ref["mu"], ref["sigma"])
            logger.info(f"path: {path}")
            logger.info(f"FID: {fid}")
            logger.info("=" * 20)
            fid_list.append(fid)
            ckpt_num_list.append(ckpt_num)

        synchronize()

    # dump the FID scores to a json file
    if is_rank0():
        metric_scores = {}
        # read metrics again in case another process altered file
        if os.path.exists(f"{samples_dir}/fid.json"):
            with open(f"{samples_dir}/fid.json", "r") as f:
                metric_scores = json.load(f)
            metric_scores = {ckpt: fid for ckpt, fid in zip(metric_scores["ckpt_num"], metric_scores["fid"])}

        # merge metrics
        for ckpt, fid in zip(ckpt_num_list, fid_list):
            metric_scores[ckpt] = fid
        metric_scores = sorted(metric_scores.items(), key=lambda x: x[0])
        metric_scores = {"ckpt_num": [ckpt for ckpt, _ in metric_scores], "fid": [fid for _, fid in metric_scores]}
        with open(f"{samples_dir}/fid.json", "w") as f:
            json.dump(metric_scores, f)