lhallee commited on
Commit
ab0be26
·
verified ·
1 Parent(s): aee957f

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +918 -0
modeling_dplm2.py ADDED
@@ -0,0 +1,918 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastPLMs-compatible DPLM2 implementation.
3
+
4
+ This module is based on:
5
+ https://github.com/bytedance/dplm
6
+ """
7
+
8
+ import entrypoint_setup
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from dataclasses import dataclass
13
+ from typing import Dict, List, Optional, Tuple, Union
14
+
15
+ from transformers import EsmTokenizer
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutputWithPastAndCrossAttentions,
18
+ BaseModelOutputWithPoolingAndCrossAttentions,
19
+ ModelOutput,
20
+ SequenceClassifierOutput,
21
+ TokenClassifierOutput,
22
+ )
23
+ from transformers.models.esm.configuration_esm import EsmConfig
24
+ from transformers.models.esm.modeling_esm import (
25
+ EsmAttention,
26
+ EsmClassificationHead,
27
+ EsmEmbeddings,
28
+ EsmEncoder,
29
+ EsmIntermediate,
30
+ EsmLayer,
31
+ EsmLMHead,
32
+ EsmOutput,
33
+ EsmPooler,
34
+ EsmPreTrainedModel,
35
+ EsmSelfAttention,
36
+ EsmSelfOutput,
37
+ RotaryEmbedding,
38
+ apply_rotary_pos_emb,
39
+ )
40
+
41
+ try:
42
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
43
+ except (ImportError, AttributeError):
44
+ create_block_mask = None
45
+ flex_attention = None
46
+
47
+ try:
48
+ from .base_tokenizer import BaseSequenceTokenizer
49
+ except ImportError:
50
+ from base_tokenizer import BaseSequenceTokenizer
51
+
52
+ from embedding_mixin import EmbeddingMixin
53
+
54
+
55
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
56
+ assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
57
+ token_valid = attention_mask_2d.bool()
58
+ batch_size, seq_len = token_valid.shape
59
+
60
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
61
+ return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
62
+
63
+ return create_block_mask(
64
+ mask_mod,
65
+ batch_size,
66
+ 1,
67
+ seq_len,
68
+ seq_len,
69
+ device=attention_mask_2d.device,
70
+ )
71
+
72
+
73
+ def _infer_modality_type(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
74
+ input_mask = attention_mask.bool()
75
+ modality_type = ((input_ids < 33) & input_mask).int()
76
+ modality_type[~input_mask] = 2
77
+ return modality_type
78
+
79
+
80
+ @dataclass
81
+ class DPLM2MaskedLMOutput(ModelOutput):
82
+ loss: Optional[torch.Tensor] = None
83
+ logits: Optional[torch.Tensor] = None
84
+ last_hidden_state: Optional[torch.Tensor] = None
85
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
86
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
87
+
88
+
89
+ class DPLM2Config(EsmConfig):
90
+ model_type = "dplm2"
91
+
92
+ def __init__(
93
+ self,
94
+ attn_backend: str = "sdpa",
95
+ aa_type: int = 1,
96
+ struct_type: int = 0,
97
+ pad_type: int = 2,
98
+ **kwargs,
99
+ ):
100
+ super().__init__(**kwargs)
101
+ self.attn_backend = attn_backend
102
+ self.aa_type = aa_type
103
+ self.struct_type = struct_type
104
+ self.pad_type = pad_type
105
+ self.tie_word_embeddings = False
106
+
107
+
108
+ class DPLM2PreTrainedModel(EsmPreTrainedModel):
109
+ config_class = DPLM2Config
110
+ base_model_prefix = "dplm2"
111
+ supports_gradient_checkpointing = True
112
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
113
+ all_tied_weights_keys = {}
114
+
115
+
116
+
117
+ class ModifiedRotaryEmbedding(RotaryEmbedding):
118
+ def __init__(self, dim: int, aa_type: int, struct_type: int):
119
+ super().__init__(dim)
120
+ self.aa_type = aa_type
121
+ self.struct_type = struct_type
122
+
123
+ def _has_multimodal_tokens(self, type_ids: Optional[torch.Tensor]) -> bool:
124
+ if type_ids is None:
125
+ return False
126
+ aa_present = (type_ids == self.aa_type).any()
127
+ struct_present = (type_ids == self.struct_type).any()
128
+ return bool(aa_present and struct_present)
129
+
130
+ def _update_cos_sin_tables(
131
+ self,
132
+ x: torch.Tensor,
133
+ type_ids: Optional[torch.Tensor],
134
+ seq_dimension: int = 2,
135
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
136
+ seq_len = x.shape[seq_dimension]
137
+ if self._has_multimodal_tokens(type_ids):
138
+ seq_len = seq_len // 2
139
+
140
+ cache_is_stale = (
141
+ self._cos_cached is None
142
+ or self._sin_cached is None
143
+ or seq_len != self._seq_len_cached
144
+ or self._cos_cached.device != x.device
145
+ )
146
+ if cache_is_stale:
147
+ self._seq_len_cached = seq_len
148
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
149
+ freqs = torch.outer(t, self.inv_freq)
150
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
151
+ self._cos_cached = emb.cos()[None, None, :, :]
152
+ self._sin_cached = emb.sin()[None, None, :, :]
153
+
154
+ return self._cos_cached, self._sin_cached
155
+
156
+ def forward(
157
+ self,
158
+ q: torch.Tensor,
159
+ k: torch.Tensor,
160
+ type_ids: Optional[torch.Tensor],
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
163
+ k,
164
+ type_ids=type_ids,
165
+ seq_dimension=-2,
166
+ )
167
+
168
+ if self._has_multimodal_tokens(type_ids):
169
+ q_1, q_2 = q.chunk(2, dim=-2)
170
+ k_1, k_2 = k.chunk(2, dim=-2)
171
+ q_1 = apply_rotary_pos_emb(q_1, self._cos_cached, self._sin_cached)
172
+ q_2 = apply_rotary_pos_emb(q_2, self._cos_cached, self._sin_cached)
173
+ k_1 = apply_rotary_pos_emb(k_1, self._cos_cached, self._sin_cached)
174
+ k_2 = apply_rotary_pos_emb(k_2, self._cos_cached, self._sin_cached)
175
+ return torch.cat((q_1, q_2), dim=-2), torch.cat((k_1, k_2), dim=-2)
176
+
177
+ return (
178
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
179
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
180
+ )
181
+
182
+
183
+ class ModifiedEsmSelfAttention(EsmSelfAttention):
184
+ def __init__(self, config, position_embedding_type=None):
185
+ super().__init__(config, position_embedding_type)
186
+ self.attn_backend = config.attn_backend
187
+ self.rotary_embeddings = ModifiedRotaryEmbedding(
188
+ dim=self.attention_head_size,
189
+ aa_type=config.aa_type,
190
+ struct_type=config.struct_type,
191
+ )
192
+
193
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
194
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
195
+ x = x.view(new_x_shape)
196
+ return x.permute(0, 2, 1, 3)
197
+
198
+ def forward(
199
+ self,
200
+ hidden_states: torch.Tensor,
201
+ attention_mask: Optional[torch.FloatTensor] = None,
202
+ head_mask: Optional[torch.FloatTensor] = None,
203
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
204
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
205
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
206
+ output_attentions: Optional[bool] = False,
207
+ type_ids: Optional[torch.Tensor] = None,
208
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
209
+ flex_block_mask: Optional[object] = None,
210
+ **kwargs,
211
+ ) -> Tuple[torch.Tensor]:
212
+ if past_key_values is not None:
213
+ past_key_value = past_key_values
214
+
215
+ mixed_query_layer = self.query(hidden_states)
216
+ is_cross_attention = encoder_hidden_states is not None
217
+
218
+ if is_cross_attention and past_key_value is not None:
219
+ key_layer = past_key_value[0]
220
+ value_layer = past_key_value[1]
221
+ attention_mask = encoder_attention_mask
222
+ elif is_cross_attention:
223
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
224
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
225
+ attention_mask = encoder_attention_mask
226
+ elif past_key_value is not None:
227
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
228
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
229
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
230
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
231
+ else:
232
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
233
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
234
+
235
+ query_layer = self.transpose_for_scores(mixed_query_layer) * self.attention_head_size**-0.5
236
+
237
+ if self.is_decoder:
238
+ past_key_value = (key_layer, value_layer)
239
+
240
+ if self.position_embedding_type == "rotary":
241
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer, type_ids)
242
+
243
+ if self.position_embedding_type in ["relative_key", "relative_key_query"]:
244
+ raise NotImplementedError
245
+
246
+ query_layer = query_layer.contiguous()
247
+ key_layer = key_layer.contiguous()
248
+ value_layer = value_layer.contiguous()
249
+
250
+ if output_attentions:
251
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
252
+ if attention_mask is not None:
253
+ attention_scores = attention_scores + attention_mask
254
+ attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
255
+ context_layer = torch.matmul(attention_probs, value_layer)
256
+ else:
257
+ attention_probs = None
258
+ if self.attn_backend == "flex":
259
+ assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
260
+ assert query_layer.dtype in (torch.float16, torch.bfloat16), (
261
+ f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
262
+ )
263
+ assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention."
264
+ assert past_key_value is None, "Flex attention backend currently does not support KV caching."
265
+ if attention_mask is not None:
266
+ assert flex_block_mask is not None, (
267
+ "Flex attention backend requires a block mask when attention_mask is provided."
268
+ )
269
+ context_layer = flex_attention(
270
+ query_layer,
271
+ key_layer,
272
+ value_layer,
273
+ block_mask=flex_block_mask,
274
+ scale=1.0,
275
+ )
276
+ else:
277
+ context_layer = F.scaled_dot_product_attention(
278
+ query_layer,
279
+ key_layer,
280
+ value_layer,
281
+ attn_mask=attention_mask,
282
+ scale=1.0,
283
+ )
284
+
285
+ if head_mask is not None and torch.is_tensor(head_mask):
286
+ context_layer = context_layer * head_mask
287
+
288
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
289
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
290
+ context_layer = context_layer.view(new_context_layer_shape)
291
+
292
+ outputs = (context_layer, attention_probs)
293
+ if self.is_decoder:
294
+ outputs = outputs + (past_key_value,)
295
+ return outputs
296
+
297
+
298
+ class ModifiedEsmAttention(EsmAttention):
299
+ def __init__(self, config):
300
+ nn.Module.__init__(self)
301
+ self.self = ModifiedEsmSelfAttention(config)
302
+ self.output = EsmSelfOutput(config)
303
+ self.pruned_heads = set()
304
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states,
309
+ attention_mask=None,
310
+ head_mask=None,
311
+ encoder_hidden_states=None,
312
+ encoder_attention_mask=None,
313
+ past_key_value=None,
314
+ output_attentions=False,
315
+ type_ids=None,
316
+ flex_block_mask=None,
317
+ ):
318
+ hidden_states_ln = self.LayerNorm(hidden_states)
319
+ self_outputs = self.self(
320
+ hidden_states_ln,
321
+ attention_mask,
322
+ head_mask,
323
+ encoder_hidden_states,
324
+ encoder_attention_mask,
325
+ past_key_value,
326
+ output_attentions,
327
+ type_ids,
328
+ flex_block_mask=flex_block_mask,
329
+ )
330
+ attention_output = self.output(self_outputs[0], hidden_states)
331
+ outputs = (attention_output,) + self_outputs[1:]
332
+ return outputs
333
+
334
+
335
+ class ModifiedEsmLayer(EsmLayer):
336
+ def __init__(self, config):
337
+ nn.Module.__init__(self)
338
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
339
+ self.seq_len_dim = 1
340
+ self.attention = ModifiedEsmAttention(config)
341
+ self.is_decoder = config.is_decoder
342
+ self.add_cross_attention = config.add_cross_attention
343
+ if self.add_cross_attention:
344
+ if self.is_decoder is False:
345
+ raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
346
+ self.crossattention = ModifiedEsmAttention(config)
347
+ self.intermediate = EsmIntermediate(config)
348
+ self.output = EsmOutput(config)
349
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ attention_mask=None,
355
+ head_mask=None,
356
+ encoder_hidden_states=None,
357
+ encoder_attention_mask=None,
358
+ past_key_value=None,
359
+ output_attentions=False,
360
+ type_ids=None,
361
+ flex_block_mask=None,
362
+ ):
363
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
364
+ self_attention_outputs = self.attention(
365
+ hidden_states,
366
+ attention_mask,
367
+ head_mask,
368
+ output_attentions=output_attentions,
369
+ past_key_value=self_attn_past_key_value,
370
+ type_ids=type_ids,
371
+ flex_block_mask=flex_block_mask,
372
+ )
373
+ attention_output = self_attention_outputs[0]
374
+
375
+ if self.is_decoder:
376
+ outputs = self_attention_outputs[1:-1]
377
+ present_key_value = self_attention_outputs[-1]
378
+ else:
379
+ outputs = self_attention_outputs[1:]
380
+
381
+ if self.is_decoder and encoder_hidden_states is not None:
382
+ if self.add_cross_attention is False:
383
+ raise AttributeError(
384
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention "
385
+ "layers by setting `config.add_cross_attention=True`"
386
+ )
387
+
388
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
389
+ cross_attention_outputs = self.crossattention(
390
+ attention_output,
391
+ attention_mask,
392
+ head_mask,
393
+ encoder_hidden_states,
394
+ encoder_attention_mask,
395
+ cross_attn_past_key_value,
396
+ output_attentions,
397
+ type_ids=None,
398
+ flex_block_mask=None,
399
+ )
400
+ attention_output = cross_attention_outputs[0]
401
+ outputs = outputs + cross_attention_outputs[1:-1]
402
+ present_key_value = present_key_value + cross_attention_outputs[-1]
403
+
404
+ layer_output = self.feed_forward_chunk(attention_output)
405
+ outputs = (layer_output,) + outputs
406
+
407
+ if self.is_decoder:
408
+ outputs = outputs + (present_key_value,)
409
+ return outputs
410
+
411
+
412
+ class ModifiedEsmEncoder(EsmEncoder):
413
+ def __init__(self, config):
414
+ nn.Module.__init__(self)
415
+ self.config = config
416
+ self.layer = nn.ModuleList([ModifiedEsmLayer(config) for _ in range(config.num_hidden_layers)])
417
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
418
+ self.gradient_checkpointing = False
419
+
420
+ def forward(
421
+ self,
422
+ hidden_states,
423
+ attention_mask=None,
424
+ head_mask=None,
425
+ encoder_hidden_states=None,
426
+ encoder_attention_mask=None,
427
+ past_key_values=None,
428
+ use_cache=None,
429
+ output_attentions=False,
430
+ output_hidden_states=False,
431
+ return_dict=True,
432
+ type_ids=None,
433
+ flex_block_mask=None,
434
+ ):
435
+ all_hidden_states = () if output_hidden_states else None
436
+ all_self_attentions = () if output_attentions else None
437
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
438
+ next_decoder_cache = () if use_cache else None
439
+
440
+ for i, layer_module in enumerate(self.layer):
441
+ if output_hidden_states:
442
+ all_hidden_states = all_hidden_states + (hidden_states,)
443
+
444
+ layer_head_mask = head_mask[i] if head_mask is not None else None
445
+ past_key_value = past_key_values[i] if past_key_values is not None else None
446
+
447
+ if self.gradient_checkpointing and self.training:
448
+ layer_outputs = self._gradient_checkpointing_func(
449
+ layer_module.__call__,
450
+ hidden_states,
451
+ attention_mask,
452
+ layer_head_mask,
453
+ encoder_hidden_states,
454
+ encoder_attention_mask,
455
+ past_key_value,
456
+ output_attentions,
457
+ type_ids,
458
+ flex_block_mask,
459
+ )
460
+ else:
461
+ layer_outputs = layer_module(
462
+ hidden_states,
463
+ attention_mask,
464
+ layer_head_mask,
465
+ encoder_hidden_states,
466
+ encoder_attention_mask,
467
+ past_key_value,
468
+ output_attentions,
469
+ type_ids,
470
+ flex_block_mask,
471
+ )
472
+
473
+ hidden_states = layer_outputs[0]
474
+ if use_cache:
475
+ next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
476
+ if output_attentions:
477
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
478
+ if self.config.add_cross_attention:
479
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
480
+
481
+ if self.emb_layer_norm_after:
482
+ hidden_states = self.emb_layer_norm_after(hidden_states)
483
+
484
+ if output_hidden_states:
485
+ all_hidden_states = all_hidden_states + (hidden_states,)
486
+
487
+ if return_dict is False:
488
+ return tuple(
489
+ value
490
+ for value in [
491
+ hidden_states,
492
+ next_decoder_cache,
493
+ all_hidden_states,
494
+ all_self_attentions,
495
+ all_cross_attentions,
496
+ ]
497
+ if value is not None
498
+ )
499
+
500
+ return BaseModelOutputWithPastAndCrossAttentions(
501
+ last_hidden_state=hidden_states,
502
+ past_key_values=next_decoder_cache,
503
+ hidden_states=all_hidden_states,
504
+ attentions=all_self_attentions,
505
+ cross_attentions=all_cross_attentions,
506
+ )
507
+
508
+
509
+ class DPLM2Model(DPLM2PreTrainedModel, EmbeddingMixin):
510
+ config_class = DPLM2Config
511
+
512
+ def __init__(self, config, add_pooling_layer=True):
513
+ DPLM2PreTrainedModel.__init__(self, config)
514
+ self.config = config
515
+ self.embeddings = EsmEmbeddings(config)
516
+ self.encoder = ModifiedEsmEncoder(config)
517
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
518
+ self.post_init()
519
+
520
+ def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
521
+ if head_mask.dim() == 1:
522
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
523
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
524
+ elif head_mask.dim() == 2:
525
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
526
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
527
+ head_mask = head_mask.to(dtype=self.dtype)
528
+ return head_mask
529
+
530
+ def get_head_mask(
531
+ self,
532
+ head_mask: Optional[torch.Tensor],
533
+ num_hidden_layers: int,
534
+ is_attention_chunked: bool = False,
535
+ ) -> Union[torch.Tensor, List[None]]:
536
+ if head_mask is None:
537
+ return [None] * num_hidden_layers
538
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
539
+ if is_attention_chunked:
540
+ head_mask = head_mask.unsqueeze(-1)
541
+ return head_mask
542
+
543
+ def get_input_embeddings(self) -> nn.Module:
544
+ return self.embeddings.word_embeddings
545
+
546
+ def set_input_embeddings(self, value):
547
+ self.embeddings.word_embeddings = value
548
+
549
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
550
+ if attention_mask is None:
551
+ attention_mask = input_ids.ne(self.config.pad_token_id)
552
+ type_ids = _infer_modality_type(input_ids, attention_mask)
553
+ outputs = self(
554
+ input_ids=input_ids,
555
+ attention_mask=attention_mask,
556
+ type_ids=type_ids,
557
+ output_hidden_states=False,
558
+ output_attentions=False,
559
+ return_dict=True,
560
+ )
561
+ return outputs.last_hidden_state
562
+
563
+ def forward(
564
+ self,
565
+ input_ids: Optional[torch.Tensor] = None,
566
+ attention_mask: Optional[torch.Tensor] = None,
567
+ position_ids: Optional[torch.Tensor] = None,
568
+ head_mask: Optional[torch.Tensor] = None,
569
+ inputs_embeds: Optional[torch.Tensor] = None,
570
+ encoder_hidden_states: Optional[torch.Tensor] = None,
571
+ encoder_attention_mask: Optional[torch.Tensor] = None,
572
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
573
+ use_cache: Optional[bool] = None,
574
+ output_attentions: Optional[bool] = None,
575
+ output_hidden_states: Optional[bool] = None,
576
+ return_dict: Optional[bool] = None,
577
+ type_ids: Optional[torch.Tensor] = None,
578
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
579
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
580
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
581
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
582
+
583
+ if self.config.is_decoder:
584
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
585
+ else:
586
+ use_cache = False
587
+
588
+ if input_ids is not None and inputs_embeds is not None:
589
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
590
+ if input_ids is not None:
591
+ input_shape = input_ids.size()
592
+ elif inputs_embeds is not None:
593
+ input_shape = inputs_embeds.size()[:-1]
594
+ else:
595
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
596
+
597
+ batch_size, seq_length = input_shape
598
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
599
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
600
+
601
+ if attention_mask is None:
602
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
603
+
604
+ token_attention_mask = None
605
+ if attention_mask.dim() == 2:
606
+ token_attention_mask = attention_mask.bool()
607
+ if self.config.attn_backend == "flex" and output_attentions is False:
608
+ extended_attention_mask = None
609
+ else:
610
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
611
+ elif attention_mask.dim() == 4:
612
+ if self.config.attn_backend == "flex" and output_attentions is False:
613
+ extended_attention_mask = None
614
+ else:
615
+ extended_attention_mask = attention_mask
616
+ if input_ids is not None:
617
+ token_attention_mask = input_ids.ne(self.config.pad_token_id)
618
+ else:
619
+ raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
620
+
621
+ if self.config.is_decoder and encoder_hidden_states is not None:
622
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
623
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
624
+ if encoder_attention_mask is None:
625
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
626
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
627
+ else:
628
+ encoder_extended_attention_mask = encoder_attention_mask
629
+
630
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
631
+
632
+ embedding_attention_mask = token_attention_mask
633
+ if embedding_attention_mask is None and input_ids is not None:
634
+ embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
635
+
636
+ flex_block_mask = None
637
+ if (
638
+ self.config.attn_backend == "flex"
639
+ and token_attention_mask is not None
640
+ and output_attentions is False
641
+ ):
642
+ assert create_block_mask is not None, (
643
+ "Flex attention backend requested but torch.create_block_mask is unavailable."
644
+ )
645
+ flex_block_mask = _create_pad_block_mask(token_attention_mask)
646
+
647
+ embedding_output = self.embeddings(
648
+ input_ids=input_ids,
649
+ position_ids=position_ids,
650
+ attention_mask=embedding_attention_mask,
651
+ inputs_embeds=inputs_embeds,
652
+ )
653
+ encoder_outputs = self.encoder(
654
+ embedding_output,
655
+ attention_mask=extended_attention_mask,
656
+ head_mask=head_mask,
657
+ encoder_hidden_states=encoder_hidden_states,
658
+ encoder_attention_mask=encoder_extended_attention_mask,
659
+ past_key_values=past_key_values,
660
+ use_cache=use_cache,
661
+ output_attentions=output_attentions,
662
+ output_hidden_states=output_hidden_states,
663
+ return_dict=return_dict,
664
+ type_ids=type_ids,
665
+ flex_block_mask=flex_block_mask,
666
+ )
667
+ sequence_output = encoder_outputs[0]
668
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
669
+
670
+ if return_dict is False:
671
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
672
+
673
+ return BaseModelOutputWithPoolingAndCrossAttentions(
674
+ last_hidden_state=sequence_output,
675
+ pooler_output=pooled_output,
676
+ past_key_values=None,
677
+ hidden_states=encoder_outputs.hidden_states,
678
+ attentions=encoder_outputs.attentions,
679
+ cross_attentions=encoder_outputs.cross_attentions,
680
+ )
681
+
682
+
683
+ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
684
+ config_class = DPLM2Config
685
+
686
+ def __init__(self, config, dropout: float = 0.1, vocab_size: Optional[int] = None):
687
+ config.hidden_dropout_prob = dropout
688
+ config.tie_word_embeddings = False
689
+ if vocab_size is not None:
690
+ config.vocab_size = vocab_size
691
+ DPLM2PreTrainedModel.__init__(self, config)
692
+ self.esm = DPLM2Model(config, add_pooling_layer=False)
693
+ self.lm_head = EsmLMHead(config)
694
+ self.loss_fct = nn.CrossEntropyLoss()
695
+ self.post_init()
696
+ self.pad_id = config.pad_token_id
697
+
698
+ def get_input_embeddings(self) -> nn.Module:
699
+ return self.esm.embeddings.word_embeddings
700
+
701
+ def get_output_embeddings(self):
702
+ return self.lm_head.decoder
703
+
704
+ def set_output_embeddings(self, new_embeddings):
705
+ self.lm_head.decoder = new_embeddings
706
+
707
+ def _get_modality_type(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
708
+ return _infer_modality_type(input_ids, attention_mask)
709
+
710
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
711
+ if attention_mask is None:
712
+ attention_mask = input_ids.ne(self.pad_id)
713
+ type_ids = self._get_modality_type(input_ids, attention_mask)
714
+ outputs = self.esm(
715
+ input_ids=input_ids,
716
+ attention_mask=attention_mask,
717
+ type_ids=type_ids,
718
+ output_attentions=False,
719
+ output_hidden_states=False,
720
+ return_dict=True,
721
+ )
722
+ return outputs.last_hidden_state
723
+
724
+ def forward(
725
+ self,
726
+ input_ids: Optional[torch.Tensor] = None,
727
+ attention_mask: Optional[torch.Tensor] = None,
728
+ type_ids: Optional[torch.Tensor] = None,
729
+ inputs_embeds: Optional[torch.Tensor] = None,
730
+ decoder_input_ids: Optional[torch.Tensor] = None,
731
+ decoder_attention_mask: Optional[torch.Tensor] = None,
732
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
733
+ labels: Optional[torch.Tensor] = None,
734
+ output_attentions: Optional[bool] = None,
735
+ output_hidden_states: Optional[bool] = None,
736
+ return_dict: Optional[bool] = None,
737
+ encoder_hidden_states: Optional[torch.Tensor] = None,
738
+ encoder_attention_mask: Optional[torch.Tensor] = None,
739
+ ) -> Union[Tuple[torch.Tensor], DPLM2MaskedLMOutput]:
740
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
741
+
742
+ if attention_mask is None:
743
+ assert input_ids is not None
744
+ attention_mask = input_ids.ne(self.pad_id)
745
+
746
+ if type_ids is None:
747
+ assert input_ids is not None
748
+ type_ids = self._get_modality_type(input_ids, attention_mask)
749
+
750
+ outputs = self.esm(
751
+ input_ids=input_ids,
752
+ inputs_embeds=inputs_embeds,
753
+ attention_mask=attention_mask,
754
+ encoder_hidden_states=encoder_hidden_states,
755
+ encoder_attention_mask=encoder_attention_mask,
756
+ output_attentions=output_attentions,
757
+ output_hidden_states=output_hidden_states,
758
+ return_dict=True,
759
+ type_ids=type_ids,
760
+ )
761
+
762
+ sequence_output = outputs.last_hidden_state
763
+ logits = self.lm_head(sequence_output)
764
+ loss = None
765
+ if labels is not None:
766
+ labels = labels.to(logits.device)
767
+ loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
768
+
769
+ if return_dict is False:
770
+ output = (logits, sequence_output, outputs.hidden_states, outputs.attentions)
771
+ if loss is not None:
772
+ return (loss,) + output
773
+ return output
774
+
775
+ return DPLM2MaskedLMOutput(
776
+ loss=loss,
777
+ logits=logits,
778
+ last_hidden_state=sequence_output,
779
+ hidden_states=outputs.hidden_states,
780
+ attentions=outputs.attentions,
781
+ )
782
+
783
+
784
+ class DPLM2ForSequenceClassification(DPLM2PreTrainedModel, EmbeddingMixin):
785
+ config_class = DPLM2Config
786
+
787
+ def __init__(self, config):
788
+ DPLM2PreTrainedModel.__init__(self, config)
789
+ self.num_labels = config.num_labels
790
+ self.esm = DPLM2Model(config, add_pooling_layer=False)
791
+ self.classifier = EsmClassificationHead(config)
792
+ self.mse = nn.MSELoss()
793
+ self.ce = nn.CrossEntropyLoss()
794
+ self.bce = nn.BCEWithLogitsLoss()
795
+ self.post_init()
796
+
797
+ def get_input_embeddings(self) -> nn.Module:
798
+ return self.esm.embeddings.word_embeddings
799
+
800
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
801
+ return self.esm._embed(input_ids, attention_mask)
802
+
803
+ def forward(
804
+ self,
805
+ input_ids: Optional[torch.Tensor] = None,
806
+ attention_mask: Optional[torch.Tensor] = None,
807
+ type_ids: Optional[torch.Tensor] = None,
808
+ inputs_embeds: Optional[torch.Tensor] = None,
809
+ labels: Optional[torch.Tensor] = None,
810
+ output_attentions: Optional[bool] = None,
811
+ output_hidden_states: Optional[bool] = None,
812
+ return_dict: Optional[bool] = None,
813
+ **kwargs,
814
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
815
+ if type_ids is None and input_ids is not None:
816
+ if attention_mask is None:
817
+ attention_mask = input_ids.ne(self.config.pad_token_id)
818
+ type_ids = _infer_modality_type(input_ids, attention_mask)
819
+
820
+ outputs = self.esm(
821
+ input_ids=input_ids,
822
+ attention_mask=attention_mask,
823
+ type_ids=type_ids,
824
+ inputs_embeds=inputs_embeds,
825
+ output_attentions=output_attentions,
826
+ output_hidden_states=output_hidden_states,
827
+ return_dict=True,
828
+ )
829
+ sequence_output = outputs.last_hidden_state
830
+ logits = self.classifier(sequence_output)
831
+
832
+ loss = None
833
+ if labels is not None:
834
+ labels = labels.to(logits.device)
835
+ if self.config.problem_type is None:
836
+ if self.num_labels == 1:
837
+ self.config.problem_type = "regression"
838
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
839
+ self.config.problem_type = "single_label_classification"
840
+ else:
841
+ self.config.problem_type = "multi_label_classification"
842
+
843
+ if self.config.problem_type == "regression":
844
+ if self.num_labels == 1:
845
+ loss = self.mse(logits.squeeze(), labels.squeeze())
846
+ else:
847
+ loss = self.mse(logits, labels)
848
+ elif self.config.problem_type == "single_label_classification":
849
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
850
+ elif self.config.problem_type == "multi_label_classification":
851
+ loss = self.bce(logits, labels)
852
+
853
+ return SequenceClassifierOutput(
854
+ loss=loss,
855
+ logits=logits,
856
+ hidden_states=outputs.hidden_states,
857
+ attentions=outputs.attentions,
858
+ )
859
+
860
+
861
+ class DPLM2ForTokenClassification(DPLM2PreTrainedModel, EmbeddingMixin):
862
+ config_class = DPLM2Config
863
+
864
+ def __init__(self, config):
865
+ DPLM2PreTrainedModel.__init__(self, config)
866
+ self.num_labels = config.num_labels
867
+ self.esm = DPLM2Model(config, add_pooling_layer=False)
868
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
869
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
870
+ self.loss_fct = nn.CrossEntropyLoss()
871
+ self.post_init()
872
+
873
+ def get_input_embeddings(self) -> nn.Module:
874
+ return self.esm.embeddings.word_embeddings
875
+
876
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
877
+ return self.esm._embed(input_ids, attention_mask)
878
+
879
+ def forward(
880
+ self,
881
+ input_ids: Optional[torch.Tensor] = None,
882
+ attention_mask: Optional[torch.Tensor] = None,
883
+ type_ids: Optional[torch.Tensor] = None,
884
+ inputs_embeds: Optional[torch.Tensor] = None,
885
+ labels: Optional[torch.Tensor] = None,
886
+ output_attentions: Optional[bool] = None,
887
+ output_hidden_states: Optional[bool] = None,
888
+ return_dict: Optional[bool] = None,
889
+ **kwargs,
890
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
891
+ if type_ids is None and input_ids is not None:
892
+ if attention_mask is None:
893
+ attention_mask = input_ids.ne(self.config.pad_token_id)
894
+ type_ids = _infer_modality_type(input_ids, attention_mask)
895
+
896
+ outputs = self.esm(
897
+ input_ids=input_ids,
898
+ attention_mask=attention_mask,
899
+ type_ids=type_ids,
900
+ inputs_embeds=inputs_embeds,
901
+ output_attentions=output_attentions,
902
+ output_hidden_states=output_hidden_states,
903
+ return_dict=True,
904
+ )
905
+ sequence_output = self.dropout(outputs.last_hidden_state)
906
+ logits = self.classifier(sequence_output)
907
+
908
+ loss = None
909
+ if labels is not None:
910
+ labels = labels.to(logits.device)
911
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
912
+
913
+ return TokenClassifierOutput(
914
+ loss=loss,
915
+ logits=logits,
916
+ hidden_states=outputs.hidden_states,
917
+ attentions=outputs.attentions,
918
+ )