File size: 5,240 Bytes
d04a061 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import sys
import torch
import torch.nn as nn
def _print(s):
print(s)
sys.stdout.flush()
def get_latents(model, tokenizer, sequence, device):
tokens = tokenizer(sequence, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**tokens)
embeds = outputs.hidden_states[-1].squeeze(0) # Get last hidden states
return embeds
# General model freezing
def freeze_model(model: nn.Module):
# Disable parameter updates for all layers
for param in model.parameters():
param.requires_grad = False
# For ProGen2 architecture
def apply_gptj_freezing(model, N_layers):
def unfreeze_n_layers(model, N_layers):
# Count number of encoder layers
model_layers = len(model.transformer.h)
for i, h in enumerate(model.transformer.h):
if i >= model_layers - N_layers:
for module in h.attn.modules():
for param in module.parameters():
param.requires_grad = True
def check_frozen_model(model, N_layers: int):
"""
Verify that only the last N_layers of model.transformer.h are unfrozen.
Source: https://github.com/enijkamp/progen2/blob/main/progen/modeling_progen.py
"""
model_layers = len(model.transformer.h)
frozen_layers = 0
unfrozen_layers = 0
for i, h in enumerate(model.transformer.h):
if i >= model_layers - N_layers: # should be unfrozen
if any(param.requires_grad for param in h.parameters()):
unfrozen_layers += 1
else:
print(f"Layer {i} has all parameters frozen, but it should be unfrozen.")
else: # should be frozen
if any(param.requires_grad for param in h.parameters()):
print(f"Layer {i} is not frozen, but it should be frozen.")
else:
frozen_layers += 1
assert frozen_layers == model_layers - N_layers and unfrozen_layers == N_layers, \
f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}"
print(f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}")
freeze_model(model)
unfreeze_n_layers(model, N_layers)
check_frozen_model(model, N_layers)
# For RDM-based architectures
def apply_rdm_freezing(model: nn.Module, N_layers: int, model_type: str):
"""
Freeze all layers except last N for esm-like architectures
Args:
model (nn.Module): model to freeze
N_layers (int): num encoder layers to unfreeze
model_type (str): one of {"esm", "evoflow", "dplm"}
"""
# choose encoder layers based on the model type
if model_type == "dplm":
encoder_layers = model.net.esm.encoder.layer
elif model_type in ("esm", "evoflow"):
encoder_layers = model.esm.encoder.layer
else:
raise ValueError(f"Unknown model_type: {model_type}")
def unfreeze_n_layers(layers, N_layers: int):
model_layers = len(layers)
for i, layer in enumerate(layers):
if i >= model_layers - N_layers:
for module in layer.attention.self.key.modules():
for param in module.parameters():
param.requires_grad = True
for module in layer.attention.self.query.modules():
for param in module.parameters():
param.requires_grad = True
for module in layer.attention.self.value.modules():
for param in module.parameters():
param.requires_grad = True
def check_model(layers, N_layers: int):
model_layers = len(layers)
frozen_layers = 0
unfrozen_layers = 0
for i, layer in enumerate(layers):
if i >= model_layers - N_layers:
layer_frozen = True
for module in layer.attention.self.key.modules():
if any(param.requires_grad for param in module.parameters()):
layer_frozen = False
for module in layer.attention.self.query.modules():
if any(param.requires_grad for param in module.parameters()):
layer_frozen = False
for module in layer.attention.self.value.modules():
if any(param.requires_grad for param in module.parameters()):
layer_frozen = False
if layer_frozen:
print(f"layer {i} has all parameters frozen, but it should be unfrozen.")
else:
unfrozen_layers += 1
else:
if any(param.requires_grad for param in layer.parameters()):
print(f"layer {i} is not frozen, but it should")
else:
frozen_layers += 1
assert (frozen_layers == model_layers - N_layers) and (unfrozen_layers == N_layers), \
f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}"
freeze_model(model)
unfreeze_n_layers(encoder_layers, N_layers)
check_model(encoder_layers, N_layers)
|