ProArd / utils /profile.py
smi08's picture
Upload folder using huggingface_hub
d008243 verified
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torchprofile import profile_macs
__all__ = ["is_parallel", "get_module_device", "trainable_param_num", "inference_macs"]
def is_parallel(model: nn.Module) -> bool:
return isinstance(
model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
)
def get_module_device(module: nn.Module) -> torch.device:
return module.parameters().__next__().device
def trainable_param_num(network: nn.Module, unit=1e6) -> float:
return sum(p.numel() for p in network.parameters() if p.requires_grad) / unit
def inference_macs(
network: nn.Module,
args: Tuple = (),
data_shape: Optional[Tuple] = None,
unit: float = 1e6,
) -> float:
if is_parallel(network):
network = network.module
if data_shape is not None:
if len(args) > 0:
raise ValueError("Please provide either data_shape or args tuple.")
args = (torch.zeros(data_shape, device=get_module_device(network)),)
is_training = network.training
network.eval()
macs = profile_macs(network, args=args) / unit
network.train(is_training)
return macs