lhallee commited on
Commit
1d5119c
·
verified ·
1 Parent(s): 4209cf0

Upload dplm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dplm.py +810 -0
dplm.py ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """
4
+ FastPLMs-compatible DPLM implementation.
5
+
6
+ This module is based on:
7
+ https://github.com/bytedance/dplm/blob/main/src/byprot/models/lm/esm_dplm.py
8
+ """
9
+
10
+ import entrypoint_setup
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from dataclasses import dataclass
15
+ from typing import Dict, List, Optional, Tuple, Union
16
+
17
+ from transformers import AutoTokenizer, EsmTokenizer
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ BaseModelOutputWithPoolingAndCrossAttentions,
21
+ ModelOutput,
22
+ SequenceClassifierOutput,
23
+ TokenClassifierOutput,
24
+ )
25
+ from transformers.models.esm.configuration_esm import EsmConfig
26
+ from transformers.models.esm.modeling_esm import (
27
+ EsmAttention,
28
+ EsmClassificationHead,
29
+ EsmContactPredictionHead,
30
+ EsmEmbeddings,
31
+ EsmEncoder,
32
+ EsmIntermediate,
33
+ EsmLayer,
34
+ EsmLMHead,
35
+ EsmOutput,
36
+ EsmPooler,
37
+ EsmPreTrainedModel,
38
+ EsmSelfAttention,
39
+ EsmSelfOutput,
40
+ )
41
+
42
+ try:
43
+ from torch.nn.attention.flex_attention import create_block_mask
44
+ from torch.nn.attention.flex_attention import flex_attention
45
+ except ImportError:
46
+ create_block_mask = None
47
+ flex_attention = None
48
+
49
+ try:
50
+ from .base_tokenizer import BaseSequenceTokenizer
51
+ except ImportError:
52
+ from base_tokenizer import BaseSequenceTokenizer
53
+
54
+ try:
55
+ from .embedding_mixin import EmbeddingMixin
56
+ except ImportError:
57
+ try:
58
+ from ..embedding_mixin import EmbeddingMixin
59
+ except ImportError:
60
+ from embedding_mixin import EmbeddingMixin
61
+
62
+
63
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
64
+ assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
65
+ token_valid = attention_mask_2d.bool()
66
+ batch_size, seq_len = token_valid.shape
67
+
68
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
69
+ return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
70
+
71
+ return create_block_mask(
72
+ mask_mod,
73
+ batch_size,
74
+ 1,
75
+ seq_len,
76
+ seq_len,
77
+ device=attention_mask_2d.device,
78
+ )
79
+
80
+
81
+ @dataclass
82
+ class DPLMMaskedLMOutput(ModelOutput):
83
+ loss: Optional[torch.Tensor] = None
84
+ logits: Optional[torch.Tensor] = None
85
+ last_hidden_state: Optional[torch.Tensor] = None
86
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
87
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
88
+
89
+
90
+ class DPLMConfig(EsmConfig):
91
+ model_type = "dplm"
92
+
93
+ def __init__(
94
+ self,
95
+ attn_backend: str = "sdpa",
96
+ **kwargs,
97
+ ):
98
+ super().__init__(**kwargs)
99
+ self.attn_backend = attn_backend
100
+ self.tie_word_embeddings = False
101
+
102
+
103
+ class DPLMPreTrainedModel(EsmPreTrainedModel):
104
+ config_class = DPLMConfig
105
+ base_model_prefix = "dplm"
106
+ supports_gradient_checkpointing = True
107
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
108
+ all_tied_weights_keys = {}
109
+
110
+ def get_input_embeddings(self) -> nn.Module:
111
+ try:
112
+ return self.embeddings.word_embeddings
113
+ except AttributeError:
114
+ return self.esm.embeddings.word_embeddings
115
+
116
+
117
+ class ModifiedEsmSelfAttention(EsmSelfAttention):
118
+ def __init__(self, config, position_embedding_type=None):
119
+ super().__init__(config, position_embedding_type)
120
+ self.attn_backend = config.attn_backend
121
+
122
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
123
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
124
+ x = x.view(new_x_shape)
125
+ return x.permute(0, 2, 1, 3)
126
+
127
+ def forward(
128
+ self,
129
+ hidden_states: torch.Tensor,
130
+ attention_mask: Optional[torch.FloatTensor] = None,
131
+ head_mask: Optional[torch.FloatTensor] = None,
132
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
133
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
134
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
135
+ output_attentions: Optional[bool] = False,
136
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
137
+ flex_block_mask: Optional[object] = None,
138
+ **kwargs,
139
+ ) -> Tuple[torch.Tensor]:
140
+ if past_key_values is not None:
141
+ past_key_value = past_key_values
142
+
143
+ mixed_query_layer = self.query(hidden_states)
144
+ is_cross_attention = encoder_hidden_states is not None
145
+
146
+ if is_cross_attention and past_key_value is not None:
147
+ key_layer = past_key_value[0]
148
+ value_layer = past_key_value[1]
149
+ attention_mask = encoder_attention_mask
150
+ elif is_cross_attention:
151
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
152
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
153
+ attention_mask = encoder_attention_mask
154
+ elif past_key_value is not None:
155
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
156
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
157
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
158
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
159
+ else:
160
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
161
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
162
+
163
+ query_layer = self.transpose_for_scores(mixed_query_layer) * self.attention_head_size**-0.5
164
+
165
+ if self.is_decoder:
166
+ past_key_value = (key_layer, value_layer)
167
+
168
+ if self.position_embedding_type == "rotary":
169
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
170
+
171
+ if self.position_embedding_type in ["relative_key", "relative_key_query"]:
172
+ raise NotImplementedError
173
+
174
+ query_layer = query_layer.contiguous()
175
+ key_layer = key_layer.contiguous()
176
+ value_layer = value_layer.contiguous()
177
+
178
+ if output_attentions:
179
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
180
+ if attention_mask is not None:
181
+ attention_scores = attention_scores + attention_mask
182
+ attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
183
+ context_layer = torch.matmul(attention_probs, value_layer)
184
+ else:
185
+ attention_probs = None
186
+ if self.attn_backend == "flex":
187
+ assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
188
+ assert query_layer.dtype in (torch.float16, torch.bfloat16), (
189
+ f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
190
+ )
191
+ assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention."
192
+ assert past_key_value is None, "Flex attention backend currently does not support KV caching."
193
+ if attention_mask is not None:
194
+ assert flex_block_mask is not None, (
195
+ "Flex attention backend requires a block mask when attention_mask is provided."
196
+ )
197
+ context_layer = flex_attention(
198
+ query_layer,
199
+ key_layer,
200
+ value_layer,
201
+ block_mask=flex_block_mask,
202
+ scale=1.0,
203
+ )
204
+ else:
205
+ context_layer = F.scaled_dot_product_attention(
206
+ query_layer,
207
+ key_layer,
208
+ value_layer,
209
+ attn_mask=attention_mask,
210
+ scale=1.0,
211
+ )
212
+
213
+ if head_mask is not None and torch.is_tensor(head_mask):
214
+ context_layer = context_layer * head_mask
215
+
216
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
217
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
218
+ context_layer = context_layer.view(new_context_layer_shape)
219
+
220
+ outputs = (context_layer, attention_probs)
221
+ if self.is_decoder:
222
+ outputs = outputs + (past_key_value,)
223
+ return outputs
224
+
225
+
226
+ class ModifiedEsmAttention(EsmAttention):
227
+ def __init__(self, config):
228
+ nn.Module.__init__(self)
229
+ self.self = ModifiedEsmSelfAttention(config)
230
+ self.output = EsmSelfOutput(config)
231
+ self.pruned_heads = set()
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states,
237
+ attention_mask=None,
238
+ head_mask=None,
239
+ encoder_hidden_states=None,
240
+ encoder_attention_mask=None,
241
+ past_key_value=None,
242
+ output_attentions=False,
243
+ flex_block_mask=None,
244
+ ):
245
+ hidden_states_ln = self.LayerNorm(hidden_states)
246
+ self_outputs = self.self(
247
+ hidden_states_ln,
248
+ attention_mask,
249
+ head_mask,
250
+ encoder_hidden_states,
251
+ encoder_attention_mask,
252
+ past_key_value,
253
+ output_attentions,
254
+ flex_block_mask=flex_block_mask,
255
+ )
256
+ attention_output = self.output(self_outputs[0], hidden_states)
257
+ outputs = (attention_output,) + self_outputs[1:]
258
+ return outputs
259
+
260
+
261
+ class ModifiedEsmLayer(EsmLayer):
262
+ def __init__(self, config):
263
+ nn.Module.__init__(self)
264
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
265
+ self.seq_len_dim = 1
266
+ self.attention = ModifiedEsmAttention(config)
267
+ self.is_decoder = config.is_decoder
268
+ self.add_cross_attention = config.add_cross_attention
269
+ if self.add_cross_attention:
270
+ if self.is_decoder is False:
271
+ raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
272
+ self.crossattention = ModifiedEsmAttention(config)
273
+ self.intermediate = EsmIntermediate(config)
274
+ self.output = EsmOutput(config)
275
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
276
+
277
+ def forward(
278
+ self,
279
+ hidden_states,
280
+ attention_mask=None,
281
+ head_mask=None,
282
+ encoder_hidden_states=None,
283
+ encoder_attention_mask=None,
284
+ past_key_value=None,
285
+ output_attentions=False,
286
+ flex_block_mask=None,
287
+ ):
288
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
289
+ self_attention_outputs = self.attention(
290
+ hidden_states,
291
+ attention_mask,
292
+ head_mask,
293
+ output_attentions=output_attentions,
294
+ past_key_value=self_attn_past_key_value,
295
+ flex_block_mask=flex_block_mask,
296
+ )
297
+ attention_output = self_attention_outputs[0]
298
+
299
+ if self.is_decoder:
300
+ outputs = self_attention_outputs[1:-1]
301
+ present_key_value = self_attention_outputs[-1]
302
+ else:
303
+ outputs = self_attention_outputs[1:]
304
+
305
+ if self.is_decoder and encoder_hidden_states is not None:
306
+ if self.add_cross_attention is False:
307
+ raise AttributeError(
308
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention "
309
+ "layers by setting `config.add_cross_attention=True`"
310
+ )
311
+
312
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
313
+ cross_attention_outputs = self.crossattention(
314
+ attention_output,
315
+ attention_mask,
316
+ head_mask,
317
+ encoder_hidden_states,
318
+ encoder_attention_mask,
319
+ cross_attn_past_key_value,
320
+ output_attentions,
321
+ flex_block_mask=None,
322
+ )
323
+ attention_output = cross_attention_outputs[0]
324
+ outputs = outputs + cross_attention_outputs[1:-1]
325
+ present_key_value = present_key_value + cross_attention_outputs[-1]
326
+
327
+ layer_output = self.feed_forward_chunk(attention_output)
328
+ outputs = (layer_output,) + outputs
329
+
330
+ if self.is_decoder:
331
+ outputs = outputs + (present_key_value,)
332
+ return outputs
333
+
334
+
335
+ class ModifiedEsmEncoder(EsmEncoder):
336
+ def __init__(self, config):
337
+ nn.Module.__init__(self)
338
+ self.config = config
339
+ self.layer = nn.ModuleList([ModifiedEsmLayer(config) for _ in range(config.num_hidden_layers)])
340
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
341
+ self.gradient_checkpointing = False
342
+
343
+ def forward(
344
+ self,
345
+ hidden_states,
346
+ attention_mask=None,
347
+ head_mask=None,
348
+ encoder_hidden_states=None,
349
+ encoder_attention_mask=None,
350
+ past_key_values=None,
351
+ use_cache=None,
352
+ output_attentions=False,
353
+ output_hidden_states=False,
354
+ return_dict=True,
355
+ flex_block_mask=None,
356
+ ):
357
+ all_hidden_states = () if output_hidden_states else None
358
+ all_self_attentions = () if output_attentions else None
359
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
360
+ next_decoder_cache = () if use_cache else None
361
+
362
+ for i, layer_module in enumerate(self.layer):
363
+ if output_hidden_states:
364
+ all_hidden_states = all_hidden_states + (hidden_states,)
365
+
366
+ layer_head_mask = head_mask[i] if head_mask is not None else None
367
+ past_key_value = past_key_values[i] if past_key_values is not None else None
368
+
369
+ if self.gradient_checkpointing and self.training:
370
+ layer_outputs = self._gradient_checkpointing_func(
371
+ layer_module.__call__,
372
+ hidden_states,
373
+ attention_mask,
374
+ layer_head_mask,
375
+ encoder_hidden_states,
376
+ encoder_attention_mask,
377
+ past_key_value,
378
+ output_attentions,
379
+ flex_block_mask,
380
+ )
381
+ else:
382
+ layer_outputs = layer_module(
383
+ hidden_states,
384
+ attention_mask,
385
+ layer_head_mask,
386
+ encoder_hidden_states,
387
+ encoder_attention_mask,
388
+ past_key_value,
389
+ output_attentions,
390
+ flex_block_mask,
391
+ )
392
+
393
+ hidden_states = layer_outputs[0]
394
+ if use_cache:
395
+ next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
396
+ if output_attentions:
397
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
398
+ if self.config.add_cross_attention:
399
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
400
+
401
+ if self.emb_layer_norm_after:
402
+ hidden_states = self.emb_layer_norm_after(hidden_states)
403
+
404
+ if output_hidden_states:
405
+ all_hidden_states = all_hidden_states + (hidden_states,)
406
+
407
+ if return_dict is False:
408
+ return tuple(
409
+ value
410
+ for value in [
411
+ hidden_states,
412
+ next_decoder_cache,
413
+ all_hidden_states,
414
+ all_self_attentions,
415
+ all_cross_attentions,
416
+ ]
417
+ if value is not None
418
+ )
419
+
420
+ return BaseModelOutputWithPastAndCrossAttentions(
421
+ last_hidden_state=hidden_states,
422
+ past_key_values=next_decoder_cache,
423
+ hidden_states=all_hidden_states,
424
+ attentions=all_self_attentions,
425
+ cross_attentions=all_cross_attentions,
426
+ )
427
+
428
+
429
+ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
430
+ config_class = DPLMConfig
431
+
432
+ def __init__(self, config, add_pooling_layer=True):
433
+ DPLMPreTrainedModel.__init__(self, config)
434
+ self.config = config
435
+ self.embeddings = EsmEmbeddings(config)
436
+ self.encoder = ModifiedEsmEncoder(config)
437
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
438
+ self.contact_head = EsmContactPredictionHead(
439
+ in_features=config.num_hidden_layers * config.num_attention_heads,
440
+ bias=True,
441
+ )
442
+ self.post_init()
443
+
444
+ def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
445
+ if head_mask.dim() == 1:
446
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
447
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
448
+ elif head_mask.dim() == 2:
449
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
450
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
451
+ head_mask = head_mask.to(dtype=self.dtype)
452
+ return head_mask
453
+
454
+ def get_head_mask(
455
+ self,
456
+ head_mask: Optional[torch.Tensor],
457
+ num_hidden_layers: int,
458
+ is_attention_chunked: bool = False,
459
+ ) -> Union[torch.Tensor, List[None]]:
460
+ if head_mask is None:
461
+ return [None] * num_hidden_layers
462
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
463
+ if is_attention_chunked:
464
+ head_mask = head_mask.unsqueeze(-1)
465
+ return head_mask
466
+
467
+ def set_input_embeddings(self, value):
468
+ self.embeddings.word_embeddings = value
469
+
470
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
471
+ if attention_mask is None:
472
+ attention_mask = input_ids.ne(self.config.pad_token_id)
473
+ outputs = self(
474
+ input_ids=input_ids,
475
+ attention_mask=attention_mask,
476
+ output_hidden_states=False,
477
+ output_attentions=False,
478
+ return_dict=True,
479
+ )
480
+ return outputs.last_hidden_state
481
+
482
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
483
+ attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
484
+ attns = torch.stack(attns, dim=1)
485
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
486
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
487
+ return self.contact_head(input_ids, attns)
488
+
489
+ def forward(
490
+ self,
491
+ input_ids: Optional[torch.Tensor] = None,
492
+ attention_mask: Optional[torch.Tensor] = None,
493
+ position_ids: Optional[torch.Tensor] = None,
494
+ head_mask: Optional[torch.Tensor] = None,
495
+ inputs_embeds: Optional[torch.Tensor] = None,
496
+ encoder_hidden_states: Optional[torch.Tensor] = None,
497
+ encoder_attention_mask: Optional[torch.Tensor] = None,
498
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
499
+ use_cache: Optional[bool] = None,
500
+ output_attentions: Optional[bool] = None,
501
+ output_hidden_states: Optional[bool] = None,
502
+ return_dict: Optional[bool] = None,
503
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
504
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
505
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
506
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
507
+
508
+ if self.config.is_decoder:
509
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
510
+ else:
511
+ use_cache = False
512
+
513
+ if input_ids is not None and inputs_embeds is not None:
514
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
515
+ if input_ids is not None:
516
+ input_shape = input_ids.size()
517
+ elif inputs_embeds is not None:
518
+ input_shape = inputs_embeds.size()[:-1]
519
+ else:
520
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
521
+
522
+ batch_size, seq_length = input_shape
523
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
524
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
525
+
526
+ if attention_mask is None:
527
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
528
+
529
+ token_attention_mask = None
530
+ if attention_mask.dim() == 2:
531
+ token_attention_mask = attention_mask.bool()
532
+ if self.config.attn_backend == "flex" and output_attentions is False:
533
+ extended_attention_mask = None
534
+ else:
535
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
536
+ elif attention_mask.dim() == 4:
537
+ if self.config.attn_backend == "flex" and output_attentions is False:
538
+ extended_attention_mask = None
539
+ else:
540
+ extended_attention_mask = attention_mask
541
+ if input_ids is not None:
542
+ token_attention_mask = input_ids.ne(self.config.pad_token_id)
543
+ else:
544
+ raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
545
+
546
+ if self.config.is_decoder and encoder_hidden_states is not None:
547
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
548
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
549
+ if encoder_attention_mask is None:
550
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
551
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
552
+ else:
553
+ encoder_extended_attention_mask = encoder_attention_mask
554
+
555
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
556
+
557
+ embedding_attention_mask = token_attention_mask
558
+ if embedding_attention_mask is None and input_ids is not None:
559
+ embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
560
+
561
+ flex_block_mask = None
562
+ if (
563
+ self.config.attn_backend == "flex"
564
+ and token_attention_mask is not None
565
+ and output_attentions is False
566
+ ):
567
+ assert create_block_mask is not None, (
568
+ "Flex attention backend requested but torch.create_block_mask is unavailable."
569
+ )
570
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
571
+
572
+ embedding_output = self.embeddings(
573
+ input_ids=input_ids,
574
+ position_ids=position_ids,
575
+ attention_mask=embedding_attention_mask,
576
+ inputs_embeds=inputs_embeds,
577
+ )
578
+ encoder_outputs = self.encoder(
579
+ embedding_output,
580
+ attention_mask=extended_attention_mask,
581
+ head_mask=head_mask,
582
+ encoder_hidden_states=encoder_hidden_states,
583
+ encoder_attention_mask=encoder_extended_attention_mask,
584
+ past_key_values=past_key_values,
585
+ use_cache=use_cache,
586
+ output_attentions=output_attentions,
587
+ output_hidden_states=output_hidden_states,
588
+ return_dict=return_dict,
589
+ flex_block_mask=flex_block_mask,
590
+ )
591
+ sequence_output = encoder_outputs[0]
592
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
593
+
594
+ if return_dict is False:
595
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
596
+
597
+ return BaseModelOutputWithPoolingAndCrossAttentions(
598
+ last_hidden_state=sequence_output,
599
+ pooler_output=pooled_output,
600
+ past_key_values=None,
601
+ hidden_states=encoder_outputs.hidden_states,
602
+ attentions=encoder_outputs.attentions,
603
+ cross_attentions=encoder_outputs.cross_attentions,
604
+ )
605
+
606
+
607
+ class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
608
+ config_class = DPLMConfig
609
+
610
+ def __init__(self, config, dropout: float = 0.1):
611
+ config.hidden_dropout_prob = dropout
612
+ DPLMPreTrainedModel.__init__(self, config)
613
+ self.esm = DPLMModel(config, add_pooling_layer=False)
614
+ self.lm_head = EsmLMHead(config)
615
+ self.loss_fct = nn.CrossEntropyLoss()
616
+ self.post_init()
617
+
618
+ self.tokenizer = self.__class__.tokenizer
619
+ if isinstance(config._name_or_path, str) and len(config._name_or_path) > 0:
620
+ try:
621
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
622
+ except Exception:
623
+ self.tokenizer = self.__class__.tokenizer
624
+
625
+ self.mask_id = self.tokenizer.mask_token_id
626
+ self.pad_id = self.tokenizer.pad_token_id
627
+ self.bos_id = self.tokenizer.cls_token_id
628
+ self.eos_id = self.tokenizer.eos_token_id
629
+ self.x_id = self.tokenizer.convert_tokens_to_ids("X")
630
+ self.contact_head = None
631
+
632
+ def get_output_embeddings(self):
633
+ return self.lm_head.decoder
634
+
635
+ def set_output_embeddings(self, new_embeddings):
636
+ self.lm_head.decoder = new_embeddings
637
+
638
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
639
+ return self.esm._embed(input_ids, attention_mask)
640
+
641
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
642
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
643
+
644
+ def forward(
645
+ self,
646
+ input_ids: Optional[torch.Tensor] = None,
647
+ attention_mask: Optional[torch.Tensor] = None,
648
+ inputs_embeds: Optional[torch.Tensor] = None,
649
+ decoder_input_ids: Optional[torch.Tensor] = None,
650
+ decoder_attention_mask: Optional[torch.Tensor] = None,
651
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
652
+ labels: Optional[torch.Tensor] = None,
653
+ output_attentions: Optional[bool] = None,
654
+ output_hidden_states: Optional[bool] = None,
655
+ return_dict: Optional[bool] = None,
656
+ encoder_hidden_states: Optional[torch.Tensor] = None,
657
+ encoder_attention_mask: Optional[torch.Tensor] = None,
658
+ ) -> Union[Tuple[torch.Tensor], DPLMMaskedLMOutput]:
659
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
660
+ if attention_mask is None and input_ids is not None:
661
+ attention_mask = input_ids.ne(self.pad_id)
662
+
663
+ outputs = self.esm(
664
+ input_ids=input_ids,
665
+ attention_mask=attention_mask,
666
+ inputs_embeds=inputs_embeds,
667
+ encoder_hidden_states=encoder_hidden_states,
668
+ encoder_attention_mask=encoder_attention_mask,
669
+ output_attentions=output_attentions,
670
+ output_hidden_states=output_hidden_states,
671
+ return_dict=True,
672
+ )
673
+ sequence_output = outputs.last_hidden_state
674
+ logits = self.lm_head(sequence_output)
675
+
676
+ loss = None
677
+ if labels is not None:
678
+ labels = labels.to(logits.device)
679
+ loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
680
+
681
+ if return_dict is False:
682
+ output = (logits, sequence_output, outputs.hidden_states, outputs.attentions)
683
+ if loss is not None:
684
+ return (loss,) + output
685
+ return output
686
+
687
+ return DPLMMaskedLMOutput(
688
+ loss=loss,
689
+ logits=logits,
690
+ last_hidden_state=sequence_output,
691
+ hidden_states=outputs.hidden_states,
692
+ attentions=outputs.attentions,
693
+ )
694
+
695
+
696
+ class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin):
697
+ config_class = DPLMConfig
698
+
699
+ def __init__(self, config):
700
+ DPLMPreTrainedModel.__init__(self, config)
701
+ self.num_labels = config.num_labels
702
+ self.esm = DPLMModel(config, add_pooling_layer=False)
703
+ self.classifier = EsmClassificationHead(config)
704
+ self.mse = nn.MSELoss()
705
+ self.ce = nn.CrossEntropyLoss()
706
+ self.bce = nn.BCEWithLogitsLoss()
707
+ self.post_init()
708
+
709
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
710
+ return self.esm._embed(input_ids, attention_mask)
711
+
712
+ def forward(
713
+ self,
714
+ input_ids: Optional[torch.Tensor] = None,
715
+ attention_mask: Optional[torch.Tensor] = None,
716
+ inputs_embeds: Optional[torch.Tensor] = None,
717
+ labels: Optional[torch.Tensor] = None,
718
+ output_attentions: Optional[bool] = None,
719
+ output_hidden_states: Optional[bool] = None,
720
+ return_dict: Optional[bool] = None,
721
+ **kwargs,
722
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
723
+ outputs = self.esm(
724
+ input_ids=input_ids,
725
+ attention_mask=attention_mask,
726
+ inputs_embeds=inputs_embeds,
727
+ output_attentions=output_attentions,
728
+ output_hidden_states=output_hidden_states,
729
+ return_dict=True,
730
+ )
731
+ sequence_output = outputs.last_hidden_state
732
+ logits = self.classifier(sequence_output)
733
+
734
+ loss = None
735
+ if labels is not None:
736
+ labels = labels.to(logits.device)
737
+ if self.config.problem_type is None:
738
+ if self.num_labels == 1:
739
+ self.config.problem_type = "regression"
740
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
741
+ self.config.problem_type = "single_label_classification"
742
+ else:
743
+ self.config.problem_type = "multi_label_classification"
744
+
745
+ if self.config.problem_type == "regression":
746
+ if self.num_labels == 1:
747
+ loss = self.mse(logits.squeeze(), labels.squeeze())
748
+ else:
749
+ loss = self.mse(logits, labels)
750
+ elif self.config.problem_type == "single_label_classification":
751
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
752
+ elif self.config.problem_type == "multi_label_classification":
753
+ loss = self.bce(logits, labels)
754
+
755
+ return SequenceClassifierOutput(
756
+ loss=loss,
757
+ logits=logits,
758
+ hidden_states=outputs.hidden_states,
759
+ attentions=outputs.attentions,
760
+ )
761
+
762
+
763
+ class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin):
764
+ config_class = DPLMConfig
765
+
766
+ def __init__(self, config):
767
+ DPLMPreTrainedModel.__init__(self, config)
768
+ self.num_labels = config.num_labels
769
+ self.esm = DPLMModel(config, add_pooling_layer=False)
770
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
771
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
772
+ self.loss_fct = nn.CrossEntropyLoss()
773
+ self.post_init()
774
+
775
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
776
+ return self.esm._embed(input_ids, attention_mask)
777
+
778
+ def forward(
779
+ self,
780
+ input_ids: Optional[torch.Tensor] = None,
781
+ attention_mask: Optional[torch.Tensor] = None,
782
+ inputs_embeds: Optional[torch.Tensor] = None,
783
+ labels: Optional[torch.Tensor] = None,
784
+ output_attentions: Optional[bool] = None,
785
+ output_hidden_states: Optional[bool] = None,
786
+ return_dict: Optional[bool] = None,
787
+ **kwargs,
788
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
789
+ outputs = self.esm(
790
+ input_ids=input_ids,
791
+ attention_mask=attention_mask,
792
+ inputs_embeds=inputs_embeds,
793
+ output_attentions=output_attentions,
794
+ output_hidden_states=output_hidden_states,
795
+ return_dict=True,
796
+ )
797
+ sequence_output = self.dropout(outputs.last_hidden_state)
798
+ logits = self.classifier(sequence_output)
799
+
800
+ loss = None
801
+ if labels is not None:
802
+ labels = labels.to(logits.device)
803
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
804
+
805
+ return TokenClassifierOutput(
806
+ loss=loss,
807
+ logits=logits,
808
+ hidden_states=outputs.hidden_states,
809
+ attentions=outputs.attentions,
810
+ )