Change SaAMPLIFY API to be compatible with AMPLIFY v1
Browse filessrc --> input_ids
pad_mask --> attention_mask
- amplify.py +35 -49
amplify.py
CHANGED
|
@@ -16,7 +16,6 @@ from .tokenizer import ProteinTokenizer
|
|
| 16 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 17 |
from transformers.modeling_outputs import MaskedLMOutput
|
| 18 |
|
| 19 |
-
|
| 20 |
class DotDict(dict):
|
| 21 |
"""Dictionary that supports the dot notation to access attributes (similarly to HuggingFace)."""
|
| 22 |
|
|
@@ -24,10 +23,8 @@ class DotDict(dict):
|
|
| 24 |
__setattr__ = dict.__setitem__
|
| 25 |
__delattr__ = dict.__delitem__
|
| 26 |
|
| 27 |
-
|
| 28 |
class AMPLIFYConfig(PretrainedConfig):
|
| 29 |
model_type = "AMPLIFY"
|
| 30 |
-
|
| 31 |
# All config parameters must have a default value.
|
| 32 |
def __init__(
|
| 33 |
self,
|
|
@@ -51,7 +48,7 @@ class AMPLIFYConfig(PretrainedConfig):
|
|
| 51 |
**kwargs,
|
| 52 |
):
|
| 53 |
super().__init__(**kwargs)
|
| 54 |
-
|
| 55 |
self.hidden_size = hidden_size
|
| 56 |
self.num_hidden_layers = num_hidden_layers
|
| 57 |
self.num_attention_heads = num_attention_heads
|
|
@@ -69,7 +66,7 @@ class AMPLIFYConfig(PretrainedConfig):
|
|
| 69 |
self.att_bias = att_bias
|
| 70 |
self.pad_token_id = pad_token_id
|
| 71 |
self.max_length = max_length
|
| 72 |
-
|
| 73 |
|
| 74 |
class EncoderBlock(nn.Module):
|
| 75 |
"""Transformer encoder block."""
|
|
@@ -111,7 +108,12 @@ class EncoderBlock(nn.Module):
|
|
| 111 |
multiple_of = 8
|
| 112 |
intermediate_size = int(2 * config.intermediate_size / 3)
|
| 113 |
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
|
| 114 |
-
self.ffn = SwiGLU(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
elif act == "relu":
|
| 116 |
self.ffn = nn.Sequential(
|
| 117 |
nn.Linear(config.hidden_size, config.intermediate_size, bias=config.ffn_bias),
|
|
@@ -127,26 +129,18 @@ class EncoderBlock(nn.Module):
|
|
| 127 |
else:
|
| 128 |
raise ValueError(f"Unsupported hidden_act: {config.hidden_act}")
|
| 129 |
|
| 130 |
-
self.attention_norm = (
|
| 131 |
-
|
| 132 |
-
if config.rms_norm
|
| 133 |
-
else nn.LayerNorm(config.hidden_size, config.norm_eps)
|
| 134 |
-
)
|
| 135 |
-
self.ffn_norm = (
|
| 136 |
-
RMSNorm(config.hidden_size, config.norm_eps)
|
| 137 |
-
if config.rms_norm
|
| 138 |
-
else nn.LayerNorm(config.hidden_size, config.norm_eps)
|
| 139 |
-
)
|
| 140 |
|
| 141 |
self.ffn_dropout = nn.Dropout(config.dropout_prob)
|
| 142 |
|
| 143 |
-
def forward(self, x: torch.Tensor,
|
| 144 |
-
attn, contact = self._att_block(self.attention_norm(x),
|
| 145 |
x = x + attn
|
| 146 |
x = x + self._ff_block(self.ffn_norm(x))
|
| 147 |
return x, contact
|
| 148 |
|
| 149 |
-
def _att_block(self, x: torch.Tensor,
|
| 150 |
batch_size, seq_len, _ = x.shape
|
| 151 |
xq, xk, xv = self.q(x), self.k(x), self.v(x)
|
| 152 |
|
|
@@ -160,8 +154,8 @@ class EncoderBlock(nn.Module):
|
|
| 160 |
attn_weights = None
|
| 161 |
if output_attentions:
|
| 162 |
attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
|
| 163 |
-
if
|
| 164 |
-
attn_weights = attn_weights +
|
| 165 |
attn_weights = attn_weights.softmax(-1)
|
| 166 |
|
| 167 |
# Compute the attention using xformers if the tensors are on GPU
|
|
@@ -172,7 +166,7 @@ class EncoderBlock(nn.Module):
|
|
| 172 |
query=xq,
|
| 173 |
key=xk,
|
| 174 |
value=xv,
|
| 175 |
-
attn_bias=
|
| 176 |
p=self.config.dropout_prob if self.training else 0,
|
| 177 |
)
|
| 178 |
else:
|
|
@@ -181,13 +175,13 @@ class EncoderBlock(nn.Module):
|
|
| 181 |
query=xq.transpose(1, 2),
|
| 182 |
key=xk.transpose(1, 2),
|
| 183 |
value=xv.transpose(1, 2),
|
| 184 |
-
attn_mask=
|
| 185 |
dropout_p=self.config.dropout_prob if self.training else 0,
|
| 186 |
).transpose(1, 2)
|
| 187 |
|
| 188 |
attn_scores = self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.d_head))
|
| 189 |
return (self.resid_dropout(attn_scores), attn_weights)
|
| 190 |
-
|
| 191 |
def _ff_block(self, x: torch.Tensor):
|
| 192 |
return self.ffn_dropout(self.ffn(x))
|
| 193 |
|
|
@@ -207,10 +201,9 @@ class AMPLIFYPreTrainedModel(PreTrainedModel):
|
|
| 207 |
class AMPLIFY(AMPLIFYPreTrainedModel):
|
| 208 |
"""The main model class.
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
"""
|
| 213 |
-
|
| 214 |
def __init__(self, config: AMPLIFYConfig, **kwargs):
|
| 215 |
super().__init__(config)
|
| 216 |
|
|
@@ -219,30 +212,23 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
| 219 |
self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 220 |
|
| 221 |
if config.layer_norm_after_embedding:
|
| 222 |
-
self.layer_norm_1 = (
|
| 223 |
-
RMSNorm(config.hidden_size, config.norm_eps)
|
| 224 |
-
if config.rms_norm
|
| 225 |
-
else nn.LayerNorm(config.hidden_size, config.norm_eps)
|
| 226 |
-
)
|
| 227 |
|
| 228 |
self.transformer_encoder = nn.ModuleList()
|
| 229 |
for _ in range(config.num_hidden_layers):
|
| 230 |
self.transformer_encoder.append(EncoderBlock(config))
|
| 231 |
|
| 232 |
if config.layer_norm_before_last_layer:
|
| 233 |
-
self.layer_norm_2 = (
|
| 234 |
-
RMSNorm(config.hidden_size, config.norm_eps)
|
| 235 |
-
if config.rms_norm
|
| 236 |
-
else nn.LayerNorm(config.hidden_size, config.norm_eps)
|
| 237 |
-
)
|
| 238 |
|
| 239 |
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
| 240 |
|
| 241 |
self.freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
|
| 242 |
-
|
| 243 |
# Initialize weights and apply final processing
|
| 244 |
self.post_init()
|
| 245 |
|
|
|
|
| 246 |
@classmethod
|
| 247 |
def load(cls, checkpoint_path: str, config_path: str):
|
| 248 |
|
|
@@ -254,7 +240,7 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
| 254 |
if checkpoint_path.endswith(".safetensors"):
|
| 255 |
state_dict = safetensors.torch.load_file(checkpoint_path)
|
| 256 |
elif checkpoint_path.endswith(".pt"):
|
| 257 |
-
state_dict = torch.load(checkpoint_path)
|
| 258 |
else:
|
| 259 |
raise ValueError(f"Expected checkpoint to be a `.pt` or `.safetensors` file.")
|
| 260 |
|
|
@@ -262,29 +248,28 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
| 262 |
tokenizer = ProteinTokenizer(**cfg["tokenizer"])
|
| 263 |
return model, tokenizer
|
| 264 |
|
| 265 |
-
|
|
|
|
| 266 |
# Initialize
|
| 267 |
hidden_states, attentions = [], []
|
| 268 |
|
| 269 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
| 270 |
-
if
|
| 271 |
-
assert
|
| 272 |
-
|
| 273 |
-
pad_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, pad_mask.size(-1), 1)
|
| 274 |
-
)
|
| 275 |
|
| 276 |
# RoPE
|
| 277 |
-
self.freqs_cis = self.freqs_cis.to(
|
| 278 |
-
freqs_cis = self.freqs_cis[:
|
| 279 |
|
| 280 |
# Embedding
|
| 281 |
-
x = self.encoder(
|
| 282 |
if self.config.layer_norm_after_embedding:
|
| 283 |
x = self.layer_norm_1(x)
|
| 284 |
|
| 285 |
# Transformer encoder
|
| 286 |
for layer in self.transformer_encoder:
|
| 287 |
-
x, attn = layer(x,
|
| 288 |
if output_hidden_states:
|
| 289 |
hidden_states.append(x)
|
| 290 |
if output_attentions:
|
|
@@ -295,3 +280,4 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
| 295 |
|
| 296 |
# Return logits or the output of the last hidden layer
|
| 297 |
return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
|
|
|
|
|
|
| 16 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 17 |
from transformers.modeling_outputs import MaskedLMOutput
|
| 18 |
|
|
|
|
| 19 |
class DotDict(dict):
|
| 20 |
"""Dictionary that supports the dot notation to access attributes (similarly to HuggingFace)."""
|
| 21 |
|
|
|
|
| 23 |
__setattr__ = dict.__setitem__
|
| 24 |
__delattr__ = dict.__delitem__
|
| 25 |
|
|
|
|
| 26 |
class AMPLIFYConfig(PretrainedConfig):
|
| 27 |
model_type = "AMPLIFY"
|
|
|
|
| 28 |
# All config parameters must have a default value.
|
| 29 |
def __init__(
|
| 30 |
self,
|
|
|
|
| 48 |
**kwargs,
|
| 49 |
):
|
| 50 |
super().__init__(**kwargs)
|
| 51 |
+
|
| 52 |
self.hidden_size = hidden_size
|
| 53 |
self.num_hidden_layers = num_hidden_layers
|
| 54 |
self.num_attention_heads = num_attention_heads
|
|
|
|
| 66 |
self.att_bias = att_bias
|
| 67 |
self.pad_token_id = pad_token_id
|
| 68 |
self.max_length = max_length
|
| 69 |
+
|
| 70 |
|
| 71 |
class EncoderBlock(nn.Module):
|
| 72 |
"""Transformer encoder block."""
|
|
|
|
| 108 |
multiple_of = 8
|
| 109 |
intermediate_size = int(2 * config.intermediate_size / 3)
|
| 110 |
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
|
| 111 |
+
self.ffn = SwiGLU(
|
| 112 |
+
config.hidden_size,
|
| 113 |
+
intermediate_size,
|
| 114 |
+
config.hidden_size,
|
| 115 |
+
bias=config.ffn_bias
|
| 116 |
+
)
|
| 117 |
elif act == "relu":
|
| 118 |
self.ffn = nn.Sequential(
|
| 119 |
nn.Linear(config.hidden_size, config.intermediate_size, bias=config.ffn_bias),
|
|
|
|
| 129 |
else:
|
| 130 |
raise ValueError(f"Unsupported hidden_act: {config.hidden_act}")
|
| 131 |
|
| 132 |
+
self.attention_norm = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
|
| 133 |
+
self.ffn_norm = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
self.ffn_dropout = nn.Dropout(config.dropout_prob)
|
| 136 |
|
| 137 |
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
|
| 138 |
+
attn, contact = self._att_block(self.attention_norm(x), attention_mask, freqs_cis, output_attentions)
|
| 139 |
x = x + attn
|
| 140 |
x = x + self._ff_block(self.ffn_norm(x))
|
| 141 |
return x, contact
|
| 142 |
|
| 143 |
+
def _att_block(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
|
| 144 |
batch_size, seq_len, _ = x.shape
|
| 145 |
xq, xk, xv = self.q(x), self.k(x), self.v(x)
|
| 146 |
|
|
|
|
| 154 |
attn_weights = None
|
| 155 |
if output_attentions:
|
| 156 |
attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
|
| 157 |
+
if attention_mask is not None:
|
| 158 |
+
attn_weights = attn_weights + attention_mask
|
| 159 |
attn_weights = attn_weights.softmax(-1)
|
| 160 |
|
| 161 |
# Compute the attention using xformers if the tensors are on GPU
|
|
|
|
| 166 |
query=xq,
|
| 167 |
key=xk,
|
| 168 |
value=xv,
|
| 169 |
+
attn_bias=attention_mask,
|
| 170 |
p=self.config.dropout_prob if self.training else 0,
|
| 171 |
)
|
| 172 |
else:
|
|
|
|
| 175 |
query=xq.transpose(1, 2),
|
| 176 |
key=xk.transpose(1, 2),
|
| 177 |
value=xv.transpose(1, 2),
|
| 178 |
+
attn_mask=attention_mask,
|
| 179 |
dropout_p=self.config.dropout_prob if self.training else 0,
|
| 180 |
).transpose(1, 2)
|
| 181 |
|
| 182 |
attn_scores = self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.d_head))
|
| 183 |
return (self.resid_dropout(attn_scores), attn_weights)
|
| 184 |
+
|
| 185 |
def _ff_block(self, x: torch.Tensor):
|
| 186 |
return self.ffn_dropout(self.ffn(x))
|
| 187 |
|
|
|
|
| 201 |
class AMPLIFY(AMPLIFYPreTrainedModel):
|
| 202 |
"""The main model class.
|
| 203 |
|
| 204 |
+
Args:
|
| 205 |
+
config (amplify.model.amplify.AMPLIFYConfig): model configuration, usually defined from the Hydra configuration.
|
| 206 |
"""
|
|
|
|
| 207 |
def __init__(self, config: AMPLIFYConfig, **kwargs):
|
| 208 |
super().__init__(config)
|
| 209 |
|
|
|
|
| 212 |
self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 213 |
|
| 214 |
if config.layer_norm_after_embedding:
|
| 215 |
+
self.layer_norm_1 = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
self.transformer_encoder = nn.ModuleList()
|
| 218 |
for _ in range(config.num_hidden_layers):
|
| 219 |
self.transformer_encoder.append(EncoderBlock(config))
|
| 220 |
|
| 221 |
if config.layer_norm_before_last_layer:
|
| 222 |
+
self.layer_norm_2 = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
| 225 |
|
| 226 |
self.freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
|
| 227 |
+
|
| 228 |
# Initialize weights and apply final processing
|
| 229 |
self.post_init()
|
| 230 |
|
| 231 |
+
|
| 232 |
@classmethod
|
| 233 |
def load(cls, checkpoint_path: str, config_path: str):
|
| 234 |
|
|
|
|
| 240 |
if checkpoint_path.endswith(".safetensors"):
|
| 241 |
state_dict = safetensors.torch.load_file(checkpoint_path)
|
| 242 |
elif checkpoint_path.endswith(".pt"):
|
| 243 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
| 244 |
else:
|
| 245 |
raise ValueError(f"Expected checkpoint to be a `.pt` or `.safetensors` file.")
|
| 246 |
|
|
|
|
| 248 |
tokenizer = ProteinTokenizer(**cfg["tokenizer"])
|
| 249 |
return model, tokenizer
|
| 250 |
|
| 251 |
+
|
| 252 |
+
def forward(self, input_ids, attention_mask=None, output_hidden_states=False, output_attentions=False):
|
| 253 |
# Initialize
|
| 254 |
hidden_states, attentions = [], []
|
| 255 |
|
| 256 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
| 257 |
+
if attention_mask is not None:
|
| 258 |
+
assert attention_mask.dtype != torch.bool and 1.0 not in attention_mask, "AMPLIFY expects an additive attention_mask"
|
| 259 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
|
|
|
|
|
|
|
| 260 |
|
| 261 |
# RoPE
|
| 262 |
+
self.freqs_cis = self.freqs_cis.to(input_ids.device, non_blocking=True)
|
| 263 |
+
freqs_cis = self.freqs_cis[: input_ids.shape[1]]
|
| 264 |
|
| 265 |
# Embedding
|
| 266 |
+
x = self.encoder(input_ids)
|
| 267 |
if self.config.layer_norm_after_embedding:
|
| 268 |
x = self.layer_norm_1(x)
|
| 269 |
|
| 270 |
# Transformer encoder
|
| 271 |
for layer in self.transformer_encoder:
|
| 272 |
+
x, attn = layer(x, attention_mask, freqs_cis, output_attentions)
|
| 273 |
if output_hidden_states:
|
| 274 |
hidden_states.append(x)
|
| 275 |
if output_attentions:
|
|
|
|
| 280 |
|
| 281 |
# Return logits or the output of the last hidden layer
|
| 282 |
return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
|
| 283 |
+
|