File size: 1,737 Bytes
a0d95b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn

def count_parameters(model:nn.Module):
    print(f'Counting params in {model.__class__.__name__}')
    total_params = 0
    T_param = 0

    # Use a set to store the IDs of parameters that have already been counted
    counted_param_ids = set()
    print(f"{'Parameter Name':^60} | {'Shape':^20} | {'Num Params':^20}")
    print("-" * 110)

    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            T_param = T_param + parameter.numel()
            continue

        # if not 'hypernet' in name or 'dummy' in name:
        #     continue
        # Get the unique ID of the parameter tensor in memory
        param_id = id(parameter)
        if param_id in counted_param_ids:
            # Optional: print a message to verify that sharing is working
            print(f"Skipping shared parameter: {name}")
            continue
        counted_param_ids.add(param_id)
        
        shape = list(parameter.shape)
        
        # the number of parameters in this layer
        num_params = parameter.numel()
        
        # layer name and n_params
        if 'bias' not in name:
            print(f"{name:<60} | {str(shape):<25} | {num_params:,}")

        total_params += num_params
        T_param = T_param + num_params
    print(f"Model: {model.__class__.__name__} Total Trainable Params: {total_params:,} / {T_param:,}")
    return total_params

def mark_iba_as_trainable_only(model, prefix='hypernetxs'):
    # First, freeze all parameters
    for n, p in model.named_parameters():
        # print(f'{n}, np  {p.requires_grad}')
        if prefix not in n:
            p.requires_grad = False
        else:
            p.requires_grad = True