FashionFlow / src /scripts /calc_metrics.py
tasin
init
f075308
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Calculate quality metrics for previous training run or pretrained network pickle."""
import sys; sys.path.extend(['.', 'src'])
import os
import re
import click
import json
import tempfile
import copy
import torch
from src import dnnlib
from omegaconf import OmegaConf
import legacy
from metrics import metric_main
from metrics import metric_utils
from src.torch_utils import training_stats
from src.torch_utils import custom_ops
from src.torch_utils import misc
#----------------------------------------------------------------------------
def subprocess_fn(rank, args, temp_dir):
dnnlib.util.Logger(should_flush=True)
# Init torch.distributed.
if args.num_gpus > 1:
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
if os.name == 'nt':
init_method = 'file:///' + init_file.replace('\\', '/')
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
else:
init_method = f'file://{init_file}'
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
# Init torch_utils.
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
if rank != 0 or not args.verbose:
custom_ops.verbosity = 'none'
# Print network summary.
device = torch.device('cuda', rank)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
if rank == 0 and args.verbose:
z = torch.empty([8, G.z_dim], device=device)
c = torch.empty([8, G.c_dim], device=device)
t = torch.zeros([8, G.cfg.sampling.num_frames_per_video], device=device).long()
misc.print_module_summary(G, [z, c, t])
# Calculate each metric.
for metric in args.metrics:
if rank == 0 and args.verbose:
print(f'Calculating {metric}...')
progress = metric_utils.ProgressMonitor(verbose=args.verbose)
result_dict = metric_main.calc_metric(
metric=metric,
G=G,
dataset_kwargs=args.dataset_kwargs,
num_gpus=args.num_gpus,
rank=rank,
device=device,
progress=progress,
cache=args.use_cache,
num_runs=(1 if metric == 'fid50k_full' else args.num_runs),
)
if rank == 0:
metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
if rank == 0 and args.verbose:
print()
# Done.
if rank == 0 and args.verbose:
print('Exiting...')
#----------------------------------------------------------------------------
class CommaSeparatedList(click.ParamType):
name = 'list'
def convert(self, value, param, ctx):
_ = param, ctx
if value is None or value.lower() == 'none' or value == '':
return []
return value.split(',')
#----------------------------------------------------------------------------
@click.command()
@click.pass_context
@click.option('--network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH')
@click.option('--networks_dir', '--networks_dir', help='Path to the experiment directory if the latest checkpoint is requested.', metavar='PATH')
@click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True)
@click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
@click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL')
@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
@click.option('--cfg_path', help='Path to the experiments config', type=str, default="auto", metavar='PATH')
@click.option('--verbose', help='Print optional information', type=bool, default=False, metavar='BOOL', show_default=True)
@click.option('--use_cache', help='Should we use the cache file?', type=bool, default=True, metavar='BOOL', show_default=True)
@click.option('--num_runs', help='Number of runs', type=int, default=1, metavar='INT', show_default=True)
def calc_metrics(ctx, network_pkl, networks_dir, metrics, data, mirror, gpus, cfg_path, verbose, use_cache: bool, num_runs: int):
"""Calculate quality metrics for previous training run or pretrained network pickle.
Examples:
\b
# Previous training run: look up options automatically, save result to JSONL file.
python calc_metrics.py --metrics=pr50k3_full \\
--network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl
\b
# Pre-trained network pickle: specify dataset explicitly, print result to stdout.
python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
Available metrics:
\b
ADA paper:
fid50k_full Frechet inception distance against the full dataset.
kid50k_full Kernel inception distance against the full dataset.
pr50k3_full Precision and recall againt the full dataset.
is50k Inception score for CIFAR-10.
\b
StyleGAN and StyleGAN2 papers:
fid50k Frechet inception distance against 50k real images.
kid50k Kernel inception distance against 50k real images.
pr50k3 Precision and recall against 50k real images.
ppl2_wend Perceptual path length in W at path endpoints against full image.
ppl_zfull Perceptual path length in Z for full paths against cropped image.
ppl_wfull Perceptual path length in W for full paths against cropped image.
ppl_zend Perceptual path length in Z at path endpoints against cropped image.
ppl_wend Perceptual path length in W at path endpoints against cropped image.
"""
dnnlib.util.Logger(should_flush=True)
if network_pkl is None:
output_regex = "^network-snapshot-\d{6}.pkl$"
ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$")
# ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)])
# network_pkl = os.path.join(networks_dir, ckpts[-1])
metrics_file = os.path.join(networks_dir, 'metric-fvd2048_16f.jsonl')
with open(metrics_file, 'r') as f:
snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()]
best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results']['fvd2048_16f'])[0]
network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl'])
print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results']['fvd2048_16f'])
# Selecting a checkpoint with the best score
# Validate arguments.
args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
if cfg_path == "auto":
# Assuming that `network_pkl` has the structure /path/to/experiment/output/network-X.pkl
output_path = os.path.dirname(network_pkl)
assert os.path.basename(output_path) == "output", f"Unknown path structure: {output_path}"
experiment_path = os.path.dirname(output_path)
cfg_path = os.path.join(experiment_path, 'experiment_config.yaml')
cfg = OmegaConf.load(cfg_path)
if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
if not args.num_gpus >= 1:
ctx.fail('--gpus must be at least 1')
# Load network.
if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
ctx.fail('--network must point to a file or URL')
if args.verbose:
print(f'Loading network from "{network_pkl}"...')
with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
network_dict = legacy.load_network_pkl(f)
args.G = network_dict['G_ema'] # subclass of torch.nn.Module
from src.training.networks import Generator
G = args.G
G.cfg.z_dim = G.z_dim
G_new = Generator(
w_dim=G.cfg.w_dim,
mapping_kwargs=dnnlib.EasyDict(num_layers=G.cfg.get('mapping_net_n_layers', 2), cfg=G.cfg),
synthesis_kwargs=dnnlib.EasyDict(
channel_base=int(G.cfg.get('fmaps', 0.5) * 32768),
channel_max=G.cfg.get('channel_max', 512),
num_fp16_res=4,
conv_clamp=256,
),
cfg=G.cfg,
img_resolution=256,
img_channels=3,
c_dim=G.cfg.c_dim,
).eval()
G_new.load_state_dict(G.state_dict())
args.G = G_new
# Initialize dataset options.
if data is not None:
args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.VideoFramesFolderDataset', cfg=cfg.dataset, path=data)
elif network_dict['training_set_kwargs'] is not None:
args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
else:
ctx.fail('Could not look up dataset options; please specify --data')
# Finalize dataset options.
args.dataset_kwargs.resolution = args.G.img_resolution
args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
if mirror is not None:
args.dataset_kwargs.xflip = mirror
args.use_cache = use_cache
args.num_runs = num_runs
# Print dataset options.
if args.verbose:
print('Dataset options:')
print(cfg.dataset)
# Locate run dir.
args.run_dir = None
if os.path.isfile(network_pkl):
pkl_dir = os.path.dirname(network_pkl)
if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
args.run_dir = pkl_dir
# Launch processes.
if args.verbose:
print('Launching processes...')
torch.multiprocessing.set_start_method('spawn')
with tempfile.TemporaryDirectory() as temp_dir:
if args.num_gpus == 1:
subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
else:
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
#----------------------------------------------------------------------------
if __name__ == "__main__":
calc_metrics() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------