Spaces:
Sleeping
Sleeping
| """ | |
| This script computes imgs/sec for a generator in the eval mode | |
| for different batch sizes | |
| """ | |
| import sys; sys.path.extend(['..', '.', 'src']) | |
| import time | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import hydra | |
| from hydra.experimental import initialize | |
| from omegaconf import DictConfig, OmegaConf | |
| from tqdm import tqdm | |
| import torch.autograd.profiler as profiler | |
| from src import dnnlib | |
| from src.infra.utils import recursive_instantiate | |
| DEVICE = 'cuda' | |
| BATCH_SIZES = [32] | |
| NUM_WARMUP_ITERS = 5 | |
| NUM_PROFILE_ITERS = 25 | |
| def instantiate_G(cfg: DictConfig) -> nn.Module: | |
| G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict()) | |
| G_kwargs.synthesis_kwargs.channel_base = int(cfg.model.generator.get('fmaps', 0.5) * 32768) | |
| G_kwargs.synthesis_kwargs.channel_max = 512 | |
| G_kwargs.mapping_kwargs.num_layers = cfg.model.generator.get('mapping_net_n_layers', 2) | |
| if cfg.get('num_fp16_res', 0) > 0: | |
| G_kwargs.synthesis_kwargs.num_fp16_res = cfg.num_fp16_res | |
| G_kwargs.synthesis_kwargs.conv_clamp = 256 | |
| G_kwargs.cfg = cfg.model.generator | |
| G_kwargs.c_dim = 0 | |
| G_kwargs.img_resolution = cfg.get('resolution', 256) | |
| G_kwargs.img_channels = 3 | |
| G = dnnlib.util.construct_class_by_name(**G_kwargs).eval().requires_grad_(False).to(DEVICE) | |
| return G | |
| def profile_for_batch_size(G: nn.Module, cfg: DictConfig, batch_size: int): | |
| z = torch.randn(batch_size, G.z_dim, device=DEVICE) | |
| c = torch.zeros(batch_size, G.c_dim, device=DEVICE) | |
| t = torch.zeros(batch_size, 2, device=DEVICE) | |
| times = [] | |
| for i in tqdm(range(NUM_WARMUP_ITERS), desc='Warming up'): | |
| torch.cuda.synchronize() | |
| fake_img = G(z, c=c, t=t).contiguous() | |
| y = fake_img[0, 0, 0, 0].item() # sync | |
| torch.cuda.synchronize() | |
| time.sleep(1) | |
| torch.cuda.reset_peak_memory_stats() | |
| with profiler.profile(record_shapes=True, use_cuda=True) as prof: | |
| for i in tqdm(range(NUM_PROFILE_ITERS), desc='Profiling'): | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| with profiler.record_function("forward"): | |
| fake_img = G(z, c=c, t=t).contiguous() | |
| y = fake_img[0, 0, 0, 0].item() # sync | |
| torch.cuda.synchronize() | |
| times.append(time.time() - start_time) | |
| torch.cuda.empty_cache() | |
| num_imgs_processed = len(times) * batch_size | |
| total_time_spent = np.sum(times) | |
| bandwidth = num_imgs_processed / total_time_spent | |
| summary = prof.key_averages().table(sort_by="cpu_time_total", row_limit=10) | |
| print(f'[Batch size: {batch_size}] Mean: {np.mean(times):.05f}s/it. Std: {np.std(times):.05f}s') | |
| print(f'[Batch size: {batch_size}] Imgs/sec: {bandwidth:.03f}') | |
| print(f'[Batch size: {batch_size}] Max mem: {torch.cuda.max_memory_allocated(DEVICE) / 2**30:<6.2f} gb') | |
| return bandwidth, summary | |
| def profile(cfg: DictConfig): | |
| recursive_instantiate(cfg) | |
| G = instantiate_G(cfg) | |
| bandwidths = [] | |
| summaries = [] | |
| print(f'Number of parameters: {sum(p.numel() for p in G.parameters())}') | |
| for batch_size in BATCH_SIZES: | |
| bandwidth, summary = profile_for_batch_size(G, cfg, batch_size) | |
| bandwidths.append(bandwidth) | |
| summaries.append(summary) | |
| best_batch_size_idx = int(np.argmax(bandwidths)) | |
| print(f'------------ Best batch size is {BATCH_SIZES[best_batch_size_idx]} ------------') | |
| print(summaries[best_batch_size_idx]) | |
| if __name__ == '__main__': | |
| profile() | |