File size: 5,240 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import sys
import torch
import torch.nn as nn


def _print(s):
    print(s)
    sys.stdout.flush()


def get_latents(model, tokenizer, sequence, device):
    tokens = tokenizer(sequence, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**tokens)
        embeds = outputs.hidden_states[-1].squeeze(0) # Get last hidden states
    return embeds



# General model freezing
def freeze_model(model: nn.Module):
    # Disable parameter updates for all layers
    for param in model.parameters():
        param.requires_grad = False



# For ProGen2 architecture
def apply_gptj_freezing(model, N_layers):
    def unfreeze_n_layers(model, N_layers):
        # Count number of encoder layers
        model_layers = len(model.transformer.h)
        for i, h in enumerate(model.transformer.h):
            if i >= model_layers - N_layers:
                for module in h.attn.modules():
                    for param in module.parameters():
                        param.requires_grad = True

    def check_frozen_model(model, N_layers: int):
        """
        Verify that only the last N_layers of model.transformer.h are unfrozen.
        Source: https://github.com/enijkamp/progen2/blob/main/progen/modeling_progen.py
        """
        model_layers = len(model.transformer.h)
        frozen_layers = 0
        unfrozen_layers = 0
        for i, h in enumerate(model.transformer.h):
            if i >= model_layers - N_layers:  # should be unfrozen
                if any(param.requires_grad for param in h.parameters()):
                    unfrozen_layers += 1
                else:
                    print(f"Layer {i} has all parameters frozen, but it should be unfrozen.")
            else:  # should be frozen
                if any(param.requires_grad for param in h.parameters()):
                    print(f"Layer {i} is not frozen, but it should be frozen.")
                else:
                    frozen_layers += 1

        assert frozen_layers == model_layers - N_layers and unfrozen_layers == N_layers, \
            f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}"

        print(f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}")

    freeze_model(model)
    unfreeze_n_layers(model, N_layers)
    check_frozen_model(model, N_layers)





# For RDM-based architectures
def apply_rdm_freezing(model: nn.Module, N_layers: int, model_type: str):
    """
    Freeze all layers except last N for esm-like architectures

    Args:
        model (nn.Module): model to freeze
        N_layers (int): num encoder layers to unfreeze
        model_type (str): one of {"esm", "evoflow", "dplm"}
    """

    # choose encoder layers based on the model type
    if model_type == "dplm":
        encoder_layers = model.net.esm.encoder.layer
    elif model_type in ("esm", "evoflow"):
        encoder_layers = model.esm.encoder.layer
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

    def unfreeze_n_layers(layers, N_layers: int):
        model_layers = len(layers)
        for i, layer in enumerate(layers):
            if i >= model_layers - N_layers:
                for module in layer.attention.self.key.modules():
                    for param in module.parameters():
                        param.requires_grad = True
                for module in layer.attention.self.query.modules():
                    for param in module.parameters():
                        param.requires_grad = True
                for module in layer.attention.self.value.modules():
                    for param in module.parameters():
                        param.requires_grad = True

    def check_model(layers, N_layers: int):
        model_layers = len(layers)
        frozen_layers = 0
        unfrozen_layers = 0

        for i, layer in enumerate(layers):
            if i >= model_layers - N_layers:
                layer_frozen = True
                for module in layer.attention.self.key.modules():
                    if any(param.requires_grad for param in module.parameters()):
                        layer_frozen = False
                for module in layer.attention.self.query.modules():
                    if any(param.requires_grad for param in module.parameters()):
                        layer_frozen = False
                for module in layer.attention.self.value.modules():
                    if any(param.requires_grad for param in module.parameters()):
                        layer_frozen = False
                
                if layer_frozen:
                    print(f"layer {i} has all parameters frozen, but it should be unfrozen.")
                else:
                    unfrozen_layers += 1
            else:
                if any(param.requires_grad for param in layer.parameters()):
                    print(f"layer {i} is not frozen, but it should")
                else:
                    frozen_layers += 1

        assert (frozen_layers == model_layers - N_layers) and (unfrozen_layers == N_layers), \
            f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}"


    freeze_model(model)
    unfreeze_n_layers(encoder_layers, N_layers)
    check_model(encoder_layers, N_layers)