File size: 3,655 Bytes
f075308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
This script computes imgs/sec for a generator in the eval mode
for different batch sizes
"""
import sys; sys.path.extend(['..', '.', 'src'])
import time

import numpy as np
import torch
import torch.nn as nn
import hydra
from hydra.experimental import initialize
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
import torch.autograd.profiler as profiler

from src import dnnlib
from src.infra.utils import recursive_instantiate


DEVICE = 'cuda'
BATCH_SIZES = [32]
NUM_WARMUP_ITERS = 5
NUM_PROFILE_ITERS = 25


def instantiate_G(cfg: DictConfig) -> nn.Module:
    G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
    G_kwargs.synthesis_kwargs.channel_base = int(cfg.model.generator.get('fmaps', 0.5) * 32768)
    G_kwargs.synthesis_kwargs.channel_max = 512
    G_kwargs.mapping_kwargs.num_layers = cfg.model.generator.get('mapping_net_n_layers', 2)
    if cfg.get('num_fp16_res', 0) > 0:
        G_kwargs.synthesis_kwargs.num_fp16_res = cfg.num_fp16_res
        G_kwargs.synthesis_kwargs.conv_clamp = 256
    G_kwargs.cfg = cfg.model.generator
    G_kwargs.c_dim = 0
    G_kwargs.img_resolution = cfg.get('resolution', 256)
    G_kwargs.img_channels = 3

    G = dnnlib.util.construct_class_by_name(**G_kwargs).eval().requires_grad_(False).to(DEVICE)

    return G


@torch.no_grad()
def profile_for_batch_size(G: nn.Module, cfg: DictConfig, batch_size: int):
    z = torch.randn(batch_size, G.z_dim, device=DEVICE)
    c = torch.zeros(batch_size, G.c_dim, device=DEVICE)
    t = torch.zeros(batch_size, 2, device=DEVICE)
    times = []

    for i in tqdm(range(NUM_WARMUP_ITERS), desc='Warming up'):
        torch.cuda.synchronize()
        fake_img = G(z, c=c, t=t).contiguous()
        y = fake_img[0, 0, 0, 0].item() # sync
        torch.cuda.synchronize()

    time.sleep(1)

    torch.cuda.reset_peak_memory_stats()

    with profiler.profile(record_shapes=True, use_cuda=True) as prof:
        for i in tqdm(range(NUM_PROFILE_ITERS), desc='Profiling'):
            torch.cuda.synchronize()
            start_time = time.time()
            with profiler.record_function("forward"):
                fake_img = G(z, c=c, t=t).contiguous()
                y = fake_img[0, 0, 0, 0].item() # sync
            torch.cuda.synchronize()
            times.append(time.time() - start_time)

    torch.cuda.empty_cache()
    num_imgs_processed = len(times) * batch_size
    total_time_spent = np.sum(times)
    bandwidth = num_imgs_processed / total_time_spent
    summary = prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)

    print(f'[Batch size: {batch_size}] Mean: {np.mean(times):.05f}s/it. Std: {np.std(times):.05f}s')
    print(f'[Batch size: {batch_size}] Imgs/sec: {bandwidth:.03f}')
    print(f'[Batch size: {batch_size}] Max mem: {torch.cuda.max_memory_allocated(DEVICE) / 2**30:<6.2f} gb')

    return bandwidth, summary


@hydra.main(config_path="../../configs", config_name="config.yaml")
def profile(cfg: DictConfig):
    recursive_instantiate(cfg)
    G = instantiate_G(cfg)
    bandwidths = []
    summaries = []
    print(f'Number of parameters: {sum(p.numel() for p in G.parameters())}')

    for batch_size in BATCH_SIZES:
        bandwidth, summary = profile_for_batch_size(G, cfg, batch_size)
        bandwidths.append(bandwidth)
        summaries.append(summary)

    best_batch_size_idx = int(np.argmax(bandwidths))
    print(f'------------ Best batch size is {BATCH_SIZES[best_batch_size_idx]} ------------')
    print(summaries[best_batch_size_idx])


if __name__ == '__main__':
    profile()