GinnM commited on
Commit
b643244
·
verified ·
1 Parent(s): bca29b0

Upload ThermoFormer

Browse files
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/rgzn/limc/ThermoFormer/finetune_checkpoint/TM",
3
+ "architectures": [
4
+ "ThermoFormer"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_thermoformer.ThermoFormerConfig",
9
+ "AutoModel": "modeling_thermoformer.ThermoFormer"
10
+ },
11
+ "emb_layer_norm_before": false,
12
+ "flash_attention": true,
13
+ "hidden_dropout_prob": 0.0,
14
+ "hidden_size": 1280,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 5120,
17
+ "layer_norm_eps": 1e-05,
18
+ "mask_token_id": 32,
19
+ "max_position_embeddings": 1026,
20
+ "mlm": true,
21
+ "model_type": "theromoformer",
22
+ "num_attention_heads": 20,
23
+ "num_hidden_layers": 33,
24
+ "pad_token_id": 1,
25
+ "position_embedding_type": "rotary",
26
+ "structure_vocab_size": 100,
27
+ "token_dropout": true,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.46.2",
30
+ "use_cache": true,
31
+ "vocab_size": 33
32
+ }
configuration_thermoformer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class ThermoFormerConfig(PretrainedConfig):
9
+ model_type = "theromoformer"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=33,
14
+ mask_token_id=32,
15
+ pad_token_id=1,
16
+ hidden_size=768,
17
+ num_hidden_layers=12,
18
+ num_attention_heads=12,
19
+ intermediate_size=3072,
20
+ hidden_dropout_prob=0.1,
21
+ attention_probs_dropout_prob=0.1,
22
+ max_position_embeddings=1026,
23
+ initializer_range=0.02,
24
+ layer_norm_eps=1e-12,
25
+ position_embedding_type="rotary",
26
+ use_cache=True,
27
+ emb_layer_norm_before=None,
28
+ token_dropout=False,
29
+ flash_attention=True,
30
+ structure_vocab_size=100,
31
+ mlm=True,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(
35
+ pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs
36
+ )
37
+
38
+ self.vocab_size = vocab_size
39
+ self.hidden_size = hidden_size
40
+ self.num_hidden_layers = num_hidden_layers
41
+ self.num_attention_heads = num_attention_heads
42
+ self.intermediate_size = intermediate_size
43
+ self.hidden_dropout_prob = hidden_dropout_prob
44
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
45
+ self.max_position_embeddings = max_position_embeddings
46
+ self.initializer_range = initializer_range
47
+ self.layer_norm_eps = layer_norm_eps
48
+ self.position_embedding_type = position_embedding_type
49
+ self.use_cache = use_cache
50
+ self.emb_layer_norm_before = emb_layer_norm_before
51
+ self.token_dropout = token_dropout
52
+ self.flash_attention = flash_attention
53
+ self.structure_vocab_size = structure_vocab_size
54
+ self.mlm = mlm
55
+
56
+ ThermoFormerConfig.register_for_auto_class()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47aee4428d0878bcdc479135be9efe6e052caee188aa936a2ebe34c8a75e3895
3
+ size 2669826476
modeling_thermoformer.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch.nn.functional as F
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from dataclasses import dataclass
9
+ from transformers.modeling_outputs import (
10
+ BaseModelOutputWithPastAndCrossAttentions,
11
+ BaseModelOutputWithPoolingAndCrossAttentions,
12
+ MaskedLMOutput,
13
+ ModelOutput,
14
+ )
15
+ from transformers.modeling_utils import (
16
+ PreTrainedModel,
17
+ find_pruneable_heads_and_indices,
18
+ prune_linear_layer,
19
+ )
20
+ from transformers.utils import logging
21
+ from .configuration_thermoformer import ThermoFormerConfig
22
+ from torch.nn.functional import scaled_dot_product_attention
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ def rotate_half(x):
28
+ return torch.cat((-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1)
29
+
30
+
31
+ def apply_rotary_pos_emb(x, cos, sin):
32
+ cos = cos[:, :, : x.shape[-2], :]
33
+ sin = sin[:, :, : x.shape[-2], :]
34
+ return (x * cos) + (rotate_half(x) * sin)
35
+
36
+
37
+ def gelu(x):
38
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
39
+
40
+
41
+ class RotaryEmbedding(torch.nn.Module):
42
+ def __init__(self, dim: int):
43
+ super().__init__()
44
+ # Generate and save the inverse frequency buffer (non trainable)
45
+ inv_freq = 1.0 / (
46
+ 10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
47
+ )
48
+ inv_freq = inv_freq
49
+ self.register_buffer("inv_freq", inv_freq)
50
+
51
+ self._seq_len_cached = None
52
+ self._cos_cached = None
53
+ self._sin_cached = None
54
+
55
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
56
+ seq_len = x.shape[seq_dimension]
57
+
58
+ # Reset the tables if the sequence length has changed,
59
+ # or if we're on a new device (possibly due to tracing for instance)
60
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
61
+ self._seq_len_cached = seq_len
62
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
63
+ self.inv_freq
64
+ )
65
+ freqs = torch.outer(t, self.inv_freq)
66
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
67
+
68
+ self._cos_cached = emb.cos()[None, None, :, :]
69
+ self._sin_cached = emb.sin()[None, None, :, :]
70
+
71
+ return self._cos_cached, self._sin_cached
72
+
73
+ def forward(
74
+ self, q: torch.Tensor, k: torch.Tensor
75
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
77
+ k, seq_dimension=-2
78
+ )
79
+
80
+ return (
81
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
82
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
83
+ )
84
+
85
+
86
+ class ThermoFormerEmbeddings(nn.Module):
87
+
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.word_embeddings = nn.Embedding(
91
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
92
+ )
93
+
94
+ if config.emb_layer_norm_before:
95
+ self.layer_norm = nn.LayerNorm(
96
+ config.hidden_size, eps=config.layer_norm_eps
97
+ )
98
+ else:
99
+ self.layer_norm = None
100
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
101
+ self.position_embedding_type = getattr(
102
+ config, "position_embedding_type", "absolute"
103
+ )
104
+ self.register_buffer(
105
+ "position_ids",
106
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
107
+ persistent=False,
108
+ )
109
+
110
+ self.padding_idx = config.pad_token_id
111
+ if self.position_embedding_type == "absolute":
112
+ self.position_embeddings = nn.Embedding(
113
+ config.max_position_embeddings,
114
+ config.hidden_size,
115
+ padding_idx=self.padding_idx,
116
+ )
117
+ self.token_dropout = config.token_dropout
118
+ self.mask_token_id = config.mask_token_id
119
+
120
+ def forward(
121
+ self,
122
+ input_ids=None,
123
+ attention_mask=None,
124
+ position_ids=None,
125
+ inputs_embeds=None,
126
+ past_key_values_length=0,
127
+ ):
128
+ if position_ids is None:
129
+ if input_ids is not None:
130
+ position_ids = create_position_ids_from_input_ids(
131
+ input_ids, self.padding_idx, past_key_values_length
132
+ )
133
+ else:
134
+ position_ids = self.create_position_ids_from_inputs_embeds(
135
+ inputs_embeds
136
+ )
137
+
138
+ if inputs_embeds is None:
139
+ inputs_embeds = self.word_embeddings(input_ids)
140
+
141
+ embeddings = inputs_embeds
142
+
143
+ if self.token_dropout:
144
+ embeddings = embeddings.masked_fill(
145
+ (input_ids == self.mask_token_id).unsqueeze(-1), 0.0
146
+ )
147
+ mask_ratio_train = 0.15 * 0.8
148
+ src_lengths = attention_mask.sum(-1)
149
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(
150
+ -1
151
+ ).float() / src_lengths
152
+ embeddings = (
153
+ embeddings
154
+ * (1 - mask_ratio_train)
155
+ / (1 - mask_ratio_observed)[:, None, None]
156
+ ).to(embeddings.dtype)
157
+
158
+ if self.position_embedding_type == "absolute":
159
+ position_embeddings = self.position_embeddings(position_ids)
160
+ embeddings = embeddings + position_embeddings
161
+
162
+ if self.layer_norm is not None:
163
+ embeddings = self.layer_norm(embeddings)
164
+ if attention_mask is not None:
165
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(
166
+ embeddings.dtype
167
+ )
168
+ # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
169
+ # embeddings = self.dropout(embeddings)
170
+ return embeddings
171
+
172
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
173
+ input_shape = inputs_embeds.size()[:-1]
174
+ sequence_length = input_shape[1]
175
+
176
+ position_ids = torch.arange(
177
+ self.padding_idx + 1,
178
+ sequence_length + self.padding_idx + 1,
179
+ dtype=torch.long,
180
+ device=inputs_embeds.device,
181
+ )
182
+ return position_ids.unsqueeze(0).expand(input_shape)
183
+
184
+
185
+ class ThermoFormerSelfAttention(nn.Module):
186
+ def __init__(self, config, position_embedding_type=None):
187
+ super().__init__()
188
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
189
+ config, "embedding_size"
190
+ ):
191
+ raise ValueError(
192
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
193
+ f"heads ({config.num_attention_heads})"
194
+ )
195
+
196
+ self.num_attention_heads = config.num_attention_heads
197
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
198
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
199
+
200
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
201
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
202
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
203
+
204
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
205
+ self.position_embedding_type = position_embedding_type or getattr(
206
+ config, "position_embedding_type", "absolute"
207
+ )
208
+ self.rotary_embeddings = None
209
+ if (
210
+ self.position_embedding_type == "relative_key"
211
+ or self.position_embedding_type == "relative_key_query"
212
+ ):
213
+ self.max_position_embeddings = config.max_position_embeddings
214
+ self.distance_embedding = nn.Embedding(
215
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
216
+ )
217
+ elif self.position_embedding_type == "rotary":
218
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
219
+ self.flash_attention = config.flash_attention
220
+ self.is_decoder = config.is_decoder
221
+ self.config = config
222
+
223
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
224
+ new_x_shape = x.size()[:-1] + (
225
+ self.num_attention_heads,
226
+ self.attention_head_size,
227
+ )
228
+ x = x.view(new_x_shape)
229
+ return x.permute(0, 2, 1, 3)
230
+
231
+ def forward(
232
+ self,
233
+ hidden_states: torch.Tensor,
234
+ attention_mask: Optional[torch.FloatTensor] = None,
235
+ head_mask: Optional[torch.FloatTensor] = None,
236
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
237
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
238
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
239
+ output_attentions: Optional[bool] = False,
240
+ ) -> Tuple[torch.Tensor]:
241
+ mixed_query_layer = self.query(hidden_states)
242
+
243
+ # If this is instantiated as a cross-attention module, the keys
244
+ # and values come from an encoder; the attention mask needs to be
245
+ # such that the encoder's padding tokens are not attended to.
246
+ is_cross_attention = encoder_hidden_states is not None
247
+
248
+ if is_cross_attention and past_key_value is not None:
249
+ # reuse k,v, cross_attentions
250
+ key_layer = past_key_value[0]
251
+ value_layer = past_key_value[1]
252
+ attention_mask = encoder_attention_mask
253
+ elif is_cross_attention:
254
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
255
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
256
+ attention_mask = encoder_attention_mask
257
+ elif past_key_value is not None:
258
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
259
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
260
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
261
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
262
+ else:
263
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
264
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
265
+
266
+ query_layer = self.transpose_for_scores(mixed_query_layer)
267
+
268
+ query_layer = query_layer * self.attention_head_size**-0.5
269
+
270
+ if self.is_decoder:
271
+ past_key_value = (key_layer, value_layer)
272
+
273
+ if self.position_embedding_type == "rotary":
274
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
275
+
276
+ if not self.flash_attention:
277
+ # Take the dot product between "query" and "key" to get the raw attention scores.
278
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
279
+
280
+ if (
281
+ self.position_embedding_type == "relative_key"
282
+ or self.position_embedding_type == "relative_key_query"
283
+ ):
284
+ seq_length = hidden_states.size()[1]
285
+ position_ids_l = torch.arange(
286
+ seq_length, dtype=torch.long, device=hidden_states.device
287
+ ).view(-1, 1)
288
+ position_ids_r = torch.arange(
289
+ seq_length, dtype=torch.long, device=hidden_states.device
290
+ ).view(1, -1)
291
+ distance = position_ids_l - position_ids_r
292
+ positional_embedding = self.distance_embedding(
293
+ distance + self.max_position_embeddings - 1
294
+ )
295
+ positional_embedding = positional_embedding.to(
296
+ dtype=query_layer.dtype
297
+ ) # fp16 compatibility
298
+
299
+ if self.position_embedding_type == "relative_key":
300
+ relative_position_scores = torch.einsum(
301
+ "bhld,lrd->bhlr", query_layer, positional_embedding
302
+ )
303
+ attention_scores = attention_scores + relative_position_scores
304
+ elif self.position_embedding_type == "relative_key_query":
305
+ relative_position_scores_query = torch.einsum(
306
+ "bhld,lrd->bhlr", query_layer, positional_embedding
307
+ )
308
+ relative_position_scores_key = torch.einsum(
309
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
310
+ )
311
+ attention_scores = (
312
+ attention_scores
313
+ + relative_position_scores_query
314
+ + relative_position_scores_key
315
+ )
316
+
317
+ if attention_mask is not None:
318
+ attention_scores = attention_scores + attention_mask
319
+
320
+ # Normalize the attention scores to probabilities.
321
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
322
+
323
+ # This is actually dropping out entire tokens to attend to, which might
324
+ # seem a bit unusual, but is taken from the original Transformer paper.
325
+ attention_probs = self.dropout(attention_probs)
326
+
327
+ # Mask heads if we want to
328
+ if head_mask is not None:
329
+ attention_probs = attention_probs * head_mask
330
+
331
+ context_layer = torch.matmul(attention_probs, value_layer)
332
+ else:
333
+ if self.training:
334
+ context_layer = scaled_dot_product_attention(
335
+ query_layer,
336
+ key_layer,
337
+ value_layer,
338
+ attn_mask=attention_mask,
339
+ dropout_p=self.config.attention_probs_dropout_prob,
340
+ scale=1, # we have query_layer = query_layer * self.attention_head_size**-0.5
341
+ )
342
+ else:
343
+ context_layer = scaled_dot_product_attention(
344
+ query_layer,
345
+ key_layer,
346
+ value_layer,
347
+ attn_mask=attention_mask,
348
+ scale=1, # we have query_layer = query_layer * self.attention_head_size**-0.5
349
+ )
350
+
351
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
352
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
353
+ context_layer = context_layer.view(new_context_layer_shape)
354
+
355
+ outputs = (
356
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
357
+ )
358
+
359
+ if self.is_decoder:
360
+ outputs = outputs + (past_key_value,)
361
+ return outputs
362
+
363
+
364
+ class ThermoFormerSelfOutput(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
368
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
369
+
370
+ def forward(self, hidden_states, input_tensor):
371
+ hidden_states = self.dense(hidden_states)
372
+ hidden_states = self.dropout(hidden_states)
373
+ hidden_states = hidden_states + input_tensor
374
+ return hidden_states
375
+
376
+
377
+ class ThermoFormerAttention(nn.Module):
378
+ def __init__(self, config):
379
+ super().__init__()
380
+ self.self = ThermoFormerSelfAttention(config)
381
+ self.output = ThermoFormerSelfOutput(config)
382
+ self.pruned_heads = set()
383
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
384
+
385
+ def prune_heads(self, heads):
386
+ if len(heads) == 0:
387
+ return
388
+ heads, index = find_pruneable_heads_and_indices(
389
+ heads,
390
+ self.self.num_attention_heads,
391
+ self.self.attention_head_size,
392
+ self.pruned_heads,
393
+ )
394
+
395
+ # Prune linear layers
396
+ self.self.query = prune_linear_layer(self.self.query, index)
397
+ self.self.key = prune_linear_layer(self.self.key, index)
398
+ self.self.value = prune_linear_layer(self.self.value, index)
399
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
400
+
401
+ # Update hyper params and store pruned heads
402
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
403
+ self.self.all_head_size = (
404
+ self.self.attention_head_size * self.self.num_attention_heads
405
+ )
406
+ self.pruned_heads = self.pruned_heads.union(heads)
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states,
411
+ attention_mask=None,
412
+ head_mask=None,
413
+ encoder_hidden_states=None,
414
+ encoder_attention_mask=None,
415
+ past_key_value=None,
416
+ output_attentions=False,
417
+ ):
418
+ hidden_states_ln = self.LayerNorm(hidden_states)
419
+ self_outputs = self.self(
420
+ hidden_states_ln,
421
+ attention_mask,
422
+ head_mask,
423
+ encoder_hidden_states,
424
+ encoder_attention_mask,
425
+ past_key_value,
426
+ output_attentions,
427
+ )
428
+ attention_output = self.output(self_outputs[0], hidden_states)
429
+ outputs = (attention_output,) + self_outputs[
430
+ 1:
431
+ ] # add attentions if we output them
432
+ return outputs
433
+
434
+
435
+ class ThermoFormerIntermediate(nn.Module):
436
+ def __init__(self, config):
437
+ super().__init__()
438
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
439
+
440
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
441
+ hidden_states = self.dense(hidden_states)
442
+ hidden_states = gelu(hidden_states)
443
+ return hidden_states
444
+
445
+
446
+ class ThermoFormerOutput(nn.Module):
447
+ def __init__(self, config):
448
+ super().__init__()
449
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
450
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
451
+
452
+ def forward(self, hidden_states, input_tensor):
453
+ hidden_states = self.dense(hidden_states)
454
+ hidden_states = self.dropout(hidden_states)
455
+ hidden_states = hidden_states + input_tensor
456
+ return hidden_states
457
+
458
+
459
+ class ThermoFormerLayer(nn.Module):
460
+ def __init__(self, config):
461
+ super().__init__()
462
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
463
+ self.seq_len_dim = 1
464
+ self.attention = ThermoFormerAttention(config)
465
+ self.is_decoder = config.is_decoder
466
+ self.add_cross_attention = config.add_cross_attention
467
+ if self.add_cross_attention:
468
+ if not self.is_decoder:
469
+ raise RuntimeError(
470
+ f"{self} should be used as a decoder model if cross attention is added"
471
+ )
472
+ self.crossattention = ThermoFormerAttention(config)
473
+ self.intermediate = ThermoFormerIntermediate(config)
474
+ self.output = ThermoFormerOutput(config)
475
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
+
477
+ def forward(
478
+ self,
479
+ hidden_states,
480
+ attention_mask=None,
481
+ head_mask=None,
482
+ encoder_hidden_states=None,
483
+ encoder_attention_mask=None,
484
+ past_key_value=None,
485
+ output_attentions=False,
486
+ ):
487
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
488
+ self_attn_past_key_value = (
489
+ past_key_value[:2] if past_key_value is not None else None
490
+ )
491
+ self_attention_outputs = self.attention(
492
+ hidden_states,
493
+ attention_mask,
494
+ head_mask,
495
+ output_attentions=output_attentions,
496
+ past_key_value=self_attn_past_key_value,
497
+ )
498
+ attention_output = self_attention_outputs[0]
499
+
500
+ # if decoder, the last output is tuple of self-attn cache
501
+ if self.is_decoder:
502
+ outputs = self_attention_outputs[1:-1]
503
+ present_key_value = self_attention_outputs[-1]
504
+ else:
505
+ outputs = self_attention_outputs[
506
+ 1:
507
+ ] # add self attentions if we output attention weights
508
+
509
+ cross_attn_present_key_value = None
510
+ if self.is_decoder and encoder_hidden_states is not None:
511
+ if not hasattr(self, "crossattention"):
512
+ raise AttributeError(
513
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
514
+ " with cross-attention layers by setting `config.add_cross_attention=True`"
515
+ )
516
+
517
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
518
+ cross_attn_past_key_value = (
519
+ past_key_value[-2:] if past_key_value is not None else None
520
+ )
521
+ cross_attention_outputs = self.crossattention(
522
+ attention_output,
523
+ attention_mask,
524
+ head_mask,
525
+ encoder_hidden_states,
526
+ encoder_attention_mask,
527
+ cross_attn_past_key_value,
528
+ output_attentions,
529
+ )
530
+ attention_output = cross_attention_outputs[0]
531
+ outputs = (
532
+ outputs + cross_attention_outputs[1:-1]
533
+ ) # add cross attentions if we output attention weights
534
+
535
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
536
+ cross_attn_present_key_value = cross_attention_outputs[-1]
537
+ present_key_value = present_key_value + cross_attn_present_key_value
538
+
539
+ layer_output = self.feed_forward_chunk(attention_output)
540
+
541
+ outputs = (layer_output,) + outputs
542
+
543
+ # if decoder, return the attn key/values as the last output
544
+ if self.is_decoder:
545
+ outputs = outputs + (present_key_value,)
546
+ return outputs
547
+
548
+ def feed_forward_chunk(self, attention_output):
549
+ attention_output_ln = self.LayerNorm(attention_output)
550
+ intermediate_output = self.intermediate(attention_output_ln)
551
+ layer_output = self.output(intermediate_output, attention_output)
552
+ return layer_output
553
+
554
+
555
+ class ThermoFormerEncoder(nn.Module):
556
+ def __init__(self, config):
557
+ super().__init__()
558
+ self.config = config
559
+ self.layer = nn.ModuleList(
560
+ [ThermoFormerLayer(config) for _ in range(config.num_hidden_layers)]
561
+ )
562
+ self.emb_layer_norm_after = nn.LayerNorm(
563
+ config.hidden_size, eps=config.layer_norm_eps
564
+ )
565
+ self.gradient_checkpointing = True
566
+
567
+ def forward(
568
+ self,
569
+ hidden_states,
570
+ attention_mask=None,
571
+ head_mask=None,
572
+ encoder_hidden_states=None,
573
+ encoder_attention_mask=None,
574
+ past_key_values=None,
575
+ use_cache=None,
576
+ output_attentions=False,
577
+ output_hidden_states=False,
578
+ return_dict=True,
579
+ ):
580
+ if self.gradient_checkpointing and self.training:
581
+ if use_cache:
582
+ logger.warning_once(
583
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
584
+ "`use_cache=False`..."
585
+ )
586
+ use_cache = False
587
+ all_hidden_states = () if output_hidden_states else None
588
+ all_self_attentions = () if output_attentions else None
589
+ all_cross_attentions = (
590
+ () if output_attentions and self.config.add_cross_attention else None
591
+ )
592
+
593
+ next_decoder_cache = () if use_cache else None
594
+ for i, layer_module in enumerate(self.layer):
595
+ if output_hidden_states:
596
+ all_hidden_states = all_hidden_states + (hidden_states,)
597
+
598
+ layer_head_mask = head_mask[i] if head_mask is not None else None
599
+ past_key_value = past_key_values[i] if past_key_values is not None else None
600
+
601
+ if self.gradient_checkpointing and self.training:
602
+ layer_outputs = self._gradient_checkpointing_func(
603
+ layer_module.__call__,
604
+ hidden_states,
605
+ attention_mask,
606
+ layer_head_mask,
607
+ encoder_hidden_states,
608
+ encoder_attention_mask,
609
+ past_key_value,
610
+ output_attentions,
611
+ )
612
+ else:
613
+ layer_outputs = layer_module(
614
+ hidden_states,
615
+ attention_mask,
616
+ layer_head_mask,
617
+ encoder_hidden_states,
618
+ encoder_attention_mask,
619
+ past_key_value,
620
+ output_attentions,
621
+ )
622
+
623
+ hidden_states = layer_outputs[0]
624
+ if use_cache:
625
+ next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
626
+ if output_attentions:
627
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
628
+ if self.config.add_cross_attention:
629
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
630
+
631
+ if self.emb_layer_norm_after:
632
+ hidden_states = self.emb_layer_norm_after(hidden_states)
633
+
634
+ if output_hidden_states:
635
+ all_hidden_states = all_hidden_states + (hidden_states,)
636
+
637
+ if not return_dict:
638
+ return tuple(
639
+ v
640
+ for v in [
641
+ hidden_states,
642
+ next_decoder_cache,
643
+ all_hidden_states,
644
+ all_self_attentions,
645
+ all_cross_attentions,
646
+ ]
647
+ if v is not None
648
+ )
649
+ return BaseModelOutputWithPastAndCrossAttentions(
650
+ last_hidden_state=hidden_states,
651
+ past_key_values=next_decoder_cache,
652
+ hidden_states=all_hidden_states,
653
+ attentions=all_self_attentions,
654
+ cross_attentions=all_cross_attentions,
655
+ )
656
+
657
+
658
+ class ThermoFormerPreTrainedModel(PreTrainedModel):
659
+ """
660
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
661
+ models.
662
+ """
663
+
664
+ config_class = ThermoFormerConfig
665
+ base_model_prefix = "ThermoFormer"
666
+ supports_gradient_checkpointing = True
667
+ _no_split_modules = [
668
+ "ThermoFormerLayer",
669
+ "ThermoFormerEmbeddings",
670
+ ]
671
+
672
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
673
+ def _init_weights(self, module):
674
+ """Initialize the weights"""
675
+ if isinstance(module, nn.Linear):
676
+ # Slightly different from the TF version which uses truncated_normal for initialization
677
+ # cf https://github.com/pytorch/pytorch/pull/5617
678
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
679
+ if module.bias is not None:
680
+ module.bias.data.zero_()
681
+ elif isinstance(module, nn.Embedding):
682
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
683
+ if module.padding_idx is not None:
684
+ module.weight.data[module.padding_idx].zero_()
685
+ elif isinstance(module, nn.LayerNorm):
686
+ module.bias.data.zero_()
687
+ module.weight.data.fill_(1.0)
688
+
689
+
690
+ class ThermoFormerModel(ThermoFormerPreTrainedModel):
691
+ base_model_prefix = "ThermoFormer"
692
+
693
+ def __init__(self, config, add_pooling_layer=True):
694
+ super().__init__(config)
695
+ self.config = config
696
+ self.embeddings = ThermoFormerEmbeddings(config)
697
+ self.encoder = ThermoFormerEncoder(config)
698
+ self.post_init()
699
+
700
+ def get_input_embeddings(self):
701
+ return self.embeddings.word_embeddings
702
+
703
+ def set_input_embeddings(self, value):
704
+ self.embeddings.word_embeddings = value
705
+
706
+ def _prune_heads(self, heads_to_prune):
707
+ """
708
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
709
+ class PreTrainedModel
710
+ """
711
+ for layer, heads in heads_to_prune.items():
712
+ self.encoder.layer[layer].attention.prune_heads(heads)
713
+
714
+ def forward(
715
+ self,
716
+ input_ids: Optional[torch.Tensor] = None,
717
+ attention_mask: Optional[torch.Tensor] = None,
718
+ position_ids: Optional[torch.Tensor] = None,
719
+ head_mask: Optional[torch.Tensor] = None,
720
+ inputs_embeds: Optional[torch.Tensor] = None,
721
+ encoder_hidden_states: Optional[torch.Tensor] = None,
722
+ encoder_attention_mask: Optional[torch.Tensor] = None,
723
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
724
+ use_cache: Optional[bool] = None,
725
+ output_attentions: Optional[bool] = None,
726
+ output_hidden_states: Optional[bool] = None,
727
+ return_dict: Optional[bool] = None,
728
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
729
+ output_attentions = (
730
+ output_attentions
731
+ if output_attentions is not None
732
+ else self.config.output_attentions
733
+ )
734
+ output_hidden_states = (
735
+ output_hidden_states
736
+ if output_hidden_states is not None
737
+ else self.config.output_hidden_states
738
+ )
739
+ return_dict = (
740
+ return_dict if return_dict is not None else self.config.use_return_dict
741
+ )
742
+
743
+ if self.config.is_decoder:
744
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
745
+ else:
746
+ use_cache = False
747
+
748
+ if input_ids is not None and inputs_embeds is not None:
749
+ raise ValueError(
750
+ "You cannot specify both input_ids and inputs_embeds at the same time"
751
+ )
752
+ elif input_ids is not None:
753
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
754
+ input_shape = input_ids.size()
755
+ elif inputs_embeds is not None:
756
+ input_shape = inputs_embeds.size()[:-1]
757
+ else:
758
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
759
+
760
+ batch_size, seq_length = input_shape
761
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
762
+
763
+ # past_key_values_length
764
+ past_key_values_length = (
765
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
766
+ )
767
+
768
+ if attention_mask is None:
769
+ attention_mask = torch.ones(
770
+ ((batch_size, seq_length + past_key_values_length)), device=device
771
+ )
772
+
773
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
774
+ attention_mask, input_shape
775
+ )
776
+
777
+ if self.config.is_decoder and encoder_hidden_states is not None:
778
+ encoder_batch_size, encoder_sequence_length, _ = (
779
+ encoder_hidden_states.size()
780
+ )
781
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
782
+ if encoder_attention_mask is None:
783
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
784
+ encoder_extended_attention_mask = self.invert_attention_mask(
785
+ encoder_attention_mask
786
+ )
787
+ else:
788
+ encoder_extended_attention_mask = None
789
+
790
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
791
+
792
+ embedding_output = self.embeddings(
793
+ input_ids=input_ids,
794
+ position_ids=position_ids,
795
+ attention_mask=attention_mask,
796
+ inputs_embeds=inputs_embeds,
797
+ past_key_values_length=past_key_values_length,
798
+ )
799
+ encoder_outputs = self.encoder(
800
+ embedding_output,
801
+ attention_mask=extended_attention_mask,
802
+ head_mask=head_mask,
803
+ encoder_hidden_states=encoder_hidden_states,
804
+ encoder_attention_mask=encoder_extended_attention_mask,
805
+ past_key_values=past_key_values,
806
+ use_cache=use_cache,
807
+ output_attentions=output_attentions,
808
+ output_hidden_states=output_hidden_states,
809
+ return_dict=return_dict,
810
+ )
811
+ sequence_output = encoder_outputs[0]
812
+
813
+ return BaseModelOutputWithPoolingAndCrossAttentions(
814
+ last_hidden_state=sequence_output,
815
+ past_key_values=encoder_outputs.past_key_values,
816
+ hidden_states=encoder_outputs.hidden_states,
817
+ attentions=encoder_outputs.attentions,
818
+ cross_attentions=encoder_outputs.cross_attentions,
819
+ )
820
+
821
+
822
+ class ThermoFormerForMaskedLM(ThermoFormerPreTrainedModel):
823
+ _tied_weights_keys = ["lm_head.decoder.weight"]
824
+
825
+ def __init__(self, config):
826
+ super().__init__(config)
827
+
828
+ if config.is_decoder:
829
+ logger.warning(
830
+ "If you want to use `ThermoFormerForMaskedLM` make sure `config.is_decoder=False` for "
831
+ "bi-directional self-attention."
832
+ )
833
+
834
+ self.model = ThermoFormerModel(config, add_pooling_layer=False)
835
+ self.lm_head = ThermoFormerLMHead(config)
836
+ self.init_weights()
837
+
838
+ def get_input_embeddings(self):
839
+ return self.model.embeddings.word_embeddings
840
+
841
+ def get_output_embeddings(self):
842
+ return self.lm_head.decoder
843
+
844
+ def set_output_embeddings(self, new_embeddings):
845
+ self.lm_head.decoder = new_embeddings
846
+
847
+ def forward(
848
+ self,
849
+ input_ids: Optional[torch.LongTensor] = None,
850
+ attention_mask: Optional[torch.Tensor] = None,
851
+ position_ids: Optional[torch.LongTensor] = None,
852
+ head_mask: Optional[torch.Tensor] = None,
853
+ inputs_embeds: Optional[torch.FloatTensor] = None,
854
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
855
+ encoder_attention_mask: Optional[torch.Tensor] = None,
856
+ labels: Optional[torch.LongTensor] = None,
857
+ output_attentions: Optional[bool] = None,
858
+ output_hidden_states: Optional[bool] = None,
859
+ return_dict: Optional[bool] = None,
860
+ ) -> Union[Tuple, MaskedLMOutput]:
861
+ return_dict = (
862
+ return_dict if return_dict is not None else self.config.use_return_dict
863
+ )
864
+
865
+ outputs = self.model(
866
+ input_ids,
867
+ attention_mask=attention_mask,
868
+ position_ids=position_ids,
869
+ head_mask=head_mask,
870
+ inputs_embeds=inputs_embeds,
871
+ encoder_hidden_states=encoder_hidden_states,
872
+ encoder_attention_mask=encoder_attention_mask,
873
+ output_attentions=output_attentions,
874
+ output_hidden_states=output_hidden_states,
875
+ return_dict=return_dict,
876
+ )
877
+ sequence_output = outputs[0]
878
+ prediction_scores = self.lm_head(sequence_output)
879
+
880
+ masked_lm_loss = None
881
+ if labels is not None:
882
+ loss_fct = CrossEntropyLoss()
883
+
884
+ labels = labels.to(prediction_scores.device)
885
+ masked_lm_loss = loss_fct(
886
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
887
+ )
888
+
889
+ if not return_dict:
890
+ output = (prediction_scores,) + outputs[2:]
891
+ return (
892
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
893
+ )
894
+
895
+ return MaskedLMOutput(
896
+ loss=masked_lm_loss,
897
+ logits=prediction_scores,
898
+ hidden_states=outputs.hidden_states,
899
+ attentions=outputs.attentions,
900
+ )
901
+
902
+
903
+ class ThermoFormerLMHead(nn.Module):
904
+
905
+ def __init__(self, config):
906
+ super().__init__()
907
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
908
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
909
+
910
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
911
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
912
+
913
+ def forward(self, features, **kwargs):
914
+ x = self.dense(features)
915
+ x = gelu(x)
916
+ x = self.layer_norm(x)
917
+
918
+ # project back to size of vocabulary with bias
919
+ x = self.decoder(x) + self.bias
920
+ return x
921
+
922
+
923
+ class ThermoFormerStructureHead(nn.Module):
924
+
925
+ def __init__(self, config):
926
+ super().__init__()
927
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
928
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
929
+ self.decoder = nn.Linear(
930
+ config.hidden_size, config.structure_vocab_size, bias=False
931
+ )
932
+ self.bias = nn.Parameter(torch.zeros(config.structure_vocab_size))
933
+
934
+ def forward(self, features, **kwargs):
935
+ x = self.dense(features)
936
+ x = gelu(x)
937
+ x = self.layer_norm(x)
938
+
939
+ # project back to size of vocabulary with bias
940
+ x = self.decoder(x) + self.bias
941
+ return x
942
+
943
+
944
+ def create_position_ids_from_input_ids(
945
+ input_ids, padding_idx, past_key_values_length=0
946
+ ):
947
+ """
948
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
949
+ are ignored. This is modified from fairseq's `utils.make_positions`.
950
+
951
+ Args:
952
+ x: torch.Tensor x:
953
+
954
+ Returns: torch.Tensor
955
+ """
956
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
957
+ mask = input_ids.ne(padding_idx).int()
958
+ incremental_indices = (
959
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
960
+ ) * mask
961
+ return incremental_indices.long() + padding_idx
962
+
963
+
964
+ # POOLING_HEAD
965
+ class MaskedConv1d(nn.Conv1d):
966
+ """A masked 1-dimensional convolution layer.
967
+
968
+ Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically.
969
+
970
+ Shape:
971
+ Input: (N, L, in_channels)
972
+ input_mask: (N, L, 1), optional
973
+ Output: (N, L, out_channels)
974
+ """
975
+
976
+ def __init__(
977
+ self,
978
+ in_channels: int,
979
+ out_channels: int,
980
+ kernel_size: int,
981
+ stride: int = 1,
982
+ dilation: int = 1,
983
+ groups: int = 1,
984
+ bias: bool = True,
985
+ ):
986
+ """
987
+ :param in_channels: input channels
988
+ :param out_channels: output channels
989
+ :param kernel_size: the kernel width
990
+ :param stride: filter shift
991
+ :param dilation: dilation factor
992
+ :param groups: perform depth-wise convolutions
993
+ :param bias: adds learnable bias to output
994
+ """
995
+ padding = dilation * (kernel_size - 1) // 2
996
+ super().__init__(
997
+ in_channels,
998
+ out_channels,
999
+ kernel_size,
1000
+ stride=stride,
1001
+ dilation=dilation,
1002
+ groups=groups,
1003
+ bias=bias,
1004
+ padding=padding,
1005
+ )
1006
+
1007
+ def forward(self, x, input_mask=None):
1008
+ if input_mask is not None:
1009
+ x = x * input_mask
1010
+ return super().forward(x.transpose(1, 2)).transpose(1, 2)
1011
+
1012
+
1013
+ class Attention1d(nn.Module):
1014
+ def __init__(self, config):
1015
+ super().__init__()
1016
+ self.layer = MaskedConv1d(config.hidden_size, 1, 1)
1017
+ self.out = nn.Linear(config.hidden_size, config.hidden_size)
1018
+
1019
+ def forward(self, x, input_mask=None, return_weights=False):
1020
+ batch_szie = x.shape[0]
1021
+ attn = self.layer(x)
1022
+ attn = attn.view(batch_szie, -1)
1023
+ if input_mask is not None:
1024
+ attn = attn.masked_fill_(
1025
+ ~input_mask.view(batch_szie, -1).bool(), float("-inf")
1026
+ )
1027
+ attn = F.softmax(attn, dim=-1).view(batch_szie, -1, 1)
1028
+ out = (attn * x).sum(dim=1)
1029
+ out = self.out(out)
1030
+ if return_weights:
1031
+ return out, attn
1032
+ return out
1033
+
1034
+
1035
+ class FFN1d(nn.Module):
1036
+ def __init__(self, config):
1037
+ super().__init__()
1038
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
1039
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
1040
+ self.act = nn.GELU()
1041
+
1042
+ def forward(self, x):
1043
+ x = self.fc1(x)
1044
+ x = self.act(x)
1045
+ x = self.fc2(x)
1046
+ return x
1047
+
1048
+
1049
+ class Attention1dPooling(nn.Module):
1050
+ """Outputs of the model with the attention1d"""
1051
+
1052
+ def __init__(self, config):
1053
+ super(Attention1dPooling, self).__init__()
1054
+ self.attention1d = Attention1d(config)
1055
+ self.ffn = FFN1d(config)
1056
+ self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
1057
+ self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
1058
+
1059
+ def forward(self, x, input_mask, return_weights=False):
1060
+ if return_weights:
1061
+ attn_out, weights = self.attention1d(
1062
+ x, input_mask=input_mask.unsqueeze(-1), return_weights=return_weights
1063
+ )
1064
+ else:
1065
+ attn_out = self.attention1d(x, input_mask=input_mask.unsqueeze(-1))
1066
+ x = self.dropout1(attn_out)
1067
+ ffn_out = self.ffn(x)
1068
+ x = x + self.dropout2(ffn_out)
1069
+ if return_weights:
1070
+ return x, weights
1071
+ return x
1072
+
1073
+
1074
+ @dataclass
1075
+ class MaskedLMOutput(ModelOutput):
1076
+ loss: Optional[torch.FloatTensor] = None
1077
+ mlm_loss: Optional[torch.FloatTensor] = None
1078
+ value_loss: Optional[torch.FloatTensor] = None
1079
+ predicted_values: Optional[torch.FloatTensor] = None
1080
+ logits: torch.FloatTensor = None
1081
+ sequence_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
1082
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
1083
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
1084
+ pooling_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
1085
+
1086
+
1087
+ class ThermoFormer(ThermoFormerPreTrainedModel):
1088
+ _tied_weights_keys = ["lm_head.decoder.weight"]
1089
+
1090
+ def __init__(self, config):
1091
+ super().__init__(config)
1092
+ self.model = ThermoFormerModel(config, add_pooling_layer=False)
1093
+ self.mlm = config.mlm
1094
+ if self.mlm:
1095
+ self.lm_head = ThermoFormerLMHead(config)
1096
+ else:
1097
+ self.lm_head = None
1098
+ self.sequence_pooling = Attention1dPooling(config)
1099
+ self.value_projection = nn.Sequential(
1100
+ nn.Linear(config.hidden_size, config.hidden_size),
1101
+ nn.Tanh(),
1102
+ nn.Linear(config.hidden_size, 1),
1103
+ )
1104
+ self.init_weights()
1105
+
1106
+ def get_input_embeddings(self):
1107
+ return self.model.embeddings.word_embeddings
1108
+
1109
+ def get_output_embeddings(self):
1110
+ return self.lm_head.decoder
1111
+
1112
+ def set_output_embeddings(self, new_embeddings):
1113
+ self.lm_head.decoder = new_embeddings
1114
+
1115
+ def forward(
1116
+ self,
1117
+ input_ids: Optional[torch.LongTensor] = None,
1118
+ attention_mask: Optional[torch.Tensor] = None,
1119
+ position_ids: Optional[torch.LongTensor] = None,
1120
+ head_mask: Optional[torch.Tensor] = None,
1121
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1122
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1123
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1124
+ labels: Optional[torch.LongTensor] = None,
1125
+ values: Optional[torch.FloatTensor] = None,
1126
+ output_attentions: Optional[bool] = None,
1127
+ output_hidden_states: Optional[bool] = None,
1128
+ return_dict: Optional[bool] = None,
1129
+ ) -> Union[Tuple, MaskedLMOutput]:
1130
+ return_dict = (
1131
+ return_dict if return_dict is not None else self.config.use_return_dict
1132
+ )
1133
+
1134
+ outputs = self.model(
1135
+ input_ids,
1136
+ attention_mask=attention_mask,
1137
+ position_ids=position_ids,
1138
+ head_mask=head_mask,
1139
+ inputs_embeds=inputs_embeds,
1140
+ encoder_hidden_states=encoder_hidden_states,
1141
+ encoder_attention_mask=encoder_attention_mask,
1142
+ output_attentions=output_attentions,
1143
+ output_hidden_states=output_hidden_states,
1144
+ return_dict=return_dict,
1145
+ )
1146
+ sequence_output = outputs[0]
1147
+
1148
+ # Masked LM
1149
+ if labels is not None:
1150
+ assert self.lm_head is not None
1151
+ lm_prediction_scores = self.lm_head(sequence_output)
1152
+ loss_fct = CrossEntropyLoss()
1153
+ labels = labels.to(lm_prediction_scores.device)
1154
+ masked_lm_loss = loss_fct(
1155
+ lm_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1156
+ )
1157
+ else:
1158
+ lm_prediction_scores = None
1159
+ masked_lm_loss = None
1160
+
1161
+ # Value prediction
1162
+ if values is not None:
1163
+ sequence_states, weights = self.sequence_pooling(
1164
+ sequence_output, attention_mask, return_weights=True
1165
+ )
1166
+ predicted_values = self.value_projection(sequence_states)
1167
+ values = values.to(predicted_values.dtype)
1168
+ values = values.reshape(-1, 1)
1169
+ value_loss = nn.MSELoss()(predicted_values, values)
1170
+ else:
1171
+ sequence_states, weights = self.sequence_pooling(
1172
+ sequence_output, attention_mask, return_weights=True
1173
+ )
1174
+ predicted_values = self.value_projection(sequence_states)
1175
+ value_loss = None
1176
+
1177
+ if masked_lm_loss is not None and value_loss is not None:
1178
+ loss = masked_lm_loss + 0.01 * value_loss
1179
+ elif masked_lm_loss is not None and value_loss is None:
1180
+ loss = masked_lm_loss
1181
+ elif masked_lm_loss is None and value_loss is not None:
1182
+ loss = 0.01 * value_loss
1183
+ else:
1184
+ loss = None
1185
+
1186
+ return MaskedLMOutput(
1187
+ loss=loss,
1188
+ mlm_loss=masked_lm_loss,
1189
+ value_loss=value_loss,
1190
+ logits=lm_prediction_scores,
1191
+ predicted_values=predicted_values.reshape(-1),
1192
+ hidden_states=outputs.hidden_states,
1193
+ sequence_hidden_states=sequence_states,
1194
+ attentions=outputs.attentions,
1195
+ pooling_attentions=weights,
1196
+ )
1197
+
1198
+
1199
+ ThermoFormer.register_for_auto_class("AutoModel")