| import torch | |
| import torch.nn as nn | |
| from torchprofile import profile_macs | |
| Byte = 8 | |
| KiB = 1024 * Byte | |
| MiB = 1024 * KiB | |
| GiB = 1024 * MiB | |
| def get_model_flops(model, inputs): | |
| num_macs = profile_macs(model, inputs) | |
| return num_macs | |
| def get_model_size(model: nn.Module, data_width=32): | |
| """ | |
| calculate the model size in bits | |
| :param data_width: #bits per element | |
| """ | |
| num_elements = 0 | |
| for param in model.parameters(): | |
| num_elements += param.numel() | |
| return num_elements * data_width | |
| def get_model_macs(model, inputs) -> int: | |
| return profile_macs(model, inputs) | |
| def get_sparsity(tensor: torch.Tensor) -> float: | |
| """ | |
| calculate the sparsity of the given tensor | |
| sparsity = #zeros / #elements = 1 - #nonzeros / #elements | |
| """ | |
| return 1 - float(tensor.count_nonzero()) / tensor.numel() | |
| def get_model_sparsity(model: nn.Module) -> float: | |
| """ | |
| calculate the sparsity of the given model | |
| sparsity = #zeros / #elements = 1 - #nonzeros / #elements | |
| """ | |
| num_nonzeros, num_elements = 0, 0 | |
| for param in model.parameters(): | |
| num_nonzeros += param.count_nonzero() | |
| num_elements += param.numel() | |
| return 1 - float(num_nonzeros) / num_elements | |
| def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int: | |
| """ | |
| calculate the total number of parameters of model | |
| :param count_nonzero_only: only count nonzero weights | |
| """ | |
| num_counted_elements = 0 | |
| for param in model.parameters(): | |
| if count_nonzero_only: | |
| num_counted_elements += param.count_nonzero() | |
| else: | |
| num_counted_elements += param.numel() | |
| return num_counted_elements | |