| | |
| | |
| | |
| | |
| |
|
| | import humanfriendly |
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def get_human_readable_count(number: int) -> str: |
| | """Return human_readable_count |
| | |
| | Originated from: |
| | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py |
| | |
| | Abbreviates an integer number with K, M, B, T for thousands, millions, |
| | billions and trillions, respectively. |
| | Examples: |
| | >>> get_human_readable_count(123) |
| | '123 ' |
| | >>> get_human_readable_count(1234) # (one thousand) |
| | '1 K' |
| | >>> get_human_readable_count(2e6) # (two million) |
| | '2 M' |
| | >>> get_human_readable_count(3e9) # (three billion) |
| | '3 B' |
| | >>> get_human_readable_count(4e12) # (four trillion) |
| | '4 T' |
| | >>> get_human_readable_count(5e15) # (more than trillion) |
| | '5,000 T' |
| | Args: |
| | number: a positive integer number |
| | Return: |
| | A string formatted according to the pattern described above. |
| | """ |
| | assert number >= 0 |
| | labels = [" ", "K", "M", "B", "T"] |
| | num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) |
| | num_groups = int(np.ceil(num_digits / 3)) |
| | num_groups = min(num_groups, len(labels)) |
| | shift = -3 * (num_groups - 1) |
| | number = number * (10**shift) |
| | index = num_groups - 1 |
| | return f"{number:.2f} {labels[index]}" |
| |
|
| |
|
| | def to_bytes(dtype) -> int: |
| | return int(str(dtype)[-2:]) // 8 |
| |
|
| |
|
| | def model_summary(model: torch.nn.Module) -> str: |
| | message = "Model structure:\n" |
| | message += str(model) |
| | tot_params = sum(p.numel() for p in model.parameters()) |
| | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params) |
| | tot_params = get_human_readable_count(tot_params) |
| | num_params = get_human_readable_count(num_params) |
| | message += "\n\nModel summary:\n" |
| | message += f" Class Name: {model.__class__.__name__}\n" |
| | message += f" Total Number of model parameters: {tot_params}\n" |
| | message += ( |
| | f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n" |
| | ) |
| | num_bytes = humanfriendly.format_size( |
| | sum( |
| | p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad |
| | ) |
| | ) |
| | message += f" Size: {num_bytes}\n" |
| | dtype = next(iter(model.parameters())).dtype |
| | message += f" Type: {dtype}" |
| | return message |
| |
|