Spaces:
Sleeping
Sleeping
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| """Calculate quality metrics for previous training run or pretrained network pickle.""" | |
| import sys; sys.path.extend(['.', 'src']) | |
| import os | |
| import re | |
| import click | |
| import json | |
| import tempfile | |
| import copy | |
| import torch | |
| from src import dnnlib | |
| from omegaconf import OmegaConf | |
| import legacy | |
| from metrics import metric_main | |
| from metrics import metric_utils | |
| from src.torch_utils import training_stats | |
| from src.torch_utils import custom_ops | |
| from src.torch_utils import misc | |
| #---------------------------------------------------------------------------- | |
| def subprocess_fn(rank, args, temp_dir): | |
| dnnlib.util.Logger(should_flush=True) | |
| # Init torch.distributed. | |
| if args.num_gpus > 1: | |
| init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) | |
| if os.name == 'nt': | |
| init_method = 'file:///' + init_file.replace('\\', '/') | |
| torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) | |
| else: | |
| init_method = f'file://{init_file}' | |
| torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) | |
| # Init torch_utils. | |
| sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None | |
| training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) | |
| if rank != 0 or not args.verbose: | |
| custom_ops.verbosity = 'none' | |
| # Print network summary. | |
| device = torch.device('cuda', rank) | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) | |
| if rank == 0 and args.verbose: | |
| z = torch.empty([8, G.z_dim], device=device) | |
| c = torch.empty([8, G.c_dim], device=device) | |
| t = torch.zeros([8, G.cfg.sampling.num_frames_per_video], device=device).long() | |
| misc.print_module_summary(G, [z, c, t]) | |
| # Calculate each metric. | |
| for metric in args.metrics: | |
| if rank == 0 and args.verbose: | |
| print(f'Calculating {metric}...') | |
| progress = metric_utils.ProgressMonitor(verbose=args.verbose) | |
| result_dict = metric_main.calc_metric( | |
| metric=metric, | |
| G=G, | |
| dataset_kwargs=args.dataset_kwargs, | |
| num_gpus=args.num_gpus, | |
| rank=rank, | |
| device=device, | |
| progress=progress, | |
| cache=args.use_cache, | |
| num_runs=(1 if metric == 'fid50k_full' else args.num_runs), | |
| ) | |
| if rank == 0: | |
| metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) | |
| if rank == 0 and args.verbose: | |
| print() | |
| # Done. | |
| if rank == 0 and args.verbose: | |
| print('Exiting...') | |
| #---------------------------------------------------------------------------- | |
| class CommaSeparatedList(click.ParamType): | |
| name = 'list' | |
| def convert(self, value, param, ctx): | |
| _ = param, ctx | |
| if value is None or value.lower() == 'none' or value == '': | |
| return [] | |
| return value.split(',') | |
| #---------------------------------------------------------------------------- | |
| def calc_metrics(ctx, network_pkl, networks_dir, metrics, data, mirror, gpus, cfg_path, verbose, use_cache: bool, num_runs: int): | |
| """Calculate quality metrics for previous training run or pretrained network pickle. | |
| Examples: | |
| \b | |
| # Previous training run: look up options automatically, save result to JSONL file. | |
| python calc_metrics.py --metrics=pr50k3_full \\ | |
| --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl | |
| \b | |
| # Pre-trained network pickle: specify dataset explicitly, print result to stdout. | |
| python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\ | |
| --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl | |
| Available metrics: | |
| \b | |
| ADA paper: | |
| fid50k_full Frechet inception distance against the full dataset. | |
| kid50k_full Kernel inception distance against the full dataset. | |
| pr50k3_full Precision and recall againt the full dataset. | |
| is50k Inception score for CIFAR-10. | |
| \b | |
| StyleGAN and StyleGAN2 papers: | |
| fid50k Frechet inception distance against 50k real images. | |
| kid50k Kernel inception distance against 50k real images. | |
| pr50k3 Precision and recall against 50k real images. | |
| ppl2_wend Perceptual path length in W at path endpoints against full image. | |
| ppl_zfull Perceptual path length in Z for full paths against cropped image. | |
| ppl_wfull Perceptual path length in W for full paths against cropped image. | |
| ppl_zend Perceptual path length in Z at path endpoints against cropped image. | |
| ppl_wend Perceptual path length in W at path endpoints against cropped image. | |
| """ | |
| dnnlib.util.Logger(should_flush=True) | |
| if network_pkl is None: | |
| output_regex = "^network-snapshot-\d{6}.pkl$" | |
| ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$") | |
| # ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)]) | |
| # network_pkl = os.path.join(networks_dir, ckpts[-1]) | |
| metrics_file = os.path.join(networks_dir, 'metric-fvd2048_16f.jsonl') | |
| with open(metrics_file, 'r') as f: | |
| snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()] | |
| best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results']['fvd2048_16f'])[0] | |
| network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl']) | |
| print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results']['fvd2048_16f']) | |
| # Selecting a checkpoint with the best score | |
| # Validate arguments. | |
| args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) | |
| if cfg_path == "auto": | |
| # Assuming that `network_pkl` has the structure /path/to/experiment/output/network-X.pkl | |
| output_path = os.path.dirname(network_pkl) | |
| assert os.path.basename(output_path) == "output", f"Unknown path structure: {output_path}" | |
| experiment_path = os.path.dirname(output_path) | |
| cfg_path = os.path.join(experiment_path, 'experiment_config.yaml') | |
| cfg = OmegaConf.load(cfg_path) | |
| if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): | |
| ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) | |
| if not args.num_gpus >= 1: | |
| ctx.fail('--gpus must be at least 1') | |
| # Load network. | |
| if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): | |
| ctx.fail('--network must point to a file or URL') | |
| if args.verbose: | |
| print(f'Loading network from "{network_pkl}"...') | |
| with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: | |
| network_dict = legacy.load_network_pkl(f) | |
| args.G = network_dict['G_ema'] # subclass of torch.nn.Module | |
| from src.training.networks import Generator | |
| G = args.G | |
| G.cfg.z_dim = G.z_dim | |
| G_new = Generator( | |
| w_dim=G.cfg.w_dim, | |
| mapping_kwargs=dnnlib.EasyDict(num_layers=G.cfg.get('mapping_net_n_layers', 2), cfg=G.cfg), | |
| synthesis_kwargs=dnnlib.EasyDict( | |
| channel_base=int(G.cfg.get('fmaps', 0.5) * 32768), | |
| channel_max=G.cfg.get('channel_max', 512), | |
| num_fp16_res=4, | |
| conv_clamp=256, | |
| ), | |
| cfg=G.cfg, | |
| img_resolution=256, | |
| img_channels=3, | |
| c_dim=G.cfg.c_dim, | |
| ).eval() | |
| G_new.load_state_dict(G.state_dict()) | |
| args.G = G_new | |
| # Initialize dataset options. | |
| if data is not None: | |
| args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.VideoFramesFolderDataset', cfg=cfg.dataset, path=data) | |
| elif network_dict['training_set_kwargs'] is not None: | |
| args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) | |
| else: | |
| ctx.fail('Could not look up dataset options; please specify --data') | |
| # Finalize dataset options. | |
| args.dataset_kwargs.resolution = args.G.img_resolution | |
| args.dataset_kwargs.use_labels = (args.G.c_dim != 0) | |
| if mirror is not None: | |
| args.dataset_kwargs.xflip = mirror | |
| args.use_cache = use_cache | |
| args.num_runs = num_runs | |
| # Print dataset options. | |
| if args.verbose: | |
| print('Dataset options:') | |
| print(cfg.dataset) | |
| # Locate run dir. | |
| args.run_dir = None | |
| if os.path.isfile(network_pkl): | |
| pkl_dir = os.path.dirname(network_pkl) | |
| if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): | |
| args.run_dir = pkl_dir | |
| # Launch processes. | |
| if args.verbose: | |
| print('Launching processes...') | |
| torch.multiprocessing.set_start_method('spawn') | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| if args.num_gpus == 1: | |
| subprocess_fn(rank=0, args=args, temp_dir=temp_dir) | |
| else: | |
| torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) | |
| #---------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| calc_metrics() # pylint: disable=no-value-for-parameter | |
| #---------------------------------------------------------------------------- | |