| import torch | |
| import torch.nn as nn | |
| def count_parameters(model:nn.Module): | |
| print(f'Counting params in {model.__class__.__name__}') | |
| total_params = 0 | |
| T_param = 0 | |
| # Use a set to store the IDs of parameters that have already been counted | |
| counted_param_ids = set() | |
| print(f"{'Parameter Name':^60} | {'Shape':^20} | {'Num Params':^20}") | |
| print("-" * 110) | |
| for name, parameter in model.named_parameters(): | |
| if not parameter.requires_grad: | |
| T_param = T_param + parameter.numel() | |
| continue | |
| # if not 'hypernet' in name or 'dummy' in name: | |
| # continue | |
| # Get the unique ID of the parameter tensor in memory | |
| param_id = id(parameter) | |
| if param_id in counted_param_ids: | |
| # Optional: print a message to verify that sharing is working | |
| print(f"Skipping shared parameter: {name}") | |
| continue | |
| counted_param_ids.add(param_id) | |
| shape = list(parameter.shape) | |
| # the number of parameters in this layer | |
| num_params = parameter.numel() | |
| # layer name and n_params | |
| if 'bias' not in name: | |
| print(f"{name:<60} | {str(shape):<25} | {num_params:,}") | |
| total_params += num_params | |
| T_param = T_param + num_params | |
| print(f"Model: {model.__class__.__name__} Total Trainable Params: {total_params:,} / {T_param:,}") | |
| return total_params | |
| def mark_iba_as_trainable_only(model, prefix='hypernetxs'): | |
| # First, freeze all parameters | |
| for n, p in model.named_parameters(): | |
| # print(f'{n}, np {p.requires_grad}') | |
| if prefix not in n: | |
| p.requires_grad = False | |
| else: | |
| p.requires_grad = True | |