davidhd commited on
Commit
4422087
·
verified ·
1 Parent(s): 935d59e

Upload AMPLIFY

Browse files
Files changed (5) hide show
  1. amplify.py +297 -0
  2. config.json +2 -3
  3. rmsnorm.py +38 -0
  4. rotary.py +80 -0
  5. tokenizer.py +133 -0
amplify.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://stackoverflow.com/a/23689767
2
+ # From https://github.com/pytorch/pytorch/issues/97899
3
+ # From https://github.com/facebookresearch/llama/blob/main/llama/model.py
4
+ import yaml
5
+
6
+ import safetensors
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn.functional import scaled_dot_product_attention
10
+ from xformers.ops import SwiGLU, memory_efficient_attention
11
+
12
+ from .rmsnorm import RMSNorm
13
+ from .rotary import precompute_freqs_cis, apply_rotary_emb
14
+ from .tokenizer import ProteinTokenizer
15
+
16
+ from transformers import PreTrainedModel, PretrainedConfig
17
+ from transformers.modeling_outputs import MaskedLMOutput
18
+
19
+
20
+ class DotDict(dict):
21
+ """Dictionary that supports the dot notation to access attributes (similarly to HuggingFace)."""
22
+
23
+ __getattr__ = dict.get
24
+ __setattr__ = dict.__setitem__
25
+ __delattr__ = dict.__delitem__
26
+
27
+
28
+ class AMPLIFYConfig(PretrainedConfig):
29
+ model_type = "AMPLIFY"
30
+
31
+ # All config parameters must have a default value.
32
+ def __init__(
33
+ self,
34
+ hidden_size: int = 960,
35
+ num_hidden_layers: int = 32,
36
+ num_attention_heads: int = 15,
37
+ intermediate_size: int = 3840,
38
+ dropout_prob: float = 0,
39
+ embedding_init_range: float = 0.02,
40
+ decoder_init_range: float = 0.02,
41
+ rms_norm: bool = True,
42
+ norm_eps: float = 1e-05,
43
+ hidden_act: str = "SwiGLU",
44
+ layer_norm_after_embedding: bool = False,
45
+ layer_norm_before_last_layer: bool = True,
46
+ vocab_size: int = 27,
47
+ ffn_bias: bool = False,
48
+ att_bias: bool = False,
49
+ pad_token_id: int = 0,
50
+ max_length: int = 2048,
51
+ **kwargs,
52
+ ):
53
+ super().__init__(**kwargs)
54
+
55
+ self.hidden_size = hidden_size
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.num_attention_heads = num_attention_heads
58
+ self.intermediate_size = intermediate_size
59
+ self.dropout_prob = dropout_prob
60
+ self.embedding_init_range = embedding_init_range
61
+ self.decoder_init_range = decoder_init_range
62
+ self.rms_norm = rms_norm
63
+ self.norm_eps = norm_eps
64
+ self.hidden_act = hidden_act
65
+ self.layer_norm_after_embedding = layer_norm_after_embedding
66
+ self.layer_norm_before_last_layer = layer_norm_before_last_layer
67
+ self.vocab_size = vocab_size
68
+ self.ffn_bias = ffn_bias
69
+ self.att_bias = att_bias
70
+ self.pad_token_id = pad_token_id
71
+ self.max_length = max_length
72
+
73
+
74
+ class EncoderBlock(nn.Module):
75
+ """Transformer encoder block."""
76
+
77
+ def __init__(self, config: AMPLIFYConfig):
78
+ """Initialize a EncoderBlock.
79
+
80
+ Args:
81
+ hidden_size (int): _description_
82
+ num_attention_heads (int): _description_
83
+ intermediate_size (int, optional): _description_. Defaults to 2048.
84
+ dropout_prob (float, optional): _description_. Defaults to 0.1.
85
+ activation (str, optional): _description_. Defaults to "relu".
86
+ rms_norm (bool, optional): _description_. Defaults to True.
87
+ norm_eps (float, optional): _description_. Defaults to 1e-5.
88
+ pad_token_id (int, optional): _description_. Defaults to 0.
89
+ max_length (int, optional): _description_. Defaults to 2048.
90
+ ffn_bias (bool, optional): _description_. Defaults to False.
91
+ att_bias (bool, optional): _description_. Defaults to False.
92
+ """
93
+ super().__init__()
94
+
95
+ self.config = config
96
+ self.d_head = config.hidden_size // config.num_attention_heads
97
+
98
+ # Attention
99
+ self.q = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=config.att_bias)
100
+ self.k = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=config.att_bias)
101
+ self.v = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=config.att_bias)
102
+ self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=config.att_bias)
103
+ self.resid_dropout = nn.Dropout(config.dropout_prob)
104
+
105
+ # Feedforward network
106
+ act = config.hidden_act.lower()
107
+ if act == "swiglu":
108
+ # To keep the number of parameters and the amount of computation constant, we reduce the number of
109
+ # hidden units by a factor of 2/3 (https://arxiv.org/pdf/2002.05202.pdf) and make it a multiple of 8 to
110
+ # avoid RuntimeError due to misaligned operand
111
+ multiple_of = 8
112
+ intermediate_size = int(2 * config.intermediate_size / 3)
113
+ intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
114
+ self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=config.ffn_bias)
115
+ elif act == "relu":
116
+ self.ffn = nn.Sequential(
117
+ nn.Linear(config.hidden_size, config.intermediate_size, bias=config.ffn_bias),
118
+ nn.ReLU(),
119
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=config.ffn_bias),
120
+ )
121
+ elif act == "gelu":
122
+ self.ffn = nn.Sequential(
123
+ nn.Linear(config.hidden_size, config.intermediate_size, bias=config.ffn_bias),
124
+ nn.GELU(),
125
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=config.ffn_bias),
126
+ )
127
+ else:
128
+ raise ValueError(f"Unsupported hidden_act: {config.hidden_act}")
129
+
130
+ self.attention_norm = (
131
+ RMSNorm(config.hidden_size, config.norm_eps)
132
+ if config.rms_norm
133
+ else nn.LayerNorm(config.hidden_size, config.norm_eps)
134
+ )
135
+ self.ffn_norm = (
136
+ RMSNorm(config.hidden_size, config.norm_eps)
137
+ if config.rms_norm
138
+ else nn.LayerNorm(config.hidden_size, config.norm_eps)
139
+ )
140
+
141
+ self.ffn_dropout = nn.Dropout(config.dropout_prob)
142
+
143
+ def forward(self, x: torch.Tensor, pad_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
144
+ attn, contact = self._att_block(self.attention_norm(x), pad_mask, freqs_cis, output_attentions)
145
+ x = x + attn
146
+ x = x + self._ff_block(self.ffn_norm(x))
147
+ return x, contact
148
+
149
+ def _att_block(self, x: torch.Tensor, pad_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
150
+ batch_size, seq_len, _ = x.shape
151
+ xq, xk, xv = self.q(x), self.k(x), self.v(x)
152
+
153
+ # Reshape for rotary embeddings
154
+ xq = xq.view(batch_size, seq_len, self.config.num_attention_heads, self.d_head)
155
+ xk = xk.view(batch_size, seq_len, self.config.num_attention_heads, self.d_head)
156
+ xv = xv.view(batch_size, seq_len, self.config.num_attention_heads, self.d_head)
157
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
158
+
159
+ # Compute the attention weight
160
+ attn_weights = None
161
+ if output_attentions:
162
+ attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
163
+ if pad_mask is not None:
164
+ attn_weights = attn_weights + pad_mask
165
+ attn_weights = attn_weights.softmax(-1)
166
+
167
+ # Compute the attention using xformers if the tensors are on GPU
168
+ if x.is_cuda:
169
+ # Input and output are of dimension (B, M, H, K) where B is the batch size, M the sequence length,
170
+ # H the number of heads, and K the embeding size per head
171
+ attn = memory_efficient_attention(
172
+ query=xq,
173
+ key=xk,
174
+ value=xv,
175
+ attn_bias=pad_mask,
176
+ p=self.config.dropout_prob if self.training else 0,
177
+ )
178
+ else:
179
+ # Input and output are of dimension (B, H, M, K)
180
+ attn = scaled_dot_product_attention(
181
+ query=xq.transpose(1, 2),
182
+ key=xk.transpose(1, 2),
183
+ value=xv.transpose(1, 2),
184
+ attn_mask=pad_mask,
185
+ dropout_p=self.config.dropout_prob if self.training else 0,
186
+ ).transpose(1, 2)
187
+
188
+ attn_scores = self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.d_head))
189
+ return (self.resid_dropout(attn_scores), attn_weights)
190
+
191
+ def _ff_block(self, x: torch.Tensor):
192
+ return self.ffn_dropout(self.ffn(x))
193
+
194
+
195
+ class AMPLIFYPreTrainedModel(PreTrainedModel):
196
+ config_class = AMPLIFYConfig
197
+
198
+ def _init_weights(self, module):
199
+ if isinstance(module, nn.Linear):
200
+ module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range)
201
+ if module.bias is not None:
202
+ module.bias.data.zero_()
203
+ elif isinstance(module, nn.Embedding):
204
+ module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range)
205
+
206
+
207
+ class AMPLIFY(AMPLIFYPreTrainedModel):
208
+ """The main model class.
209
+
210
+ Args:
211
+ config (amplify.model.amplify.AMPLIFYConfig): model configuration, usually defined from the Hydra configuration.
212
+ """
213
+
214
+ def __init__(self, config: AMPLIFYConfig, **kwargs):
215
+ super().__init__(config)
216
+
217
+ self.config = config
218
+
219
+ self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
220
+
221
+ if config.layer_norm_after_embedding:
222
+ self.layer_norm_1 = (
223
+ RMSNorm(config.hidden_size, config.norm_eps)
224
+ if config.rms_norm
225
+ else nn.LayerNorm(config.hidden_size, config.norm_eps)
226
+ )
227
+
228
+ self.transformer_encoder = nn.ModuleList()
229
+ for _ in range(config.num_hidden_layers):
230
+ self.transformer_encoder.append(EncoderBlock(config))
231
+
232
+ if config.layer_norm_before_last_layer:
233
+ self.layer_norm_2 = (
234
+ RMSNorm(config.hidden_size, config.norm_eps)
235
+ if config.rms_norm
236
+ else nn.LayerNorm(config.hidden_size, config.norm_eps)
237
+ )
238
+
239
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
240
+
241
+ self.freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
242
+
243
+ # Initialize weights and apply final processing
244
+ self.post_init()
245
+
246
+ @classmethod
247
+ def load(cls, checkpoint_path: str, config_path: str):
248
+
249
+ with open(config_path, "r") as file:
250
+ cfg = yaml.safe_load(file)
251
+
252
+ model = AMPLIFY(AMPLIFYConfig(**cfg["model"], **cfg["tokenizer"]))
253
+
254
+ if checkpoint_path.endswith(".safetensors"):
255
+ state_dict = safetensors.torch.load_file(checkpoint_path)
256
+ elif checkpoint_path.endswith(".pt"):
257
+ state_dict = torch.load(checkpoint_path)
258
+ else:
259
+ raise ValueError(f"Expected checkpoint to be a `.pt` or `.safetensors` file.")
260
+
261
+ model.load_state_dict(state_dict)
262
+ tokenizer = ProteinTokenizer(**cfg["tokenizer"])
263
+ return model, tokenizer
264
+
265
+ def forward(self, src, pad_mask=None, output_hidden_states=False, output_attentions=False):
266
+ # Initialize
267
+ hidden_states, attentions = [], []
268
+
269
+ # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
270
+ if pad_mask is not None:
271
+ assert pad_mask.dtype != torch.bool and 1.0 not in pad_mask, "AMPLIFY expects an additive pad_mask"
272
+ pad_mask = (
273
+ pad_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, pad_mask.size(-1), 1)
274
+ )
275
+
276
+ # RoPE
277
+ self.freqs_cis = self.freqs_cis.to(src.device, non_blocking=True)
278
+ freqs_cis = self.freqs_cis[: src.shape[1]]
279
+
280
+ # Embedding
281
+ x = self.encoder(src)
282
+ if self.config.layer_norm_after_embedding:
283
+ x = self.layer_norm_1(x)
284
+
285
+ # Transformer encoder
286
+ for layer in self.transformer_encoder:
287
+ x, attn = layer(x, pad_mask, freqs_cis, output_attentions)
288
+ if output_hidden_states:
289
+ hidden_states.append(x)
290
+ if output_attentions:
291
+ attentions.append(attn)
292
+
293
+ # Classification head with layer norm
294
+ logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
295
+
296
+ # Return logits or the output of the last hidden layer
297
+ return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
config.json CHANGED
@@ -1,13 +1,12 @@
1
  {
2
  "_name_": "PLM",
3
- "_name_or_path": "davidhd/SaAMPLIFY_120M",
4
  "architectures": [
5
  "AMPLIFY"
6
  ],
7
  "att_bias": false,
8
  "auto_map": {
9
- "AutoConfig": "davidhd/SaAMPLIFY_120M--amplify.AMPLIFYConfig",
10
- "AutoModel": "davidhd/SaAMPLIFY_120M--amplify.AMPLIFY"
11
  },
12
  "bos_token_id": 3,
13
  "decoder_init_range": 0.02,
 
1
  {
2
  "_name_": "PLM",
 
3
  "architectures": [
4
  "AMPLIFY"
5
  ],
6
  "att_bias": false,
7
  "auto_map": {
8
+ "AutoConfig": "amplify.AMPLIFYConfig",
9
+ "AutoModel": "amplify.AMPLIFY"
10
  },
11
  "bos_token_id": 3,
12
  "decoder_init_range": 0.02,
rmsnorm.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(self, dim: int, eps: float = 1e-6):
7
+ """
8
+ Initialize the RMSNorm normalization layer.
9
+
10
+ Args:
11
+ dim (int): The dimension of the input tensor.
12
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13
+
14
+ Attributes:
15
+ eps (float): A small value added to the denominator for numerical stability.
16
+ weight (nn.Parameter): Learnable scaling parameter.
17
+
18
+ """
19
+ super().__init__()
20
+ self.eps = eps
21
+ self.weight = nn.Parameter(torch.ones(dim))
22
+
23
+ def _norm(self, x):
24
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
25
+
26
+ def forward(self, x):
27
+ """
28
+ Forward pass through the RMSNorm layer.
29
+
30
+ Args:
31
+ x (torch.Tensor): The input tensor.
32
+
33
+ Returns:
34
+ torch.Tensor: The output tensor after applying RMSNorm.
35
+
36
+ """
37
+ output = self._norm(x.float()).type_as(x) # Avoids mixed precision issues as in https://github.com/chandar-lab/AMPLIFY/issues/19
38
+ return output * self.weight
rotary.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+
5
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
6
+ """
7
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
8
+
9
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
10
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
11
+ The returned tensor contains complex values in complex64 data type.
12
+
13
+ Args:
14
+ dim (int): Dimension of the frequency tensor.
15
+ end (int): End index for precomputing frequencies.
16
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
17
+
18
+ Returns:
19
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
20
+ """
21
+
22
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
23
+ t = torch.arange(end, device=freqs.device) # type: ignore
24
+ freqs = torch.outer(t, freqs).float() # type: ignore
25
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
26
+
27
+
28
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
29
+ """
30
+ Reshape frequency tensor for broadcasting it with another tensor.
31
+
32
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
33
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
34
+
35
+ Args:
36
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
37
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
38
+
39
+ Returns:
40
+ torch.Tensor: Reshaped frequency tensor.
41
+
42
+ Raises:
43
+ AssertionError: If the frequency tensor doesn't match the expected shape.
44
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
45
+ """
46
+
47
+ ndim = x.ndim
48
+ assert 0 <= 1 < ndim
49
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
50
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
51
+ return freqs_cis.view(*shape)
52
+
53
+
54
+ def apply_rotary_emb(
55
+ xq: torch.Tensor,
56
+ xk: torch.Tensor,
57
+ freqs_cis: torch.Tensor,
58
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ """
60
+ Apply rotary embeddings to input tensors using the given frequency tensor.
61
+
62
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
63
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
64
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
65
+ returned as real tensors.
66
+
67
+ Args:
68
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
69
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
70
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
74
+ """
75
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
76
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
77
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
78
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
79
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
80
+ return xq_out.type_as(xq), xk_out.type_as(xk)
tokenizer.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Optional, Union
3
+ from torch import Tensor
4
+
5
+
6
+ class ProteinTokenizer(object):
7
+ def __init__(
8
+ self,
9
+ vocab_path: str,
10
+ pad_token_id: int,
11
+ mask_token_id: int,
12
+ bos_token_id: int,
13
+ eos_token_id: int,
14
+ unk_token_id: int,
15
+ other_special_token_ids: Optional[List[int]],
16
+ **kwargs,
17
+ ):
18
+ """Vocabulary comprising the amino acids, and the special tokens <unk>, <bos>, <eos>, <pad> and <mask>.
19
+
20
+ Args:
21
+ vocab_path (str): Path to the vocabulary file to load.
22
+ pad_token_id (int): <PAD> token index.
23
+ mask_token_id (int): <MASK> token index.
24
+ bos_token_id (int): <BOS> token index.
25
+ eos_token_id (int): <EOS> token index.
26
+ unk_token_id (int): <UNK> token index.
27
+ other_special_token_ids (Optional[List[int]]): List of additional special tokens.
28
+ """
29
+ self._token_to_id = dict()
30
+ self._id_to_token = dict()
31
+
32
+ with open(vocab_path, "r") as vocab_file:
33
+ for i, token in enumerate(vocab_file):
34
+ token = token.strip()
35
+ self._token_to_id[token] = i
36
+ self._id_to_token[i] = token
37
+
38
+ # Padding token
39
+ self.pad_token_id = pad_token_id
40
+ self.pad_token = self._token_to_id.get(pad_token_id)
41
+
42
+ # Beginning and end of sequence
43
+ self.bos_token_id = bos_token_id
44
+ self.eos_token_id = eos_token_id
45
+ self.bos_token = self._token_to_id.get(bos_token_id)
46
+ self.eos_token = self._token_to_id.get(eos_token_id)
47
+
48
+ # Mask token
49
+ self.mask_token_id = mask_token_id
50
+ self.mask_token = self._token_to_id.get(mask_token_id)
51
+
52
+ # Unknown token
53
+ self.unk_token_id = unk_token_id
54
+ self.unk_token = self._id_to_token.get(unk_token_id)
55
+
56
+ # Set of all special token indices
57
+ self.special_token_ids = set()
58
+ self.special_token_ids.add(pad_token_id)
59
+ self.special_token_ids.add(mask_token_id)
60
+ self.special_token_ids.add(bos_token_id)
61
+ self.special_token_ids.add(eos_token_id)
62
+ self.special_token_ids.add(unk_token_id)
63
+ if other_special_token_ids is not None:
64
+ self.special_token_ids.update(other_special_token_ids)
65
+
66
+ def __len__(self) -> int:
67
+ return len(self._token_to_id)
68
+
69
+ def token_to_id(self, token: str) -> int:
70
+ return self._token_to_id.get(token, self.unk_token_id)
71
+
72
+ def id_to_token(self, index: int) -> str:
73
+ return self._id_to_token.get(index, self.unk_token)
74
+
75
+ def encode(
76
+ self,
77
+ tokens: List[str],
78
+ max_length: Optional[int] = None,
79
+ add_special_tokens: bool = True,
80
+ random_truncate: bool = True,
81
+ **kwargs,
82
+ ) -> Union[List[int], Tensor]:
83
+ """Encodes a list of tokens into a list or tensor of token indices.
84
+
85
+ Args:
86
+ tokens (List[str]): Sequence of tokens to encode.
87
+ max_length (Optional[int], optional): Truncate the sequence to the specified length. Defaults to None.
88
+ add_special_tokens (bool, optional): Add special tokens <bos> and <eos> at the start and end.. Defaults to True.
89
+ random_truncate (bool, optional): Truncate the sequence to a random subsequence of if longer than truncate.
90
+ Defaults to True.
91
+
92
+ Returns:
93
+ Union[List[int], Tensor]: Token indices.
94
+ """
95
+ token_ids = list(map(self.token_to_id, tokens))
96
+ if add_special_tokens:
97
+ token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
98
+ if max_length is not None and max_length < len(token_ids):
99
+ if random_truncate:
100
+ offset = int(torch.randint(0, len(token_ids) - max_length, (1,)).item())
101
+ else:
102
+ offset = 0
103
+ token_ids = token_ids[offset : offset + max_length]
104
+ return torch.as_tensor(token_ids, dtype=torch.long)
105
+
106
+ def decode(
107
+ self,
108
+ token_ids: List[int],
109
+ skip_special_tokens: bool = True,
110
+ **kwargs,
111
+ ) -> Union[List[str], str]:
112
+ """Decodes a list or tensor of token ids into a list or string of tokens.
113
+
114
+ Args:
115
+ token_ids (List[int]): Token indices to decode.
116
+ skip_special_tokens (bool, optional): Skip the special tokens <bos> and <eos> at the start and end.
117
+ Defaults to True.
118
+
119
+ Returns:
120
+ Union[List[str], str]: Protein.
121
+ """
122
+ if torch.is_tensor(token_ids):
123
+ token_ids = token_ids.tolist()
124
+
125
+ if skip_special_tokens:
126
+ if len(token_ids) > 0 and token_ids[0] in self.special_token_ids:
127
+ token_ids = token_ids[1:]
128
+ if len(token_ids) > 0 and token_ids[-1] in self.special_token_ids:
129
+ token_ids = token_ids[:-1]
130
+
131
+ tokens = " ".join(map(self.id_to_token, token_ids))
132
+
133
+ return tokens