File size: 1,179 Bytes
d008243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Optional, Tuple

import torch
import torch.nn as nn
from torchprofile import profile_macs

__all__ = ["is_parallel", "get_module_device", "trainable_param_num", "inference_macs"]


def is_parallel(model: nn.Module) -> bool:
    return isinstance(
        model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
    )


def get_module_device(module: nn.Module) -> torch.device:
    return module.parameters().__next__().device


def trainable_param_num(network: nn.Module, unit=1e6) -> float:
    return sum(p.numel() for p in network.parameters() if p.requires_grad) / unit


def inference_macs(
    network: nn.Module,
    args: Tuple = (),
    data_shape: Optional[Tuple] = None,
    unit: float = 1e6,
) -> float:
    if is_parallel(network):
        network = network.module
    if data_shape is not None:
        if len(args) > 0:
            raise ValueError("Please provide either data_shape or args tuple.")
        args = (torch.zeros(data_shape, device=get_module_device(network)),)
    is_training = network.training
    network.eval()
    macs = profile_macs(network, args=args) / unit
    network.train(is_training)
    return macs