LD3 / compute_fid.py
vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
import numpy as np
import pickle
from tqdm import tqdm
from utils import (
parse_arguments,
check_fid_file,
prepare_paths,
adjust_hyper,
get_solvers,
set_seed_everything,
)
from models import prepare_stuff, prepare_condition_loader
import math
import dnnlib
import pickle
import scipy
from torch.nn.functional import adaptive_avg_pool2d
from pytorch_fid.inception import InceptionV3
from gen_data import Generator, get_data_inverse_scaler
def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
m = np.square(mu - mu_ref).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
fid = m + np.trace(sigma + sigma_ref - s * 2)
return float(np.real(fid))
def main(args):
if not args.use_ema:
print("Auto update use_ema to True for evaluation")
args.use_ema = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Start sampling...")
# laten-diff evaluation
FEATURE_DIM = 2048
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FEATURE_DIM]
fid_model = InceptionV3([block_idx]).to(device)
fid_model.eval()
# edm evalutaion
DETECTOR_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl"
with dnnlib.util.open_url(DETECTOR_URL, verbose=True) as f:
detector_net = pickle.load(f).to(device)
with dnnlib.util.open_url(args.ref_path) as f:
ref = dict(np.load(f))
wrapped_model, model, decoding_fn, noise_schedule, latent_resolution, latent_channel, _, _ = prepare_stuff(args)
condition_loader = prepare_condition_loader(model_type=args.model,
model=model,
scale=args.scale if hasattr(args, "scale") else None,
condition=args.prompt_path or "uniform",
sampling_batch_size=args.sampling_batch_size,
num_prompt=None,
)
adjust_hyper(args, latent_resolution, latent_channel)
_, _, skip_type = prepare_paths(args)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
solver, steps, solver_extra_params = get_solvers(
args.solver_name,
NFEs=args.steps,
order=args.order,
noise_schedule=noise_schedule,
unipc_variant=args.unipc_variant,
)
generator = Generator(
noise_schedule=noise_schedule,
solver=solver,
order=args.order,
skip_type=skip_type,
load_from=args.load_from,
timesteps_1=args.custom_ts_1,
timesteps_2=args.custom_ts_2,
steps=steps,
solver_extra_params=solver_extra_params,
device=device,
)
print(generator.timesteps, generator.timesteps2)
inverse_scalar = get_data_inverse_scaler(centered=True)
num_batches = math.ceil(args.total_samples / args.sampling_batch_size)
batch_size = args.sampling_batch_size
n_total_samples = batch_size * num_batches
mu = torch.zeros([FEATURE_DIM], dtype=torch.float64, device=device)
sigma = torch.zeros([FEATURE_DIM, FEATURE_DIM], dtype=torch.float64, device=device)
act_arr = np.empty((n_total_samples, FEATURE_DIM))
start_idx=0
with torch.no_grad():
for index in tqdm(range(num_batches)):
current_batch_size = min(batch_size, args.total_samples - index * batch_size)
sampling_shape = (current_batch_size, latent_channel, latent_resolution, latent_resolution)
latents = torch.randn(sampling_shape, device=device)
if condition_loader is not None:
conditioning, conditioned_unconditioning = next(condition_loader)
else:
conditioning = None
conditioned_unconditioning = None
img_teacher = generator.sample(wrapped_model, decoding_fn, latents, conditioning, conditioned_unconditioning)
img_teacher = inverse_scalar(img_teacher)
samples_edm = 255 * img_teacher
images = torch.clip(samples_edm, 0, 255).to(torch.uint8)
features = detector_net(images.to(device), return_features=True).to(
torch.float64
)
mu += features.sum(0)
sigma += features.T @ features
samples_latent_diff = torch.clamp(img_teacher, min=0.0, max=1.0)
with torch.no_grad():
pred = fid_model(samples_latent_diff.float())[0]
# If model output is not scalar, apply global spatial average pooling.
# This happens if you choose a dimensionality not equal 2048.
if pred.size(2) != 1 or pred.size(3) != 1:
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
act_arr[start_idx:start_idx + pred.shape[0]] = pred
start_idx = start_idx + pred.shape[0]
mu /= n_total_samples
sigma -= mu.ger(mu) * n_total_samples
sigma /= n_total_samples - 1
mu = mu.cpu().numpy()
sigma = sigma.cpu().numpy()
fid_edm = calculate_fid_from_inception_stats(mu, sigma, ref["mu"], ref["sigma"])
mu = np.mean(act_arr, axis=0)
sigma = np.cov(act_arr, rowvar=False)
fid_latent_diff = calculate_fid_from_inception_stats(mu, sigma, ref["mu"], ref["sigma"])
print("FID EDM: {:.4f}".format(fid_edm))
print("FID LD: {:.4f}".format(fid_latent_diff))
if __name__ == "__main__":
args = parse_arguments()
set_seed_everything(args.seed)
main(args)