lhallee commited on
Commit
498ceff
·
verified ·
1 Parent(s): 2c79b44

Upload modeling_dplm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm.py +809 -0
modeling_dplm.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """
4
+ FastPLMs-compatible DPLM implementation.
5
+
6
+ This module is based on:
7
+ https://github.com/bytedance/dplm/blob/main/src/byprot/models/lm/esm_dplm.py
8
+ """
9
+
10
+ import entrypoint_setup
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from dataclasses import dataclass
15
+ from typing import Dict, List, Optional, Tuple, Union
16
+
17
+ from transformers import AutoTokenizer, EsmTokenizer
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ BaseModelOutputWithPoolingAndCrossAttentions,
21
+ ModelOutput,
22
+ SequenceClassifierOutput,
23
+ TokenClassifierOutput,
24
+ )
25
+ from transformers.models.esm.configuration_esm import EsmConfig
26
+ from transformers.models.esm.modeling_esm import (
27
+ EsmAttention,
28
+ EsmClassificationHead,
29
+ EsmContactPredictionHead,
30
+ EsmEmbeddings,
31
+ EsmEncoder,
32
+ EsmIntermediate,
33
+ EsmLayer,
34
+ EsmLMHead,
35
+ EsmOutput,
36
+ EsmPooler,
37
+ EsmPreTrainedModel,
38
+ EsmSelfAttention,
39
+ EsmSelfOutput,
40
+ )
41
+
42
+ try:
43
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
44
+ except (ImportError, AttributeError):
45
+ create_block_mask = None
46
+ flex_attention = None
47
+
48
+ try:
49
+ from .base_tokenizer import BaseSequenceTokenizer
50
+ except ImportError:
51
+ from base_tokenizer import BaseSequenceTokenizer
52
+
53
+ from embedding_mixin import EmbeddingMixin
54
+
55
+
56
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
57
+ assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
58
+ token_valid = attention_mask_2d.bool()
59
+ batch_size, seq_len = token_valid.shape
60
+
61
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
62
+ return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
63
+
64
+ return create_block_mask(
65
+ mask_mod,
66
+ batch_size,
67
+ 1,
68
+ seq_len,
69
+ seq_len,
70
+ device=attention_mask_2d.device,
71
+ )
72
+
73
+
74
+ @dataclass
75
+ class DPLMMaskedLMOutput(ModelOutput):
76
+ loss: Optional[torch.Tensor] = None
77
+ logits: Optional[torch.Tensor] = None
78
+ last_hidden_state: Optional[torch.Tensor] = None
79
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
80
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
81
+
82
+
83
+ class DPLMConfig(EsmConfig):
84
+ model_type = "dplm"
85
+
86
+ def __init__(
87
+ self,
88
+ attn_backend: str = "sdpa",
89
+ **kwargs,
90
+ ):
91
+ super().__init__(**kwargs)
92
+ self.attn_backend = attn_backend
93
+ self.tie_word_embeddings = False
94
+
95
+
96
+ class DPLMPreTrainedModel(EsmPreTrainedModel):
97
+ config_class = DPLMConfig
98
+ base_model_prefix = "dplm"
99
+ supports_gradient_checkpointing = True
100
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
101
+ all_tied_weights_keys = {}
102
+
103
+
104
+ class ModifiedEsmSelfAttention(EsmSelfAttention):
105
+ def __init__(self, config, position_embedding_type=None):
106
+ super().__init__(config, position_embedding_type)
107
+ self.attn_backend = config.attn_backend
108
+
109
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
110
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
111
+ x = x.view(new_x_shape)
112
+ return x.permute(0, 2, 1, 3)
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ attention_mask: Optional[torch.FloatTensor] = None,
118
+ head_mask: Optional[torch.FloatTensor] = None,
119
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
120
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
121
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
122
+ output_attentions: Optional[bool] = False,
123
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
124
+ flex_block_mask: Optional[object] = None,
125
+ **kwargs,
126
+ ) -> Tuple[torch.Tensor]:
127
+ if past_key_values is not None:
128
+ past_key_value = past_key_values
129
+
130
+ mixed_query_layer = self.query(hidden_states)
131
+ is_cross_attention = encoder_hidden_states is not None
132
+
133
+ if is_cross_attention and past_key_value is not None:
134
+ key_layer = past_key_value[0]
135
+ value_layer = past_key_value[1]
136
+ attention_mask = encoder_attention_mask
137
+ elif is_cross_attention:
138
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
139
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
140
+ attention_mask = encoder_attention_mask
141
+ elif past_key_value is not None:
142
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
143
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
144
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
145
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
146
+ else:
147
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
148
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
149
+
150
+ query_layer = self.transpose_for_scores(mixed_query_layer) * self.attention_head_size**-0.5
151
+
152
+ if self.is_decoder:
153
+ past_key_value = (key_layer, value_layer)
154
+
155
+ if self.position_embedding_type == "rotary":
156
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
157
+
158
+ if self.position_embedding_type in ["relative_key", "relative_key_query"]:
159
+ raise NotImplementedError
160
+
161
+ query_layer = query_layer.contiguous()
162
+ key_layer = key_layer.contiguous()
163
+ value_layer = value_layer.contiguous()
164
+
165
+ if output_attentions:
166
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
167
+ if attention_mask is not None:
168
+ attention_scores = attention_scores + attention_mask
169
+ attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
170
+ context_layer = torch.matmul(attention_probs, value_layer)
171
+ else:
172
+ attention_probs = None
173
+ if self.attn_backend == "flex":
174
+ assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
175
+ assert query_layer.dtype in (torch.float16, torch.bfloat16), (
176
+ f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
177
+ )
178
+ assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention."
179
+ assert past_key_value is None, "Flex attention backend currently does not support KV caching."
180
+ if attention_mask is not None:
181
+ assert flex_block_mask is not None, (
182
+ "Flex attention backend requires a block mask when attention_mask is provided."
183
+ )
184
+ context_layer = flex_attention(
185
+ query_layer,
186
+ key_layer,
187
+ value_layer,
188
+ block_mask=flex_block_mask,
189
+ scale=1.0,
190
+ )
191
+ else:
192
+ context_layer = F.scaled_dot_product_attention(
193
+ query_layer,
194
+ key_layer,
195
+ value_layer,
196
+ attn_mask=attention_mask,
197
+ scale=1.0,
198
+ )
199
+
200
+ if head_mask is not None and torch.is_tensor(head_mask):
201
+ context_layer = context_layer * head_mask
202
+
203
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
204
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
205
+ context_layer = context_layer.view(new_context_layer_shape)
206
+
207
+ outputs = (context_layer, attention_probs)
208
+ if self.is_decoder:
209
+ outputs = outputs + (past_key_value,)
210
+ return outputs
211
+
212
+
213
+ class ModifiedEsmAttention(EsmAttention):
214
+ def __init__(self, config):
215
+ nn.Module.__init__(self)
216
+ self.self = ModifiedEsmSelfAttention(config)
217
+ self.output = EsmSelfOutput(config)
218
+ self.pruned_heads = set()
219
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
220
+
221
+ def forward(
222
+ self,
223
+ hidden_states,
224
+ attention_mask=None,
225
+ head_mask=None,
226
+ encoder_hidden_states=None,
227
+ encoder_attention_mask=None,
228
+ past_key_value=None,
229
+ output_attentions=False,
230
+ flex_block_mask=None,
231
+ ):
232
+ hidden_states_ln = self.LayerNorm(hidden_states)
233
+ self_outputs = self.self(
234
+ hidden_states_ln,
235
+ attention_mask,
236
+ head_mask,
237
+ encoder_hidden_states,
238
+ encoder_attention_mask,
239
+ past_key_value,
240
+ output_attentions,
241
+ flex_block_mask=flex_block_mask,
242
+ )
243
+ attention_output = self.output(self_outputs[0], hidden_states)
244
+ outputs = (attention_output,) + self_outputs[1:]
245
+ return outputs
246
+
247
+
248
+ class ModifiedEsmLayer(EsmLayer):
249
+ def __init__(self, config):
250
+ nn.Module.__init__(self)
251
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
252
+ self.seq_len_dim = 1
253
+ self.attention = ModifiedEsmAttention(config)
254
+ self.is_decoder = config.is_decoder
255
+ self.add_cross_attention = config.add_cross_attention
256
+ if self.add_cross_attention:
257
+ if self.is_decoder is False:
258
+ raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
259
+ self.crossattention = ModifiedEsmAttention(config)
260
+ self.intermediate = EsmIntermediate(config)
261
+ self.output = EsmOutput(config)
262
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states,
267
+ attention_mask=None,
268
+ head_mask=None,
269
+ encoder_hidden_states=None,
270
+ encoder_attention_mask=None,
271
+ past_key_value=None,
272
+ output_attentions=False,
273
+ flex_block_mask=None,
274
+ ):
275
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
276
+ self_attention_outputs = self.attention(
277
+ hidden_states,
278
+ attention_mask,
279
+ head_mask,
280
+ output_attentions=output_attentions,
281
+ past_key_value=self_attn_past_key_value,
282
+ flex_block_mask=flex_block_mask,
283
+ )
284
+ attention_output = self_attention_outputs[0]
285
+
286
+ if self.is_decoder:
287
+ outputs = self_attention_outputs[1:-1]
288
+ present_key_value = self_attention_outputs[-1]
289
+ else:
290
+ outputs = self_attention_outputs[1:]
291
+
292
+ if self.is_decoder and encoder_hidden_states is not None:
293
+ if self.add_cross_attention is False:
294
+ raise AttributeError(
295
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention "
296
+ "layers by setting `config.add_cross_attention=True`"
297
+ )
298
+
299
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
300
+ cross_attention_outputs = self.crossattention(
301
+ attention_output,
302
+ attention_mask,
303
+ head_mask,
304
+ encoder_hidden_states,
305
+ encoder_attention_mask,
306
+ cross_attn_past_key_value,
307
+ output_attentions,
308
+ flex_block_mask=None,
309
+ )
310
+ attention_output = cross_attention_outputs[0]
311
+ outputs = outputs + cross_attention_outputs[1:-1]
312
+ present_key_value = present_key_value + cross_attention_outputs[-1]
313
+
314
+ layer_output = self.feed_forward_chunk(attention_output)
315
+ outputs = (layer_output,) + outputs
316
+
317
+ if self.is_decoder:
318
+ outputs = outputs + (present_key_value,)
319
+ return outputs
320
+
321
+
322
+ class ModifiedEsmEncoder(EsmEncoder):
323
+ def __init__(self, config):
324
+ nn.Module.__init__(self)
325
+ self.config = config
326
+ self.layer = nn.ModuleList([ModifiedEsmLayer(config) for _ in range(config.num_hidden_layers)])
327
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
328
+ self.gradient_checkpointing = False
329
+
330
+ def forward(
331
+ self,
332
+ hidden_states,
333
+ attention_mask=None,
334
+ head_mask=None,
335
+ encoder_hidden_states=None,
336
+ encoder_attention_mask=None,
337
+ past_key_values=None,
338
+ use_cache=None,
339
+ output_attentions=False,
340
+ output_hidden_states=False,
341
+ return_dict=True,
342
+ flex_block_mask=None,
343
+ ):
344
+ all_hidden_states = () if output_hidden_states else None
345
+ all_self_attentions = () if output_attentions else None
346
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
347
+ next_decoder_cache = () if use_cache else None
348
+
349
+ for i, layer_module in enumerate(self.layer):
350
+ if output_hidden_states:
351
+ all_hidden_states = all_hidden_states + (hidden_states,)
352
+
353
+ layer_head_mask = head_mask[i] if head_mask is not None else None
354
+ past_key_value = past_key_values[i] if past_key_values is not None else None
355
+
356
+ if self.gradient_checkpointing and self.training:
357
+ layer_outputs = self._gradient_checkpointing_func(
358
+ layer_module.__call__,
359
+ hidden_states,
360
+ attention_mask,
361
+ layer_head_mask,
362
+ encoder_hidden_states,
363
+ encoder_attention_mask,
364
+ past_key_value,
365
+ output_attentions,
366
+ flex_block_mask,
367
+ )
368
+ else:
369
+ layer_outputs = layer_module(
370
+ hidden_states,
371
+ attention_mask,
372
+ layer_head_mask,
373
+ encoder_hidden_states,
374
+ encoder_attention_mask,
375
+ past_key_value,
376
+ output_attentions,
377
+ flex_block_mask,
378
+ )
379
+
380
+ hidden_states = layer_outputs[0]
381
+ if use_cache:
382
+ next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
383
+ if output_attentions:
384
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
385
+ if self.config.add_cross_attention:
386
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
387
+
388
+ if self.emb_layer_norm_after:
389
+ hidden_states = self.emb_layer_norm_after(hidden_states)
390
+
391
+ if output_hidden_states:
392
+ all_hidden_states = all_hidden_states + (hidden_states,)
393
+
394
+ if return_dict is False:
395
+ return tuple(
396
+ value
397
+ for value in [
398
+ hidden_states,
399
+ next_decoder_cache,
400
+ all_hidden_states,
401
+ all_self_attentions,
402
+ all_cross_attentions,
403
+ ]
404
+ if value is not None
405
+ )
406
+
407
+ return BaseModelOutputWithPastAndCrossAttentions(
408
+ last_hidden_state=hidden_states,
409
+ past_key_values=next_decoder_cache,
410
+ hidden_states=all_hidden_states,
411
+ attentions=all_self_attentions,
412
+ cross_attentions=all_cross_attentions,
413
+ )
414
+
415
+
416
+ class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
417
+ config_class = DPLMConfig
418
+
419
+ def get_input_embeddings(self) -> nn.Module:
420
+ return self.embeddings.word_embeddings
421
+
422
+ def __init__(self, config, add_pooling_layer=True):
423
+ DPLMPreTrainedModel.__init__(self, config)
424
+ self.config = config
425
+ self.embeddings = EsmEmbeddings(config)
426
+ self.encoder = ModifiedEsmEncoder(config)
427
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
428
+ self.contact_head = EsmContactPredictionHead(
429
+ in_features=config.num_hidden_layers * config.num_attention_heads,
430
+ bias=True,
431
+ )
432
+ self.post_init()
433
+
434
+ def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
435
+ if head_mask.dim() == 1:
436
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
437
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
438
+ elif head_mask.dim() == 2:
439
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
440
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
441
+ head_mask = head_mask.to(dtype=self.dtype)
442
+ return head_mask
443
+
444
+ def get_head_mask(
445
+ self,
446
+ head_mask: Optional[torch.Tensor],
447
+ num_hidden_layers: int,
448
+ is_attention_chunked: bool = False,
449
+ ) -> Union[torch.Tensor, List[None]]:
450
+ if head_mask is None:
451
+ return [None] * num_hidden_layers
452
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
453
+ if is_attention_chunked:
454
+ head_mask = head_mask.unsqueeze(-1)
455
+ return head_mask
456
+
457
+ def set_input_embeddings(self, value):
458
+ self.embeddings.word_embeddings = value
459
+
460
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
461
+ if attention_mask is None:
462
+ attention_mask = input_ids.ne(self.config.pad_token_id)
463
+ outputs = self(
464
+ input_ids=input_ids,
465
+ attention_mask=attention_mask,
466
+ output_hidden_states=False,
467
+ output_attentions=False,
468
+ return_dict=True,
469
+ )
470
+ return outputs.last_hidden_state
471
+
472
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
473
+ attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
474
+ attns = torch.stack(attns, dim=1)
475
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
476
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
477
+ return self.contact_head(input_ids, attns)
478
+
479
+ def forward(
480
+ self,
481
+ input_ids: Optional[torch.Tensor] = None,
482
+ attention_mask: Optional[torch.Tensor] = None,
483
+ position_ids: Optional[torch.Tensor] = None,
484
+ head_mask: Optional[torch.Tensor] = None,
485
+ inputs_embeds: Optional[torch.Tensor] = None,
486
+ encoder_hidden_states: Optional[torch.Tensor] = None,
487
+ encoder_attention_mask: Optional[torch.Tensor] = None,
488
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
489
+ use_cache: Optional[bool] = None,
490
+ output_attentions: Optional[bool] = None,
491
+ output_hidden_states: Optional[bool] = None,
492
+ return_dict: Optional[bool] = None,
493
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
494
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
495
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
496
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
497
+
498
+ if self.config.is_decoder:
499
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
500
+ else:
501
+ use_cache = False
502
+
503
+ if input_ids is not None and inputs_embeds is not None:
504
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
505
+ if input_ids is not None:
506
+ input_shape = input_ids.size()
507
+ elif inputs_embeds is not None:
508
+ input_shape = inputs_embeds.size()[:-1]
509
+ else:
510
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
511
+
512
+ batch_size, seq_length = input_shape
513
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
514
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
515
+
516
+ if attention_mask is None:
517
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
518
+
519
+ token_attention_mask = None
520
+ if attention_mask.dim() == 2:
521
+ token_attention_mask = attention_mask.bool()
522
+ if self.config.attn_backend == "flex" and output_attentions is False:
523
+ extended_attention_mask = None
524
+ else:
525
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
526
+ elif attention_mask.dim() == 4:
527
+ if self.config.attn_backend == "flex" and output_attentions is False:
528
+ extended_attention_mask = None
529
+ else:
530
+ extended_attention_mask = attention_mask
531
+ if input_ids is not None:
532
+ token_attention_mask = input_ids.ne(self.config.pad_token_id)
533
+ else:
534
+ raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
535
+
536
+ if self.config.is_decoder and encoder_hidden_states is not None:
537
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
538
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
539
+ if encoder_attention_mask is None:
540
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
541
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
542
+ else:
543
+ encoder_extended_attention_mask = encoder_attention_mask
544
+
545
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
546
+
547
+ embedding_attention_mask = token_attention_mask
548
+ if embedding_attention_mask is None and input_ids is not None:
549
+ embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
550
+
551
+ flex_block_mask = None
552
+ if (
553
+ self.config.attn_backend == "flex"
554
+ and token_attention_mask is not None
555
+ and output_attentions is False
556
+ ):
557
+ assert create_block_mask is not None, (
558
+ "Flex attention backend requested but torch.create_block_mask is unavailable."
559
+ )
560
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
561
+
562
+ embedding_output = self.embeddings(
563
+ input_ids=input_ids,
564
+ position_ids=position_ids,
565
+ attention_mask=embedding_attention_mask,
566
+ inputs_embeds=inputs_embeds,
567
+ )
568
+ encoder_outputs = self.encoder(
569
+ embedding_output,
570
+ attention_mask=extended_attention_mask,
571
+ head_mask=head_mask,
572
+ encoder_hidden_states=encoder_hidden_states,
573
+ encoder_attention_mask=encoder_extended_attention_mask,
574
+ past_key_values=past_key_values,
575
+ use_cache=use_cache,
576
+ output_attentions=output_attentions,
577
+ output_hidden_states=output_hidden_states,
578
+ return_dict=return_dict,
579
+ flex_block_mask=flex_block_mask,
580
+ )
581
+ sequence_output = encoder_outputs[0]
582
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
583
+
584
+ if return_dict is False:
585
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
586
+
587
+ return BaseModelOutputWithPoolingAndCrossAttentions(
588
+ last_hidden_state=sequence_output,
589
+ pooler_output=pooled_output,
590
+ past_key_values=None,
591
+ hidden_states=encoder_outputs.hidden_states,
592
+ attentions=encoder_outputs.attentions,
593
+ cross_attentions=encoder_outputs.cross_attentions,
594
+ )
595
+
596
+
597
+ class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
598
+ config_class = DPLMConfig
599
+
600
+ def __init__(self, config, dropout: float = 0.1):
601
+ config.hidden_dropout_prob = dropout
602
+ DPLMPreTrainedModel.__init__(self, config)
603
+ self.esm = DPLMModel(config, add_pooling_layer=False)
604
+ self.lm_head = EsmLMHead(config)
605
+ self.loss_fct = nn.CrossEntropyLoss()
606
+ self.post_init()
607
+
608
+ self.tokenizer = self.__class__.tokenizer
609
+ if isinstance(config._name_or_path, str) and len(config._name_or_path) > 0:
610
+ try:
611
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
612
+ except Exception:
613
+ self.tokenizer = self.__class__.tokenizer
614
+
615
+ self.mask_id = self.tokenizer.mask_token_id
616
+ self.pad_id = self.tokenizer.pad_token_id
617
+ self.bos_id = self.tokenizer.cls_token_id
618
+ self.eos_id = self.tokenizer.eos_token_id
619
+ self.x_id = self.tokenizer.convert_tokens_to_ids("X")
620
+ self.contact_head = None
621
+
622
+ def get_input_embeddings(self) -> nn.Module:
623
+ return self.esm.embeddings.word_embeddings
624
+
625
+ def get_output_embeddings(self):
626
+ return self.lm_head.decoder
627
+
628
+ def set_output_embeddings(self, new_embeddings):
629
+ self.lm_head.decoder = new_embeddings
630
+
631
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
632
+ return self.esm._embed(input_ids, attention_mask)
633
+
634
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
635
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
636
+
637
+ def forward(
638
+ self,
639
+ input_ids: Optional[torch.Tensor] = None,
640
+ attention_mask: Optional[torch.Tensor] = None,
641
+ inputs_embeds: Optional[torch.Tensor] = None,
642
+ decoder_input_ids: Optional[torch.Tensor] = None,
643
+ decoder_attention_mask: Optional[torch.Tensor] = None,
644
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
645
+ labels: Optional[torch.Tensor] = None,
646
+ output_attentions: Optional[bool] = None,
647
+ output_hidden_states: Optional[bool] = None,
648
+ return_dict: Optional[bool] = None,
649
+ encoder_hidden_states: Optional[torch.Tensor] = None,
650
+ encoder_attention_mask: Optional[torch.Tensor] = None,
651
+ ) -> Union[Tuple[torch.Tensor], DPLMMaskedLMOutput]:
652
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
653
+ if attention_mask is None and input_ids is not None:
654
+ attention_mask = input_ids.ne(self.pad_id)
655
+
656
+ outputs = self.esm(
657
+ input_ids=input_ids,
658
+ attention_mask=attention_mask,
659
+ inputs_embeds=inputs_embeds,
660
+ encoder_hidden_states=encoder_hidden_states,
661
+ encoder_attention_mask=encoder_attention_mask,
662
+ output_attentions=output_attentions,
663
+ output_hidden_states=output_hidden_states,
664
+ return_dict=True,
665
+ )
666
+ sequence_output = outputs.last_hidden_state
667
+ logits = self.lm_head(sequence_output)
668
+
669
+ loss = None
670
+ if labels is not None:
671
+ labels = labels.to(logits.device)
672
+ loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
673
+
674
+ if return_dict is False:
675
+ output = (logits, sequence_output, outputs.hidden_states, outputs.attentions)
676
+ if loss is not None:
677
+ return (loss,) + output
678
+ return output
679
+
680
+ return DPLMMaskedLMOutput(
681
+ loss=loss,
682
+ logits=logits,
683
+ last_hidden_state=sequence_output,
684
+ hidden_states=outputs.hidden_states,
685
+ attentions=outputs.attentions,
686
+ )
687
+
688
+
689
+ class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin):
690
+ config_class = DPLMConfig
691
+
692
+ def get_input_embeddings(self) -> nn.Module:
693
+ return self.esm.embeddings.word_embeddings
694
+
695
+ def __init__(self, config):
696
+ DPLMPreTrainedModel.__init__(self, config)
697
+ self.num_labels = config.num_labels
698
+ self.esm = DPLMModel(config, add_pooling_layer=False)
699
+ self.classifier = EsmClassificationHead(config)
700
+ self.mse = nn.MSELoss()
701
+ self.ce = nn.CrossEntropyLoss()
702
+ self.bce = nn.BCEWithLogitsLoss()
703
+ self.post_init()
704
+
705
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
706
+ return self.esm._embed(input_ids, attention_mask)
707
+
708
+ def forward(
709
+ self,
710
+ input_ids: Optional[torch.Tensor] = None,
711
+ attention_mask: Optional[torch.Tensor] = None,
712
+ inputs_embeds: Optional[torch.Tensor] = None,
713
+ labels: Optional[torch.Tensor] = None,
714
+ output_attentions: Optional[bool] = None,
715
+ output_hidden_states: Optional[bool] = None,
716
+ return_dict: Optional[bool] = None,
717
+ **kwargs,
718
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
719
+ outputs = self.esm(
720
+ input_ids=input_ids,
721
+ attention_mask=attention_mask,
722
+ inputs_embeds=inputs_embeds,
723
+ output_attentions=output_attentions,
724
+ output_hidden_states=output_hidden_states,
725
+ return_dict=True,
726
+ )
727
+ sequence_output = outputs.last_hidden_state
728
+ logits = self.classifier(sequence_output)
729
+
730
+ loss = None
731
+ if labels is not None:
732
+ labels = labels.to(logits.device)
733
+ if self.config.problem_type is None:
734
+ if self.num_labels == 1:
735
+ self.config.problem_type = "regression"
736
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
737
+ self.config.problem_type = "single_label_classification"
738
+ else:
739
+ self.config.problem_type = "multi_label_classification"
740
+
741
+ if self.config.problem_type == "regression":
742
+ if self.num_labels == 1:
743
+ loss = self.mse(logits.squeeze(), labels.squeeze())
744
+ else:
745
+ loss = self.mse(logits, labels)
746
+ elif self.config.problem_type == "single_label_classification":
747
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
748
+ elif self.config.problem_type == "multi_label_classification":
749
+ loss = self.bce(logits, labels)
750
+
751
+ return SequenceClassifierOutput(
752
+ loss=loss,
753
+ logits=logits,
754
+ hidden_states=outputs.hidden_states,
755
+ attentions=outputs.attentions,
756
+ )
757
+
758
+
759
+ class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin):
760
+ config_class = DPLMConfig
761
+
762
+ def get_input_embeddings(self) -> nn.Module:
763
+ return self.esm.embeddings.word_embeddings
764
+
765
+ def __init__(self, config):
766
+ DPLMPreTrainedModel.__init__(self, config)
767
+ self.num_labels = config.num_labels
768
+ self.esm = DPLMModel(config, add_pooling_layer=False)
769
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
770
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
771
+ self.loss_fct = nn.CrossEntropyLoss()
772
+ self.post_init()
773
+
774
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
775
+ return self.esm._embed(input_ids, attention_mask)
776
+
777
+ def forward(
778
+ self,
779
+ input_ids: Optional[torch.Tensor] = None,
780
+ attention_mask: Optional[torch.Tensor] = None,
781
+ inputs_embeds: Optional[torch.Tensor] = None,
782
+ labels: Optional[torch.Tensor] = None,
783
+ output_attentions: Optional[bool] = None,
784
+ output_hidden_states: Optional[bool] = None,
785
+ return_dict: Optional[bool] = None,
786
+ **kwargs,
787
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
788
+ outputs = self.esm(
789
+ input_ids=input_ids,
790
+ attention_mask=attention_mask,
791
+ inputs_embeds=inputs_embeds,
792
+ output_attentions=output_attentions,
793
+ output_hidden_states=output_hidden_states,
794
+ return_dict=True,
795
+ )
796
+ sequence_output = self.dropout(outputs.last_hidden_state)
797
+ logits = self.classifier(sequence_output)
798
+
799
+ loss = None
800
+ if labels is not None:
801
+ labels = labels.to(logits.device)
802
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
803
+
804
+ return TokenClassifierOutput(
805
+ loss=loss,
806
+ logits=logits,
807
+ hidden_states=outputs.hidden_states,
808
+ attentions=outputs.attentions,
809
+ )