FashionFlow / src /scripts /profile_model.py
tasin
init
f075308
"""
This script computes imgs/sec for a generator in the eval mode
for different batch sizes
"""
import sys; sys.path.extend(['..', '.', 'src'])
import time
import numpy as np
import torch
import torch.nn as nn
import hydra
from hydra.experimental import initialize
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
import torch.autograd.profiler as profiler
from src import dnnlib
from src.infra.utils import recursive_instantiate
DEVICE = 'cuda'
BATCH_SIZES = [32]
NUM_WARMUP_ITERS = 5
NUM_PROFILE_ITERS = 25
def instantiate_G(cfg: DictConfig) -> nn.Module:
G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
G_kwargs.synthesis_kwargs.channel_base = int(cfg.model.generator.get('fmaps', 0.5) * 32768)
G_kwargs.synthesis_kwargs.channel_max = 512
G_kwargs.mapping_kwargs.num_layers = cfg.model.generator.get('mapping_net_n_layers', 2)
if cfg.get('num_fp16_res', 0) > 0:
G_kwargs.synthesis_kwargs.num_fp16_res = cfg.num_fp16_res
G_kwargs.synthesis_kwargs.conv_clamp = 256
G_kwargs.cfg = cfg.model.generator
G_kwargs.c_dim = 0
G_kwargs.img_resolution = cfg.get('resolution', 256)
G_kwargs.img_channels = 3
G = dnnlib.util.construct_class_by_name(**G_kwargs).eval().requires_grad_(False).to(DEVICE)
return G
@torch.no_grad()
def profile_for_batch_size(G: nn.Module, cfg: DictConfig, batch_size: int):
z = torch.randn(batch_size, G.z_dim, device=DEVICE)
c = torch.zeros(batch_size, G.c_dim, device=DEVICE)
t = torch.zeros(batch_size, 2, device=DEVICE)
times = []
for i in tqdm(range(NUM_WARMUP_ITERS), desc='Warming up'):
torch.cuda.synchronize()
fake_img = G(z, c=c, t=t).contiguous()
y = fake_img[0, 0, 0, 0].item() # sync
torch.cuda.synchronize()
time.sleep(1)
torch.cuda.reset_peak_memory_stats()
with profiler.profile(record_shapes=True, use_cuda=True) as prof:
for i in tqdm(range(NUM_PROFILE_ITERS), desc='Profiling'):
torch.cuda.synchronize()
start_time = time.time()
with profiler.record_function("forward"):
fake_img = G(z, c=c, t=t).contiguous()
y = fake_img[0, 0, 0, 0].item() # sync
torch.cuda.synchronize()
times.append(time.time() - start_time)
torch.cuda.empty_cache()
num_imgs_processed = len(times) * batch_size
total_time_spent = np.sum(times)
bandwidth = num_imgs_processed / total_time_spent
summary = prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)
print(f'[Batch size: {batch_size}] Mean: {np.mean(times):.05f}s/it. Std: {np.std(times):.05f}s')
print(f'[Batch size: {batch_size}] Imgs/sec: {bandwidth:.03f}')
print(f'[Batch size: {batch_size}] Max mem: {torch.cuda.max_memory_allocated(DEVICE) / 2**30:<6.2f} gb')
return bandwidth, summary
@hydra.main(config_path="../../configs", config_name="config.yaml")
def profile(cfg: DictConfig):
recursive_instantiate(cfg)
G = instantiate_G(cfg)
bandwidths = []
summaries = []
print(f'Number of parameters: {sum(p.numel() for p in G.parameters())}')
for batch_size in BATCH_SIZES:
bandwidth, summary = profile_for_batch_size(G, cfg, batch_size)
bandwidths.append(bandwidth)
summaries.append(summary)
best_batch_size_idx = int(np.argmax(bandwidths))
print(f'------------ Best batch size is {BATCH_SIZES[best_batch_size_idx]} ------------')
print(summaries[best_batch_size_idx])
if __name__ == '__main__':
profile()