Lolalb commited on
Commit
4fdd722
·
verified ·
1 Parent(s): 03f51a4

Upload AMPLIFY

Browse files
Files changed (4) hide show
  1. amplify.py +453 -0
  2. config.json +42 -0
  3. model.safetensors +3 -0
  4. rotary.py +28 -0
amplify.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://stackoverflow.com/a/23689767
2
+ # From https://github.com/pytorch/pytorch/issues/97899
3
+ # From https://github.com/facebookresearch/llama/blob/main/llama/model.py
4
+ import yaml
5
+ import os
6
+
7
+ import safetensors
8
+ import torch
9
+ from torch import nn
10
+
11
+ from torch.nn.functional import scaled_dot_product_attention
12
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
13
+
14
+ from transformers import PreTrainedModel, PretrainedConfig
15
+ from transformers.modeling_outputs import MaskedLMOutput
16
+
17
+ from .rotary import precompute_freqs_cis, apply_rotary_emb
18
+ from .tokenizer import ProteinTokenizer
19
+
20
+
21
+ class DotDict(dict):
22
+ """Dictionary that supports the dot notation to access attributes (similarly to HuggingFace)."""
23
+
24
+ __getattr__ = dict.get
25
+ __setattr__ = dict.__setitem__
26
+ __delattr__ = dict.__delitem__
27
+
28
+
29
+ class AMPLIFYConfig(PretrainedConfig):
30
+ model_type = "AMPLIFY"
31
+
32
+ # All config parameters must have a default value.
33
+ def __init__(
34
+ self,
35
+ hidden_size: int = 960,
36
+ num_hidden_layers: int = 32,
37
+ num_attention_heads: int = 15,
38
+ intermediate_size: int = 3840,
39
+ embedding_init_range: float = 0.02,
40
+ decoder_init_range: float = 0.02,
41
+ norm_eps: float = 1e-05,
42
+ vocab_size: int = 32,
43
+ pad_token_id: int = 0,
44
+ max_length: int = 2048,
45
+ max_protein_length: int = 50000,
46
+ base_scale: float = 1.0 / (960.0**0.5),
47
+ normalized_transformer: bool = False,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(**kwargs)
51
+
52
+ self.hidden_size = hidden_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_attention_heads = num_attention_heads
55
+ self.intermediate_size = intermediate_size
56
+ self.embedding_init_range = embedding_init_range
57
+ self.decoder_init_range = decoder_init_range
58
+ self.norm_eps = norm_eps
59
+ self.vocab_size = vocab_size
60
+ self.pad_token_id = pad_token_id
61
+ self.max_length = max_length
62
+ self.max_protein_length = max_protein_length
63
+ self.base_scale = base_scale
64
+ self.normalized_transformer = normalized_transformer
65
+
66
+
67
+ class EncoderBlock(nn.Module):
68
+ """Transformer encoder block."""
69
+
70
+ def __init__(self, config: AMPLIFYConfig):
71
+ """Initialize a EncoderBlock.
72
+
73
+ Args:
74
+ hidden_size (int): _description_
75
+ num_attention_heads (int): _description_
76
+ intermediate_size (int, optional): _description_. Defaults to 2048.
77
+ activation (str, optional): _description_. Defaults to "relu".
78
+ rms_norm (bool, optional): _description_. Defaults to True.
79
+ norm_eps (float, optional): _description_. Defaults to 1e-5.
80
+ pad_token_id (int, optional): _description_. Defaults to 0.
81
+ max_length (int, optional): _description_. Defaults to 2048.
82
+ """
83
+ super().__init__()
84
+
85
+ self.config = config
86
+ self.d_head = config.hidden_size // config.num_attention_heads
87
+
88
+ # Attention
89
+ self.qkv = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size * 3, bias=False)
90
+ self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=False)
91
+
92
+ # Feedforward network with SwiGLU
93
+ # To keep the number of parameters and the amount of computation constant, we reduce the number of
94
+ # hidden units by a factor of 2/3 (https://arxiv.org/pdf/2002.05202.pdf) and make it a multiple of 8 to
95
+ # avoid RuntimeError due to misaligned operand
96
+ multiple_of = 8
97
+ intermediate_size = multiple_of * ((int(2 * config.intermediate_size / 3) + multiple_of - 1) // multiple_of)
98
+
99
+ # Feedforward network
100
+ self.c_fc = nn.Linear(config.hidden_size, 2 * intermediate_size, bias=False)
101
+ self.silu = nn.SiLU()
102
+ self.mlp_c_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
103
+
104
+ self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
105
+ self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
106
+
107
+ def forward(
108
+ self,
109
+ x: torch.Tensor,
110
+ attention_mask: torch.Tensor,
111
+ freqs_cis: torch.Tensor,
112
+ output_attentions: bool,
113
+ max_seqlen: int = None,
114
+ cu_seqlens: torch.Tensor = None,
115
+ ):
116
+ batch_size, seq_len, _ = x.shape
117
+
118
+ # Reshape for rotary embeddings
119
+ xq, xk, xv = (
120
+ self.qkv(self.attention_norm(x))
121
+ .reshape(batch_size, seq_len, self.config.num_attention_heads, self.d_head * 3)
122
+ .chunk(3, axis=-1)
123
+ )
124
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
125
+
126
+ # Attn block
127
+ attn_weights = None
128
+
129
+ # Flash attention if the tensors are packed
130
+ if cu_seqlens is not None:
131
+ attn = flash_attn_varlen_func(
132
+ q=xq.squeeze(0),
133
+ k=xk.squeeze(0),
134
+ v=xv.squeeze(0),
135
+ cu_seqlens_q=cu_seqlens.squeeze(),
136
+ cu_seqlens_k=cu_seqlens.squeeze(),
137
+ max_seqlen_q=max_seqlen,
138
+ max_seqlen_k=max_seqlen,
139
+ dropout_p=0.0,
140
+ causal=False,
141
+ )
142
+
143
+ # Eager attention if attention weights are needed in the output
144
+ elif output_attentions:
145
+ attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
146
+ if attention_mask is not None:
147
+ attn_weights = attn_weights * attention_mask
148
+ attn_weights = attn_weights.softmax(-1)
149
+ attn = attn_weights @ xv.permute(0, 2, 1, 3)
150
+ attn = attn.transpose(1, 2)
151
+
152
+ # SDPA will pick an appropriate backend otherwise
153
+ else:
154
+ attn = scaled_dot_product_attention(
155
+ query=xq.transpose(1, 2),
156
+ key=xk.transpose(1, 2),
157
+ value=xv.transpose(1, 2),
158
+ attn_mask=attention_mask.bool() if attention_mask is not None else None,
159
+ dropout_p=0,
160
+ ).transpose(1, 2)
161
+
162
+ attn = self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.d_head))
163
+
164
+ # Residual stream
165
+ x = x + attn
166
+
167
+ # FFN block
168
+ uv = self.c_fc(self.ffn_norm(x))
169
+ u, v = torch.chunk(uv, 2, dim=-1)
170
+ x_mlp = u * self.silu(v)
171
+ h_mlp = self.mlp_c_proj(x_mlp)
172
+
173
+ # Residual stream
174
+ x = x + h_mlp
175
+
176
+ return x, attn_weights
177
+
178
+
179
+ class NEncoderBlock(nn.Module):
180
+ """Transformer encoder block."""
181
+
182
+ def __init__(self, config: AMPLIFYConfig):
183
+ """Initialize a EncoderBlock.
184
+
185
+ Args:
186
+ hidden_size (int): _description_
187
+ num_attention_heads (int): _description_
188
+ intermediate_size (int, optional): _description_. Defaults to 2048.
189
+ activation (str, optional): _description_. Defaults to "relu".
190
+ rms_norm (bool, optional): _description_. Defaults to True.
191
+ norm_eps (float, optional): _description_. Defaults to 1e-5.
192
+ pad_token_id (int, optional): _description_. Defaults to 0.
193
+ max_length (int, optional): _description_. Defaults to 2048.
194
+ """
195
+ super().__init__()
196
+
197
+ self.config = config
198
+ self.d_head = config.hidden_size // config.num_attention_heads
199
+
200
+ # Attention
201
+ self.qkv = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size * 3, bias=False)
202
+ self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=False)
203
+
204
+ # To keep the number of parameters and the amount of computation constant, we reduce the number of
205
+ # hidden units by a factor of 2/3 (https://arxiv.org/pdf/2002.05202.pdf) and make it a multiple of 8 to
206
+ # avoid RuntimeError due to misaligned operand
207
+ multiple_of = 8
208
+ intermediate_size = multiple_of * ((int(2 * config.intermediate_size / 3) + multiple_of - 1) // multiple_of)
209
+
210
+ # Feedforward network
211
+ self.c_fc = nn.Linear(config.hidden_size, 2 * intermediate_size, bias=False)
212
+ self.silu = nn.SiLU()
213
+ self.mlp_c_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
214
+
215
+ # Normalized Transformer
216
+ self.attn_alpha_init_value = 0.05
217
+ self.attn_alpha_init_scaling = config.base_scale
218
+ self.attn_alpha = torch.nn.Parameter(self.attn_alpha_init_scaling * torch.ones(self.config.hidden_size))
219
+
220
+ self.mlp_alpha_init_value = 0.05
221
+ self.mlp_alpha_init_scaling = config.base_scale
222
+ self.mlp_alpha = torch.nn.Parameter(self.mlp_alpha_init_scaling * torch.ones(self.config.hidden_size))
223
+
224
+ self.sqk_init_value = 1.0
225
+ self.sqk_init_scaling = config.base_scale
226
+ self.sqk = torch.nn.Parameter(self.sqk_init_scaling * torch.ones(self.config.hidden_size))
227
+
228
+ self.suv_init_value = 1.0
229
+ self.suv_init_scaling = 1.0
230
+ self.suv = torch.nn.Parameter(self.suv_init_scaling * torch.ones(2 * 4 * config.hidden_size))
231
+
232
+ def forward(
233
+ self,
234
+ x: torch.Tensor,
235
+ attention_mask: torch.Tensor,
236
+ freqs_cis: torch.Tensor,
237
+ output_attentions: bool,
238
+ max_seqlen: int = None,
239
+ cu_seqlens: torch.Tensor = None,
240
+ ):
241
+ batch_size, seq_len, _ = x.shape
242
+
243
+ # Reshape for rotary embeddings
244
+ xq, xk, xv = self.qkv(x).reshape(batch_size, seq_len, self.config.num_attention_heads, self.d_head * 3).chunk(3, axis=-1)
245
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
246
+
247
+ sqk = (self.sqk * (self.sqk_init_value / self.sqk_init_scaling)).reshape(
248
+ 1, 1, self.config.num_attention_heads, self.config.hidden_size // self.config.num_attention_heads
249
+ )
250
+ xq = sqk * self.justnorm(xq)
251
+ xk = sqk * self.justnorm(xk)
252
+
253
+ softmax_scale = (self.config.hidden_size / self.config.num_attention_heads) ** 0.5
254
+
255
+ # Attn block
256
+ attn_weights = None
257
+
258
+ # Flash attention if the tensors are packed
259
+ if cu_seqlens is not None:
260
+ attn = flash_attn_varlen_func(
261
+ q=xq.squeeze(0),
262
+ k=xk.squeeze(0),
263
+ v=xv.squeeze(0),
264
+ cu_seqlens_q=cu_seqlens,
265
+ cu_seqlens_k=cu_seqlens,
266
+ max_seqlen_q=max_seqlen,
267
+ max_seqlen_k=max_seqlen,
268
+ dropout_p=0.0,
269
+ causal=False,
270
+ softmax_scale=softmax_scale,
271
+ )
272
+
273
+ # Eager attention if attention weights are needed in the output
274
+ elif output_attentions:
275
+ attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / softmax_scale
276
+ if attention_mask is not None:
277
+ attn_weights = attn_weights + attention_mask.type(attn_weights.dtype)
278
+ attn_weights = attn_weights.softmax(-1)
279
+ attn = attn_weights @ xv.permute(0, 2, 1, 3)
280
+ attn = attn.transpose(1, 2)
281
+
282
+ # SDPA will pick an appropriate backend otherwise
283
+ else:
284
+ attn = scaled_dot_product_attention(
285
+ query=xq.transpose(1, 2),
286
+ key=xk.transpose(1, 2),
287
+ value=xv.transpose(1, 2),
288
+ attn_mask=attention_mask,
289
+ dropout_p=0,
290
+ scale=softmax_scale,
291
+ ).transpose(1, 2)
292
+
293
+ attn_scores = self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.d_head))
294
+
295
+ lr = self.attn_alpha * (self.attn_alpha_init_value / self.attn_alpha_init_scaling)
296
+ lr = torch.abs(lr)
297
+
298
+ A_norm = self.justnorm(x) # normally, normalization is not needed
299
+ B_norm = self.justnorm(attn_scores)
300
+
301
+ # Residual stream
302
+ res = A_norm + lr * (B_norm - A_norm)
303
+ x = self.justnorm(res)
304
+
305
+ # FFN block
306
+ uv = self.c_fc(x)
307
+ suv = self.suv * ((self.suv_init_value / self.suv_init_scaling) * (self.config.hidden_size**0.5))
308
+ uv = suv * uv
309
+ u, v = torch.chunk(uv, 2, dim=-1)
310
+ x_mlp = u * self.silu(v)
311
+ h_mlp = self.mlp_c_proj(x_mlp)
312
+
313
+ lr = self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling)
314
+ lr = torch.abs(lr)
315
+
316
+ A_norm = self.justnorm(x) # normally, normalization is not needed
317
+ B_norm = self.justnorm(h_mlp)
318
+
319
+ # Residual stream
320
+ res = A_norm + lr * (B_norm - A_norm)
321
+ x = self.justnorm(res)
322
+
323
+ return (x, attn_weights)
324
+
325
+ def justnorm(self, x):
326
+ return x / x.norm(p=2, dim=-1, keepdim=True)
327
+
328
+
329
+ class AMPLIFYPreTrainedModel(PreTrainedModel):
330
+ config_class = AMPLIFYConfig
331
+
332
+ def _init_weights(self, module):
333
+ if isinstance(module, nn.Linear):
334
+ module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range)
335
+ elif isinstance(module, nn.Embedding):
336
+ module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range)
337
+
338
+
339
+ class AMPLIFY(AMPLIFYPreTrainedModel):
340
+ """The main model class.
341
+
342
+ Args:
343
+ config (amplify.model.amplify.AMPLIFYConfig): model configuration, usually defined from the Hydra configuration.
344
+ """
345
+
346
+ def __init__(self, config: AMPLIFYConfig, **kwargs):
347
+ super().__init__(config)
348
+
349
+ self.config = config
350
+
351
+ self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
352
+
353
+ self.transformer_encoder = nn.ModuleList()
354
+ for _ in range(config.num_hidden_layers):
355
+ self.transformer_encoder.append(NEncoderBlock(config) if self.config.normalized_transformer else EncoderBlock(config))
356
+
357
+ if not self.config.normalized_transformer:
358
+ self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
359
+
360
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
361
+
362
+ freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_protein_length * 2)
363
+
364
+ # Ensures freqs_cis is moved to the same devices as the model. Non-persistent buffers are not saved in the state_dict.
365
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
366
+
367
+ # Initialize weights and apply final processing
368
+ self.post_init()
369
+
370
+ @classmethod
371
+ def load(cls, checkpoint_path: str, config_path: str, vocab_path: str = None, tag: str = None):
372
+
373
+ with open(config_path, "r") as file:
374
+ cfg = yaml.safe_load(file)
375
+
376
+ if vocab_path is not None:
377
+ cfg["tokenizer"]["vocab_path"] = vocab_path
378
+
379
+ model = AMPLIFY(AMPLIFYConfig(**cfg["model"], **cfg["tokenizer"]))
380
+
381
+ if os.path.isdir(checkpoint_path):
382
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
383
+
384
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_path, tag=tag)
385
+ elif checkpoint_path.endswith(".safetensors"):
386
+ state_dict = safetensors.torch.load_file(checkpoint_path)
387
+ elif checkpoint_path.endswith(".pt"):
388
+ state_dict = torch.load(checkpoint_path)
389
+ else:
390
+ raise ValueError(f"Expected checkpoint to be a deepspeed folder, `.pt`, or `.safetensors` file.")
391
+
392
+ for key in list(state_dict.keys()):
393
+ if key.startswith("_orig_mod."):
394
+ new_key = key[len("_orig_mod.") :]
395
+ state_dict[new_key] = state_dict.pop(key)
396
+ key = new_key
397
+ if "ffn.w12" in key:
398
+ new_key = key.replace("ffn.w12", "c_fc")
399
+ state_dict[new_key] = state_dict.pop(key)
400
+ elif "ffn.w3" in key:
401
+ new_key = key.replace("ffn.w3", "mlp_c_proj")
402
+ state_dict[new_key] = state_dict.pop(key)
403
+
404
+ model.load_state_dict(state_dict)
405
+ tokenizer = ProteinTokenizer(**cfg["tokenizer"], max_length=cfg["trainer"]["train"]["max_length"])
406
+ return model, tokenizer
407
+
408
+ def forward(
409
+ self,
410
+ input_ids: torch.Tensor,
411
+ position_ids: torch.Tensor = None,
412
+ max_seqlen: int = None,
413
+ cu_seqlens: torch.Tensor = None,
414
+ attention_mask: torch.Tensor = None,
415
+ output_hidden_states: bool = False,
416
+ output_attentions: bool = False,
417
+ ):
418
+ # Initialize
419
+ hidden_states, attentions = [], []
420
+
421
+ # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
422
+ if attention_mask is not None:
423
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
424
+
425
+ # Checks to be done if inputs are packed sequences
426
+ if cu_seqlens is not None:
427
+ assert not output_attentions, "Output attentions is not supported when sequences are packed."
428
+ assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
429
+ assert input_ids.shape[0] == 1, "Cumulative sequence lengths are provided but input_ids are not packed."
430
+ assert input_ids.is_cuda, "Packing uses an implementation of flash-attention and is only supported on GPU."
431
+
432
+ # RoPE
433
+ if position_ids is not None:
434
+ freqs_cis = self.freqs_cis[position_ids]
435
+ else:
436
+ freqs_cis = self.freqs_cis[: input_ids.shape[1]].unsqueeze(0).repeat(input_ids.shape[0], 1, 1)
437
+
438
+ # Embedding
439
+ x = self.encoder(input_ids)
440
+
441
+ # Transformer encoder
442
+ for layer in self.transformer_encoder:
443
+ x, attn = layer(x, attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
444
+ if output_hidden_states:
445
+ hidden_states.append(x)
446
+ if output_attentions:
447
+ attentions.append(attn)
448
+
449
+ # Classification head with layer norm
450
+ logits = self.decoder(self.layer_norm(x) if not self.config.normalized_transformer else x)
451
+
452
+ # Return logits or the output of the last hidden layer
453
+ return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_": "AMPLIFY",
3
+ "ambiguous_token_ids": [
4
+ 1,
5
+ 6,
6
+ 7,
7
+ 8,
8
+ 9,
9
+ 10,
10
+ 11
11
+ ],
12
+ "architectures": [
13
+ "AMPLIFY"
14
+ ],
15
+ "auto_map": {
16
+ "AutoConfig": "amplify.AMPLIFYConfig",
17
+ "AutoModel": "amplify.AMPLIFY"
18
+ },
19
+ "base_scale": 0.03227486121839514,
20
+ "bos_token_id": 3,
21
+ "decoder_init_range": 0.02,
22
+ "embedding_init_range": 0.02,
23
+ "eos_token_id": 4,
24
+ "hidden_size": 640,
25
+ "intermediate_size": 2560,
26
+ "mask_token_id": 2,
27
+ "max_length": 2048,
28
+ "max_protein_length": 50000,
29
+ "model_type": "AMPLIFY",
30
+ "norm_eps": 1e-05,
31
+ "normalized_transformer": false,
32
+ "num_attention_heads": 10,
33
+ "num_hidden_layers": 24,
34
+ "other_special_token_ids": null,
35
+ "pad_token_id": 0,
36
+ "remove_ambiguous": true,
37
+ "torch_dtype": "float32",
38
+ "transformers_version": "4.49.0",
39
+ "unk_token_id": 1,
40
+ "vocab_path": "/home/mila/l/lola.lebreton/AMPLIFY-private/conf/tokenizer/amplify_vocab.txt",
41
+ "vocab_size": 32
42
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:938da9b3999b2168c99d3629525df19020558e90e90a57a85964532a9ee6b286
3
+ size 473147704
rotary.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+
5
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
6
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
7
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
8
+ freqs = torch.outer(t, freqs)
9
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
10
+ return freqs_cis
11
+
12
+
13
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
14
+ assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1])
15
+ return freqs_cis.contiguous().unsqueeze(2)
16
+
17
+
18
+ def apply_rotary_emb(
19
+ xq: torch.Tensor,
20
+ xk: torch.Tensor,
21
+ freqs_cis: torch.Tensor,
22
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
23
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
24
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
25
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
26
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
27
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
28
+ return xq_out.type_as(xq), xk_out.type_as(xk)