mohsennp commited on
Commit
4b2b63f
·
verified ·
1 Parent(s): 0ba1d66

Upload DeCodon

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. configuration_decodon.py +113 -0
  3. modeling_decodon.py +1640 -0
config.json CHANGED
@@ -5,6 +5,10 @@
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "attention_type": "self",
 
 
 
 
8
  "classifier_dropout": 0.1,
9
  "dilation_rates": null,
10
  "gamma_init": 1.782709687623856,
 
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "attention_type": "self",
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_decodon.DeCodonConfig",
10
+ "AutoModelForCausalLM": "modeling_decodon.DeCodon"
11
+ },
12
  "classifier_dropout": 0.1,
13
  "dilation_rates": null,
14
  "gamma_init": 1.782709687623856,
configuration_decodon.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class DeCodonConfig(PretrainedConfig):
4
+ def __init__(
5
+ self,
6
+ vocab_size=70,
7
+ hidden_size=768,
8
+ num_hidden_layers=12,
9
+ num_attention_heads=12,
10
+ intermediate_size=3072,
11
+ hidden_act="gelu",
12
+ hidden_dropout_prob=0.1,
13
+ attention_probs_dropout_prob=0.1,
14
+ max_position_embeddings=512,
15
+ type_vocab_size=2,
16
+ initializer_range=0.02,
17
+ layer_norm_eps=1e-12,
18
+ pad_token_id=0,
19
+ position_embedding_type="rotary",
20
+ use_cache=True,
21
+ classifier_dropout=None,
22
+ gamma_init=1.0,
23
+ use_rotary_emb=True,
24
+ rotary_theta=1e4,
25
+ use_flash_attn=False,
26
+ **kwargs,
27
+ ):
28
+ super().__init__(
29
+ pad_token_id=pad_token_id,
30
+ vocab_size=vocab_size,
31
+ hidden_size=hidden_size,
32
+ num_hidden_layers=num_hidden_layers,
33
+ num_attention_heads=num_attention_heads,
34
+ intermediate_size=intermediate_size,
35
+ hidden_act=hidden_act,
36
+ hidden_dropout_prob=hidden_dropout_prob,
37
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
38
+ max_position_embeddings=max_position_embeddings,
39
+ type_vocab_size=type_vocab_size,
40
+ initializer_range=initializer_range,
41
+ layer_norm_eps=layer_norm_eps,
42
+ position_embedding_type=position_embedding_type,
43
+ use_cache=use_cache,
44
+ classifier_dropout=classifier_dropout,
45
+ gamma_init=gamma_init,
46
+ use_rotary_emb=use_rotary_emb,
47
+ rotary_theta=rotary_theta,
48
+ use_flash_attn=use_flash_attn,
49
+ is_decoder=kwargs.pop("is_decoder", True),
50
+ **kwargs,
51
+ )
52
+
53
+
54
+ class DeCodonForSequenceTaskConfig(DeCodonConfig):
55
+ def __init__(
56
+ self,
57
+ task_name="NoName",
58
+ num_labels=2,
59
+ num_tasks=1,
60
+ loss_fn="huber",
61
+ cls_num_hidden_layers=1,
62
+ cls_hidden_size=128,
63
+ cls_dropout_prob=0.1,
64
+ cls_hidden_act="relu",
65
+ cls_type="cls",
66
+ cls_num_attention_heads=8,
67
+ cls_use_rotary_emb=False,
68
+ cls_rotary_theta=1e4,
69
+ num_filters=128,
70
+ kernel_size=3,
71
+ stride=1,
72
+ dilation=1,
73
+ pooling_size=2,
74
+ pooling_type="max",
75
+ layer_indices=-1,
76
+ reduction="mean",
77
+ layer_reduction="none",
78
+ problem_type="classification",
79
+ **kwargs,
80
+ ):
81
+
82
+ if problem_type == "classification":
83
+ problem_type_ = "single_label_classification"
84
+ else:
85
+ problem_type_ = problem_type
86
+
87
+ super().__init__(
88
+ task_name=task_name,
89
+ num_labels=num_labels,
90
+ num_tasks=num_tasks,
91
+ loss_fn=loss_fn,
92
+ cls_num_hidden_layers=cls_num_hidden_layers,
93
+ cls_hidden_size=cls_hidden_size,
94
+ cls_dropout_prob=cls_dropout_prob,
95
+ cls_hidden_act=cls_hidden_act,
96
+ cls_num_attention_heads=cls_num_attention_heads,
97
+ cls_use_rotary_emb=cls_use_rotary_emb,
98
+ cls_rotary_theta=cls_rotary_theta,
99
+ cls_type=cls_type,
100
+ num_filters=num_filters,
101
+ kernel_size=kernel_size,
102
+ stride=stride,
103
+ dilation=dilation,
104
+ pooling_size=pooling_size,
105
+ pooling_type=pooling_type,
106
+ layer_indices=layer_indices,
107
+ reduction=reduction,
108
+ layer_reduction=layer_reduction,
109
+ problem_type=problem_type_,
110
+ **kwargs,
111
+ )
112
+
113
+ self.problem_type = problem_type
modeling_decodon.py ADDED
@@ -0,0 +1,1640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from dataclasses import dataclass
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers.modeling_outputs import (
7
+ SequenceClassifierOutput,
8
+ )
9
+
10
+ from typing import Optional, Tuple
11
+
12
+ import torch
13
+ import torch.utils.checkpoint
14
+ from torch import nn
15
+
16
+ from dataclasses import dataclass
17
+ from transformers.activations import ACT2FN, ACT2CLS
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging
20
+ from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutput, CausalLMOutputWithPast
21
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
22
+ import xformers.ops as xops
23
+
24
+ from collections import OrderedDict
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ from einops import rearrange, einsum
31
+ from transformers.pytorch_utils import Conv1D
32
+
33
+
34
+ import torch
35
+ from torch.amp import autocast
36
+ from torch import nn, einsum, Tensor
37
+
38
+ from einops import rearrange, repeat
39
+ from typing import Optional, Union
40
+
41
+ from .configuration_decodon import DeCodonConfig
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ def rotate_half(x):
47
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
48
+ x1, x2 = x.unbind(dim=-1)
49
+ x = torch.stack((-x2, x1), dim=-1)
50
+ return rearrange(x, "... d r -> ... (d r)")
51
+
52
+
53
+ @autocast(device_type="cuda", enabled=False)
54
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
55
+ """
56
+ Applies rotary embeddings to a tensor.
57
+
58
+ Parameters
59
+ ----------
60
+ freqs : Tensor
61
+ The frequencies to apply to the tensor: (seq_len, dim)
62
+ t : Tensor
63
+ The tensor to apply the rotary embeddings to: (..., seq_len, n_heads, dim)
64
+ start_index : int
65
+ The starting index to apply the rotary embeddings. (default: 0)
66
+ scale : float
67
+ The scale to apply to the rotary embeddings. (default: 1.0)
68
+
69
+ Returns
70
+ -------
71
+ Tensor
72
+ The tensor with the rotary embeddings applied.: (..., seq_len, n_heads, dim)
73
+
74
+ """
75
+ # if t.ndim == 3:
76
+ # seq_len = t.shape[seq_dim]
77
+ # freqs = freqs[-seq_len:].to(t)
78
+
79
+ rot_dim = freqs.shape[-1]
80
+ end_index = start_index + rot_dim
81
+
82
+ assert (
83
+ rot_dim <= t.shape[-1]
84
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
85
+
86
+ t_left, t, t_right = (
87
+ t[..., :start_index],
88
+ t[..., start_index:end_index],
89
+ t[..., end_index:],
90
+ )
91
+ if isinstance(scale, float):
92
+ scale = torch.tensor(scale, device=t.device, dtype=t.dtype)
93
+
94
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
95
+ return torch.cat((t_left, t, t_right), dim=-1)
96
+
97
+
98
+ # learned rotation helpers
99
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
100
+ if freq_ranges is not None:
101
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
102
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
103
+
104
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
105
+ return apply_rotary_emb(rotations, t, start_index=start_index)
106
+
107
+
108
+ """
109
+ Inspired from https://github.com/lucidrains/rotary-embedding-torch
110
+ """
111
+
112
+ class RotaryEmbedding(nn.Module):
113
+ """
114
+ Rotary Embeddings Implemenetation inspired by https://github.com/lucidrains/rotary-embedding-torch.
115
+
116
+ Rotary Positional Embeddings (RoPE) encode position information of tokens with a
117
+ rotation matrix that naturally incorporates explicit relative position dependency.
118
+
119
+ Parameters
120
+ ----------
121
+ emb_dim : int
122
+ Embedding dimension. Usually set to the dim of each head in the attention module.
123
+ freqs : Optional[Tensor]
124
+ Custom frequencies to apply to query/key tensors. (default: None)
125
+ theta : float
126
+ Base constant used for computing rotation angles.
127
+ learned_freq : bool (default: False)
128
+ Whether to learn the frequencies.
129
+ use_xpos : bool (default: False)
130
+ Whether to employ XPos technique for resolving length extrapolation issue.
131
+ NOTE: This can only be enabled for autoregressive models like GPT.
132
+ xpos_scale_base : int (default: 512)
133
+ The base for the scale factor used in XPos technique.
134
+ interpolate_factor : float (default: 1.0)
135
+ Length interpolation factor for extending context length of the pretrained model.
136
+ Final model's context length = pretrained_model_context_length * interpolate_factor.
137
+
138
+ theta_rescale_factor : float (default: 1.0)
139
+ The factor to rescale the theta.
140
+
141
+ cache_if_possible : bool (default: True)
142
+ Whether to cache the frequencies/scales if possible.
143
+
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ emb_dim,
149
+ freqs: Optional[Tensor] = None,
150
+ theta=1e4,
151
+ learned_freq=False,
152
+ use_xpos=False,
153
+ xpos_scale_base=512,
154
+ interpolate_factor=1.0,
155
+ theta_rescale_factor=1.0,
156
+ cache_if_possible=True,
157
+ ):
158
+ super().__init__()
159
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
160
+ # has some connection to NTK literature
161
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
162
+
163
+ theta *= theta_rescale_factor ** (emb_dim / (emb_dim - 2))
164
+
165
+ if freqs is None:
166
+ freqs = 1.0 / (
167
+ theta
168
+ ** (torch.arange(0, emb_dim, 2)[: (emb_dim // 2)].float() / emb_dim)
169
+ )
170
+ # freqs = torch.ones(num_freqs).float()
171
+
172
+ self.cache_if_possible = cache_if_possible
173
+
174
+ self.register_buffer("cached_freqs", None, persistent=False)
175
+ self.register_buffer("cached_scales", None, persistent=False)
176
+
177
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
178
+
179
+ self.learned_freq = learned_freq
180
+
181
+ # interpolation factors
182
+
183
+ assert interpolate_factor >= 1.0
184
+ self.interpolate_factor = interpolate_factor
185
+
186
+ # xpos
187
+ self.use_xpos = use_xpos
188
+ if not use_xpos:
189
+ self.register_buffer("scale", None, persistent=False)
190
+ return
191
+
192
+ scale = (torch.arange(0, emb_dim, 2) + 0.4 * emb_dim) / (1.4 * emb_dim)
193
+ self.scale_base = xpos_scale_base
194
+ self.register_buffer("scale", scale, persistent=False)
195
+
196
+ @property
197
+ def device(self):
198
+ return self.freqs.device
199
+
200
+ def rotate_queries_or_keys(self, t, offset=0, freq_seq_len=None, scale=None):
201
+ """
202
+ Parameters
203
+ ----------
204
+ t : Tensor
205
+ tensor to rotate: (batch_size, seq_len, num_heads, head_dim)
206
+ """
207
+ seq_len = t.shape[1]
208
+ assert (
209
+ not self.use_xpos or scale is not None
210
+ ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
211
+
212
+ if freq_seq_len is not None:
213
+ assert freq_seq_len >= seq_len
214
+ seq_len = freq_seq_len
215
+
216
+ seq = (
217
+ torch.arange(seq_len, device=t.device, dtype=t.dtype) + offset
218
+ ) / self.interpolate_factor
219
+
220
+ freqs = self.forward(
221
+ seq,
222
+ seq_len=seq_len,
223
+ offset=offset,
224
+ ).to(t.dtype)
225
+
226
+ freqs = rearrange(freqs, "n d -> n 1 d")
227
+
228
+ if scale is not None:
229
+ scale = rearrange(scale, "n d -> n 1 d")
230
+
231
+ if scale is None:
232
+ scale = torch.tensor(1.0, device=t.device, dtype=t.dtype)
233
+
234
+ return apply_rotary_emb(freqs, t, scale=scale)
235
+
236
+ def rotate_queries_and_keys(self, q, k):
237
+ """
238
+ Parameters
239
+ ----------
240
+ q : Tensor
241
+ queries tensor: (batch_size, seq_len, num_heads, head_dim)
242
+ k : Tensor
243
+ keys tensor: (batch_size, seq_len, num_heads, head_dim)
244
+ """
245
+ assert self.use_xpos
246
+ seq_len = q.shape[-3]
247
+
248
+ seq = (
249
+ torch.arange(seq_len, device=q.device, dtype=q.dtype)
250
+ ) / self.interpolate_factor
251
+
252
+ freqs = self.forward(seq, seq_len=seq_len)
253
+ scale = self.get_scale(seq, seq_len=seq_len)
254
+
255
+ freqs = rearrange(freqs, "n d -> n 1 d")
256
+ scale = rearrange(scale, "n d -> n 1 d")
257
+
258
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale)
259
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1)
260
+
261
+ rotated_q = rotated_q.type(q.dtype)
262
+ rotated_k = rotated_k.type(k.dtype)
263
+
264
+ return rotated_q, rotated_k
265
+
266
+ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
267
+ assert self.use_xpos
268
+
269
+ should_cache = self.cache_if_possible and seq_len is not None
270
+
271
+ if (
272
+ should_cache
273
+ and self.cached_scales is not None
274
+ and (seq_len + offset) <= self.cached_scales.shape[0]
275
+ ):
276
+ return self.cached_scales[offset : (offset + seq_len)]
277
+
278
+ scale = 1.0
279
+ if self.use_xpos:
280
+ power = (t - len(t) // 2) / self.scale_base
281
+ scale = self.scale ** rearrange(power, "n -> n 1")
282
+ scale = torch.cat((scale, scale), dim=-1)
283
+
284
+ if should_cache:
285
+ self.register_buffer("cached_scales", scale, persistent=False)
286
+
287
+ return scale
288
+
289
+ def rotate_queries_with_cached_keys(self, q, k, offset=0):
290
+ q_len, k_len = q.shape[1], k.shape[1]
291
+ assert q_len <= k_len
292
+
293
+ rotated_q, rotated_k = self.rotate_queries_and_keys(q, k)
294
+
295
+ rotated_q = rotated_q[:, -1:, ...]
296
+
297
+ return rotated_q, rotated_k
298
+
299
+ seq = (
300
+ torch.arange(k_len, device=q.device, dtype=q.dtype)
301
+ ) / self.interpolate_factor
302
+
303
+ if self.use_xpos:
304
+ q_scale = self.get_scale(seq[-q_len:]).to(q.dtype)
305
+ k_scale = self.get_scale(seq).to(k.dtype)
306
+
307
+ else:
308
+ k_scale = 1.0
309
+ q_scale = 1.0
310
+
311
+ rotated_q = self.rotate_queries_or_keys(
312
+ q, scale=q_scale, offset=k_len - q_len + offset
313
+ )
314
+ rotated_k = self.rotate_queries_or_keys(k, scale=k_scale**-1)
315
+
316
+ return rotated_q, rotated_k
317
+
318
+ @autocast(device_type="cuda", enabled=False)
319
+ def forward(self, t: Tensor, seq_len=None, offset=0):
320
+ should_cache = (
321
+ self.cache_if_possible and not self.learned_freq and seq_len is not None
322
+ )
323
+
324
+ if (
325
+ should_cache
326
+ and self.cached_freqs is not None
327
+ and (offset + seq_len) <= self.cached_freqs.shape[0]
328
+ ):
329
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
330
+
331
+ freqs = self.freqs
332
+
333
+ freqs = einsum("..., f -> ... f", t, freqs)
334
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
335
+
336
+ if should_cache:
337
+ self.register_buffer("cached_freqs", freqs.detach(), persistent=False)
338
+
339
+ return freqs
340
+
341
+
342
+
343
+ class MultiHeadedSelfAttention(nn.Module):
344
+ """
345
+ Multi-Headed Self Attention module supported with Flash Attention and Rotary Embeddings.
346
+
347
+ Parameters
348
+ ----------
349
+ q_input_dim: int
350
+ The input dimension of the query tensor.
351
+ kv_input_dim: int
352
+ The input dimension of the key and value tensors.
353
+ qk_proj_dim: int
354
+ The projected dimension of the query and key tensors.
355
+ v_proj_dim: int
356
+ The projected dimension of the value tensors.
357
+ num_heads: int
358
+ Number of attention heads.
359
+ dropout: float
360
+ Dropout rate to apply to the attention scores.
361
+ projection_layer: str
362
+ The type of projection layer to use. Either 'linear' or 'conv'.
363
+ Basically both are linear projections, but 'conv' uses Conv1D layer as proposed in the original GPT2 paper.
364
+ use_flash_attn: bool
365
+ Whether to use Flash Attention or not. If True, Flash Attention will be used.
366
+ NOTE: Flash Attention is required to be installed.
367
+ use_rotary_emb: bool
368
+ Whether to use Rotary Embeddings or not.
369
+ rotary_theta: int
370
+ The base for the geometric progression used to compute the rotation angles.
371
+ rotary_use_xpos: bool
372
+ Whether to use XPos technique for resolving length extrapolation issue.
373
+ NOTE: This can only be enabled for autoregressive models like GPT.
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ q_input_dim,
379
+ kv_input_dim,
380
+ qk_proj_dim,
381
+ v_proj_dim,
382
+ num_heads,
383
+ dropout: float = 0.0,
384
+ projection_layer: str = "linear",
385
+ use_flash_attn: bool = True,
386
+ use_rotary_emb: bool = False,
387
+ rotary_theta: int = 1e4,
388
+ rotary_use_xpos: bool = False,
389
+ is_cross_attention: bool = False,
390
+ **kwargs,
391
+ ):
392
+ super().__init__()
393
+ assert (
394
+ qk_proj_dim % num_heads == 0
395
+ ), "qk_proj_dim must be divisible by num_heads"
396
+ assert v_proj_dim % num_heads == 0, "v_proj_dim must be divisible by num_heads"
397
+
398
+ self.num_heads = num_heads
399
+ self.dropout_rate = dropout
400
+ self.projection_layer = projection_layer
401
+ self.use_rotary_emb = use_rotary_emb
402
+ self.is_cross_attention = is_cross_attention
403
+
404
+ if use_flash_attn and not is_cross_attention:
405
+ try:
406
+ from flash_attn import flash_attn_qkvpacked_func
407
+
408
+ self.use_flash_attn = True
409
+ self.flashattn_fn = flash_attn_qkvpacked_func
410
+ except ImportError:
411
+ print("flash_attn not installed, reverting to default attention")
412
+ self.use_flash_attn = False
413
+ self.flashattn_fn = None
414
+ else:
415
+ self.use_flash_attn = False
416
+ self.flashattn_fn = None
417
+
418
+ if self.projection_layer == "linear":
419
+ self.query = nn.Linear(q_input_dim, qk_proj_dim)
420
+ self.key = nn.Linear(kv_input_dim, qk_proj_dim)
421
+ self.value = nn.Linear(kv_input_dim, v_proj_dim)
422
+ elif self.projection_layer == "conv":
423
+ self.query = Conv1D(qk_proj_dim, q_input_dim)
424
+ self.key = Conv1D(qk_proj_dim, kv_input_dim)
425
+ self.value = Conv1D(v_proj_dim, kv_input_dim)
426
+ else:
427
+ raise ValueError(
428
+ f"projection_layer must be either 'linear' or 'conv', got {projection_layer}"
429
+ )
430
+
431
+ if self.use_rotary_emb:
432
+ self.rotary_emb = RotaryEmbedding(
433
+ emb_dim=qk_proj_dim // num_heads // 2,
434
+ theta=rotary_theta,
435
+ use_xpos=rotary_use_xpos,
436
+ )
437
+
438
+ self.dr_rate = dropout
439
+ self.dropout = nn.Dropout(dropout)
440
+
441
+ def forward(
442
+ self,
443
+ x_q,
444
+ x_kv,
445
+ is_causal=False,
446
+ attention_bias=None,
447
+ attention_mask=None,
448
+ output_attentions=False,
449
+ query=None,
450
+ key=None,
451
+ value=None,
452
+ use_cache=False,
453
+ ):
454
+ """
455
+ Applies a classical self attention operation.
456
+
457
+ Parameters
458
+ ----------
459
+ x_q: torch.Tensor
460
+ The query tensor of shape (batch_size, query_seq_len, emb_dim)
461
+ x_kv: torch.Tensor
462
+ The key/value tensor of shape (batch_size, kv_seq_len, emb_dim)
463
+ attention_bias: torch.Tensor
464
+ The attention bias to apply to the attention scores. (default: None)
465
+ attention_mask: torch.Tensor
466
+ The attention mask to apply to the attention scores. Shape: (batch_size, q_len, kv_seq_len)
467
+ """
468
+ assert (x_q is not None and x_kv is not None) or (
469
+ query is not None and key is not None and value is not None
470
+ ), "Either x_q and x_kv or query, key and value must be provided"
471
+
472
+ past_memory_provided = (
473
+ query is not None and key is not None and value is not None
474
+ )
475
+
476
+ if query is None:
477
+ q_len = x_q.size(1)
478
+ k_len = x_kv.size(1)
479
+
480
+ query = self.query(x_q)
481
+ key = self.key(x_kv)
482
+ value = self.value(x_kv)
483
+
484
+ else:
485
+ q_len = query.size(1)
486
+ k_len = key.size(1)
487
+
488
+ if use_cache:
489
+ cache = (key.clone(), value.clone(), query.clone())
490
+
491
+ q = rearrange(query, "b q (h d) -> b q h d", h=self.num_heads)
492
+ k = rearrange(key, "b k (h d) -> b k h d", h=self.num_heads)
493
+ v = rearrange(value, "b v (h d) -> b v h d", h=self.num_heads)
494
+
495
+ if self.use_rotary_emb:
496
+ if use_cache and past_memory_provided:
497
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
498
+ if self.rotary_emb.use_xpos:
499
+ q, k = self.rotary_emb.rotate_queries_and_keys(q, k)
500
+ else:
501
+ q = self.rotary_emb.rotate_queries_or_keys(q)
502
+ k = self.rotary_emb.rotate_queries_or_keys(k)
503
+
504
+ if (
505
+ self.use_flash_attn
506
+ and not use_cache
507
+ and not output_attentions
508
+ and attention_bias is None
509
+ ):
510
+ qkv = torch.stack([q, k, v], dim=2).to(torch.bfloat16)
511
+ x = self.flashattn_fn(
512
+ qkv=qkv,
513
+ dropout_p=self.dropout_rate if self.training else 0.0,
514
+ causal=is_causal,
515
+ deterministic=False,
516
+ return_attn_probs=False,
517
+ )
518
+
519
+ x = x.to(x_q.dtype)
520
+ elif self.use_flash_attn and not output_attentions:
521
+ attn_bias = xops.LowerTriangularMask() if is_causal else attention_bias
522
+
523
+ if attention_mask is not None:
524
+ if attn_bias is None:
525
+ attn_bias = attention_mask
526
+ else:
527
+ if isinstance(attn_bias, torch.Tensor):
528
+ attn_bias = attn_bias + attention_mask
529
+ else:
530
+ attn_bias.add_bias(bias=attention_mask)
531
+
532
+ attn_bias = attn_bias.materialize(
533
+ shape=(q_len, k_len),
534
+ device=q.device,
535
+ dtype=q.dtype,
536
+ )
537
+ else:
538
+ if isinstance(attn_bias, torch.Tensor) and len(attn_bias.shape) == 3:
539
+ attn_bias = (
540
+ attn_bias.unsqueeze(1)
541
+ .expand(-1, self.num_heads, -1, -1)
542
+ .float()
543
+ ) # (batch_size, num_heads, q_len, k_len)
544
+ else:
545
+ attn_bias = attn_bias.materialize(
546
+ shape=(q_len, k_len),
547
+ device=q.device,
548
+ dtype=q.dtype,
549
+ )
550
+
551
+ if isinstance(attn_bias, xops.LowerTriangularMask):
552
+ attn_bias = attn_bias.materialize(
553
+ shape=(q_len, k_len),
554
+ device=q.device,
555
+ dtype=q.dtype,
556
+ )
557
+
558
+ # print(attention_mask.shape, attn_bias.shape)
559
+ # print(attn_bias[0, 0, 0, :])
560
+
561
+ need_adjustment = False
562
+ if attn_bias.shape[-2] % 8 != 0:
563
+ nearest_multiple_q = 8 * (1 + attn_bias.shape[-2] // 8)
564
+ need_adjustment = True
565
+ else:
566
+ nearest_multiple_q = attn_bias.shape[-2]
567
+
568
+ if attn_bias.shape[-1] % 8 != 0:
569
+ nearest_multiple_k = 8 * (1 + attn_bias.shape[-1] // 8)
570
+ need_adjustment = True
571
+ else:
572
+ nearest_multiple_k = attn_bias.shape[-1]
573
+
574
+ if need_adjustment:
575
+ new_attn_bias = torch.zeros(
576
+ attn_bias.shape[0],
577
+ attn_bias.shape[1],
578
+ nearest_multiple_q,
579
+ nearest_multiple_k,
580
+ ).to(attn_bias.device)
581
+ new_attn_bias[:, :, : attn_bias.shape[-2], : attn_bias.shape[-1]] = (
582
+ attn_bias
583
+ )
584
+
585
+ x = xops.memory_efficient_attention(
586
+ query=q,
587
+ key=k,
588
+ value=v,
589
+ op=None,
590
+ attn_bias=new_attn_bias[:, :, :q_len, :k_len],
591
+ p=self.dr_rate,
592
+ )
593
+ else:
594
+ attn_bias = attn_bias.to(q.dtype)
595
+ attn_bias = attn_bias.repeat(1, self.num_heads, 1, 1)
596
+ x = xops.memory_efficient_attention(
597
+ query=q,
598
+ key=k,
599
+ value=v,
600
+ op=None,
601
+ attn_bias=attn_bias,
602
+ p=self.dr_rate,
603
+ )
604
+ # x: (batch_size, query_seq_len, n_head, head_dim)
605
+ else:
606
+ # if output_attentions:
607
+ attention_scores = einsum(q, k, "b q h d, b k h d -> b h q k")
608
+ attention_scores = attention_scores / (q.size(-1) ** 0.5)
609
+
610
+ if attention_bias is not None:
611
+ attn_bias = attention_bias.unsqueeze(1).expand(
612
+ -1, self.num_heads, -1, -1
613
+ )
614
+ # elif is_causal:
615
+ # attn_bias = xops.LowerTriangularMask().materialize(
616
+ # shape=attention_scores.shape, device=attention_scores.device
617
+ # )
618
+ else:
619
+ attn_bias = None
620
+
621
+ if attention_mask is not None:
622
+ if attn_bias is None:
623
+ attn_bias = attention_mask
624
+ else:
625
+ attn_bias = attn_bias + attention_mask
626
+
627
+ attention_scores = attention_scores + attn_bias
628
+
629
+ attention_probs = attention_scores.softmax(dim=-1)
630
+ attention_probs = self.dropout(attention_probs)
631
+
632
+ x = einsum(attention_probs, v, "b h q k, b v h d -> b q h d")
633
+
634
+ x = rearrange(x, "b q h d -> b q (h d)", h=self.num_heads)
635
+
636
+ if use_cache:
637
+ if output_attentions:
638
+ return x, attention_probs, cache
639
+ else:
640
+ return x, None, cache
641
+ else:
642
+ if output_attentions:
643
+ return x, attention_probs
644
+ else:
645
+ return x, None
646
+
647
+ class DeCodonPreTrainedModel(PreTrainedModel):
648
+ """
649
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
650
+ models.
651
+ """
652
+ base_model_prefix = "decodon"
653
+ supports_gradient_checkpointing = True
654
+
655
+ def _init_weights(self, module):
656
+ """MAGNETO Initialize the weights"""
657
+ if isinstance(module, nn.Linear):
658
+ nn.init.xavier_normal_(module.weight, gain=self.config.gamma_init)
659
+ if module.bias is not None:
660
+ module.bias.data.zero_()
661
+
662
+ elif isinstance(module, nn.Embedding):
663
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
664
+ if module.padding_idx is not None:
665
+ module.weight.data[module.padding_idx].zero_()
666
+
667
+ elif isinstance(module, nn.LayerNorm):
668
+ module.bias.data.zero_()
669
+ module.weight.data.fill_(1.0)
670
+
671
+ def _set_gradient_checkpointing(self, module, value=False):
672
+ if isinstance(module, DeCodonLayer):
673
+ module.gradient_checkpointing = value
674
+
675
+
676
+ class DeCodonEmbeddings(nn.Module):
677
+ """
678
+ DeCodon Embeddings
679
+
680
+ Word, position and token type embeddings for DeCodon.
681
+ """
682
+
683
+ def __init__(self, config):
684
+ super().__init__()
685
+ self.word_embeddings = nn.Embedding(
686
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
687
+ )
688
+ self.position_embeddings = nn.Embedding(
689
+ config.max_position_embeddings, config.hidden_size
690
+ )
691
+ self.token_type_embeddings = nn.Embedding(
692
+ config.type_vocab_size, config.hidden_size
693
+ )
694
+
695
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
696
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
697
+
698
+ self.position_embedding_type = getattr(
699
+ config, "position_embedding_type", "absolute"
700
+ )
701
+
702
+ self.register_buffer(
703
+ "position_ids",
704
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
705
+ persistent=False,
706
+ )
707
+
708
+ self.register_buffer(
709
+ "token_type_ids",
710
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
711
+ persistent=False,
712
+ )
713
+
714
+ def forward(
715
+ self,
716
+ input_ids: Optional[torch.LongTensor] = None,
717
+ token_type_ids: Optional[torch.LongTensor] = None,
718
+ position_ids: Optional[torch.LongTensor] = None,
719
+ inputs_embeds: Optional[torch.FloatTensor] = None,
720
+ past_key_values_length: int = 0,
721
+ ) -> torch.Tensor:
722
+
723
+ if input_ids is not None:
724
+ input_shape = input_ids.size()
725
+ else:
726
+ input_shape = inputs_embeds.size()[:-1]
727
+
728
+ seq_length = input_shape[1]
729
+
730
+ if position_ids is None:
731
+ position_ids = self.position_ids[
732
+ :, past_key_values_length : seq_length + past_key_values_length
733
+ ]
734
+
735
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
736
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
737
+ # issue #5664
738
+ if token_type_ids is None:
739
+ if hasattr(self, "token_type_ids"):
740
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
741
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
742
+ input_shape[0], seq_length
743
+ )
744
+ token_type_ids = buffered_token_type_ids_expanded
745
+ else:
746
+ token_type_ids = torch.zeros(
747
+ input_shape, dtype=torch.long, device=self.position_ids.device
748
+ )
749
+
750
+ if inputs_embeds is None:
751
+ inputs_embeds = self.word_embeddings(input_ids)
752
+
753
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
754
+
755
+ embeddings = inputs_embeds + token_type_embeddings
756
+ if self.position_embedding_type == "absolute":
757
+ position_embeddings = self.position_embeddings(position_ids)
758
+ embeddings += position_embeddings
759
+
760
+ # embeddings = self.ln(embeddings)
761
+ embeddings = self.dropout(embeddings)
762
+
763
+ return embeddings
764
+
765
+
766
+ class DeCodonAttention(nn.Module):
767
+ """
768
+ DeCodon Attention Layer
769
+
770
+ This module supports self-attention and dilated attention with Rotary Positional Embeddings (RoPE).
771
+ """
772
+
773
+ def __init__(self, config):
774
+ super().__init__()
775
+
776
+ self.pre_layer_norm = nn.LayerNorm(
777
+ config.hidden_size, eps=config.layer_norm_eps
778
+ )
779
+ self.post_attn_dense = nn.Linear(config.hidden_size, config.hidden_size)
780
+ self.post_layer_norm = nn.LayerNorm(
781
+ config.hidden_size, eps=config.layer_norm_eps
782
+ )
783
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
784
+
785
+ self.self_attention = MultiHeadedSelfAttention(
786
+ q_input_dim=config.hidden_size,
787
+ kv_input_dim=config.hidden_size,
788
+ qk_proj_dim=config.hidden_size,
789
+ v_proj_dim=config.hidden_size,
790
+ num_heads=config.num_attention_heads,
791
+ dropout=config.attention_probs_dropout_prob,
792
+ projection_layer="conv",
793
+ use_flash_attn=config.use_flash_attn,
794
+ use_rotary_emb=config.use_rotary_emb,
795
+ rotary_theta=config.rotary_theta,
796
+ rotary_use_xpos=True,
797
+ )
798
+
799
+ def forward(
800
+ self,
801
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
802
+ attention_mask: Optional[torch.FloatTensor] = None,
803
+ output_attentions: Optional[bool] = False,
804
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
805
+ use_cache: Optional[bool] = False,
806
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
807
+
808
+ attn_input = self.pre_layer_norm(hidden_states)
809
+
810
+ if past_key_values is not None:
811
+ query = self.self_attention.query(attn_input)
812
+ key = self.self_attention.key(attn_input)
813
+ value = self.self_attention.value(attn_input)
814
+
815
+ past_key, past_value, past_query = past_key_values
816
+
817
+ # past_new_query = query[:, :-1, :]
818
+ # past_new_key = key[:, :-1, :]
819
+ # past_new_value = value[:, :-1, :]
820
+
821
+ # print(
822
+ # (past_new_query[0] != past_query[0]).sum(),
823
+ # past_new_query.size(),
824
+ # past_new_query[past_new_query != past_query].cpu().numpy(),
825
+ # past_query[past_new_query != past_query].cpu().numpy(),
826
+ # past_query.sum().item(),
827
+ # )
828
+ # print(
829
+ # (past_new_key[0] == past_key[0]).sum(),
830
+ # past_new_key.size(),
831
+ # # past_new_key[0, 0, :1024],
832
+ # # past_key[0, 0, :1024],
833
+ # past_new_key[past_new_key != past_key].cpu().numpy(),
834
+ # past_key[past_new_key != past_key].cpu().numpy(),
835
+ # past_key.sum().item(),
836
+ # )
837
+
838
+ # print(
839
+ # (past_new_value[0] == past_value[0]).sum(),
840
+ # past_new_value.size(),
841
+ # # past_new_value[0, 0, :1024],
842
+ # # past_value[0, 0, :1024],
843
+ # past_new_value[past_new_value != past_value].cpu().numpy(),
844
+ # past_value[past_new_value != past_value].cpu().numpy(),
845
+ # past_value.sum().item(),
846
+ # )
847
+
848
+ # print(query.shape, key.shape, value.shape)
849
+ # print(past_query.shape, past_key.shape, past_value.shape)
850
+
851
+ key = torch.cat(
852
+ (past_key, key), dim=1
853
+ ) # (batch_size, seq_len, hidden_size)
854
+ value = torch.cat(
855
+ (past_value, value), dim=1
856
+ ) # (batch_size, seq_len, hidden_size)
857
+ query = torch.cat((past_query, query), dim=1)
858
+
859
+ # print(query.shape, key.shape, value.shape)
860
+ # print()
861
+
862
+ attn_outputs = self.self_attention(
863
+ x_q=None,
864
+ x_kv=None,
865
+ query=query,
866
+ key=key,
867
+ value=value,
868
+ is_causal=True,
869
+ attention_mask=attention_mask,
870
+ output_attentions=output_attentions,
871
+ use_cache=use_cache,
872
+ attention_bias=None,
873
+ )
874
+ else:
875
+ attn_outputs = self.self_attention(
876
+ x_q=attn_input,
877
+ x_kv=attn_input,
878
+ is_causal=True,
879
+ attention_bias=None,
880
+ attention_mask=attention_mask,
881
+ output_attentions=output_attentions,
882
+ use_cache=use_cache,
883
+ )
884
+
885
+ attn_output = attn_outputs[0]
886
+ attn_output = self.post_layer_norm(attn_output)
887
+ attn_output = self.post_attn_dense(attn_output)
888
+ attn_output = self.dropout(attn_output)
889
+ attn_output = hidden_states + attn_output
890
+
891
+ return (attn_output,) + attn_outputs[1:]
892
+
893
+
894
+ class DeCodonFFN(nn.Module):
895
+ """
896
+ DeCodon Position-wise Feed-Forward Network
897
+ """
898
+
899
+ def __init__(self, config):
900
+ super().__init__()
901
+ embed_dim = config.hidden_size
902
+ self.pre_layer_norm = nn.LayerNorm(
903
+ config.hidden_size, eps=config.layer_norm_eps
904
+ )
905
+ self.intermediate_dense = Conv1D(config.intermediate_size, embed_dim)
906
+ self.post_layer_norm = nn.LayerNorm(
907
+ config.intermediate_size, eps=config.layer_norm_eps
908
+ )
909
+ self.post_dense = Conv1D(embed_dim, config.intermediate_size)
910
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
911
+
912
+ if isinstance(config.hidden_act, str):
913
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
914
+ else:
915
+ self.intermediate_act_fn = config.hidden_act
916
+
917
+ def forward(
918
+ self, hidden_states: Optional[Tuple[torch.FloatTensor]]
919
+ ) -> torch.FloatTensor:
920
+ hidden_states = self.pre_layer_norm(hidden_states)
921
+ hidden_states = self.intermediate_dense(hidden_states)
922
+ hidden_states = self.intermediate_act_fn(hidden_states)
923
+ hidden_states = self.post_layer_norm(hidden_states)
924
+ hidden_states = self.post_dense(hidden_states)
925
+ hidden_states = self.dropout(hidden_states)
926
+ return hidden_states
927
+
928
+
929
+ class DeCodonLayer(nn.Module):
930
+ """
931
+ DeCodon (Decoder) Layer consists of an attention layer and a position-wise feed-forward network.
932
+ """
933
+
934
+ def __init__(self, config):
935
+ super().__init__()
936
+ self.attention = DeCodonAttention(config)
937
+ self.output = DeCodonFFN(config)
938
+
939
+ def forward(
940
+ self,
941
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
942
+ attention_mask: Optional[torch.FloatTensor] = None,
943
+ output_attentions: Optional[bool] = False,
944
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
945
+ use_cache: Optional[bool] = False,
946
+ ) -> Union[
947
+ Tuple[torch.Tensor],
948
+ Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
949
+ ]:
950
+ self_attention_outputs = self.attention(
951
+ hidden_states,
952
+ attention_mask,
953
+ output_attentions=output_attentions,
954
+ past_key_values=past_key_values,
955
+ use_cache=use_cache,
956
+ )
957
+ attention_output = self_attention_outputs[0]
958
+
959
+ outputs = self_attention_outputs[
960
+ 1:
961
+ ] # add self attentions if we output attention weights
962
+
963
+ layer_output = self.output(attention_output)
964
+ outputs = (layer_output,) + outputs
965
+
966
+ return outputs
967
+
968
+
969
+ class DeCodonStack(nn.Module):
970
+ """
971
+ DeCodon Stack consists of multiple DeCodon layers.
972
+ """
973
+
974
+ def __init__(self, config):
975
+ super().__init__()
976
+ self.config = config
977
+ self.blocks = nn.ModuleList(
978
+ [DeCodonLayer(config) for _ in range(config.num_hidden_layers)]
979
+ )
980
+ self.gradient_checkpointing = False
981
+
982
+ def forward(
983
+ self,
984
+ hidden_states: torch.Tensor,
985
+ attention_mask: Optional[torch.FloatTensor] = None,
986
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
987
+ output_attentions: Optional[bool] = False,
988
+ output_hidden_states: Optional[bool] = False,
989
+ return_dict: Optional[bool] = True,
990
+ use_cache: Optional[bool] = False,
991
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
992
+
993
+ if past_key_values is None:
994
+ past_key_values = [None] * len(self.blocks)
995
+ past_length = 0
996
+ else:
997
+ past_length = past_key_values[0][0].size(-2)
998
+
999
+ all_hidden_states = () if output_hidden_states else None
1000
+ all_self_attentions = () if output_attentions else None
1001
+ presents = () if use_cache else None
1002
+ for i, (block, past_key_value) in enumerate(zip(self.blocks, past_key_values)):
1003
+ if output_hidden_states:
1004
+ all_hidden_states = all_hidden_states + (hidden_states,)
1005
+
1006
+ block_outputs = block(
1007
+ hidden_states=hidden_states,
1008
+ attention_mask=attention_mask,
1009
+ output_attentions=output_attentions,
1010
+ past_key_values=past_key_value,
1011
+ use_cache=use_cache,
1012
+ )
1013
+
1014
+ hidden_states = block_outputs[0]
1015
+
1016
+ if use_cache:
1017
+ presents = presents + (block_outputs[2],)
1018
+
1019
+ if output_attentions:
1020
+ all_self_attentions = all_self_attentions + (block_outputs[1],)
1021
+
1022
+ if output_hidden_states:
1023
+ all_hidden_states = all_hidden_states + (hidden_states,)
1024
+
1025
+ if not return_dict:
1026
+ return tuple(
1027
+ v
1028
+ for v in [
1029
+ hidden_states,
1030
+ presents,
1031
+ all_hidden_states,
1032
+ all_self_attentions,
1033
+ ]
1034
+ if v is not None
1035
+ )
1036
+
1037
+ return BaseModelOutputWithPast(
1038
+ last_hidden_state=hidden_states,
1039
+ past_key_values=presents,
1040
+ hidden_states=all_hidden_states,
1041
+ attentions=all_self_attentions,
1042
+ )
1043
+
1044
+
1045
+ class DeCodonModule(DeCodonPreTrainedModel):
1046
+ """
1047
+ The DeCodon Module (Decoder only) without any task-specific head on top.
1048
+ """
1049
+
1050
+ def __init__(self, config):
1051
+ super().__init__(config)
1052
+
1053
+ self.embeddings = DeCodonEmbeddings(config)
1054
+ self.decoder = DeCodonStack(config)
1055
+ self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1056
+
1057
+ self.gradient_checkpointing = False
1058
+
1059
+ # Initialize weights and apply final processing
1060
+ self.post_init()
1061
+
1062
+ def set_input_embeddings(self, new_embeddings):
1063
+ self.embeddings.word_embeddings = new_embeddings
1064
+
1065
+ def forward(
1066
+ self,
1067
+ input_ids: Optional[torch.LongTensor] = None,
1068
+ attention_mask: Optional[torch.FloatTensor] = None,
1069
+ token_type_ids: Optional[torch.LongTensor] = None,
1070
+ position_ids: Optional[torch.LongTensor] = None,
1071
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1072
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1073
+ output_attentions: Optional[bool] = None,
1074
+ output_hidden_states: Optional[bool] = None,
1075
+ return_dict: Optional[bool] = None,
1076
+ use_cache: Optional[bool] = False,
1077
+ ) -> Union[Tuple, BaseModelOutput]:
1078
+ output_attentions = (
1079
+ output_attentions
1080
+ if output_attentions is not None
1081
+ else self.config.output_attentions
1082
+ )
1083
+ output_hidden_states = (
1084
+ output_hidden_states
1085
+ if output_hidden_states is not None
1086
+ else self.config.output_hidden_states
1087
+ )
1088
+ return_dict = (
1089
+ return_dict if return_dict is not None else self.config.use_return_dict
1090
+ )
1091
+
1092
+ if input_ids is not None and inputs_embeds is not None:
1093
+ raise ValueError(
1094
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1095
+ )
1096
+ elif input_ids is not None:
1097
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1098
+ input_shape = input_ids.size()
1099
+ elif inputs_embeds is not None:
1100
+ input_shape = inputs_embeds.size()[:-1]
1101
+ else:
1102
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1103
+
1104
+ if past_key_values is not None:
1105
+ past_length = past_key_values[0][0].size(-2)
1106
+ else:
1107
+ past_length = 0
1108
+
1109
+ batch_size, seq_length = input_shape
1110
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1111
+
1112
+ if attention_mask is None:
1113
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
1114
+
1115
+ if token_type_ids is None:
1116
+ if hasattr(self.embeddings, "token_type_ids"):
1117
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1118
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
1119
+ batch_size, seq_length
1120
+ )
1121
+ token_type_ids = buffered_token_type_ids_expanded
1122
+ else:
1123
+ token_type_ids = torch.zeros(
1124
+ input_shape, dtype=torch.long, device=device
1125
+ )
1126
+
1127
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1128
+ # ourselves in which case we just need to make it broadcastable to all heads.
1129
+ # extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
1130
+ # attention_mask, input_shape
1131
+ # )
1132
+ embedding_output = self.embeddings(
1133
+ input_ids=input_ids,
1134
+ position_ids=position_ids,
1135
+ token_type_ids=token_type_ids,
1136
+ inputs_embeds=inputs_embeds,
1137
+ )
1138
+
1139
+ extended_attention_mask = _prepare_4d_causal_attention_mask(
1140
+ attention_mask=attention_mask,
1141
+ input_shape=(batch_size, input_shape[-1]),
1142
+ inputs_embeds=embedding_output,
1143
+ past_key_values_length=past_length,
1144
+ )
1145
+ # extended_attention_mask = attention_mask
1146
+
1147
+ decoder_outputs = self.decoder(
1148
+ embedding_output,
1149
+ attention_mask=extended_attention_mask,
1150
+ output_attentions=output_attentions,
1151
+ output_hidden_states=output_hidden_states,
1152
+ past_key_values=past_key_values,
1153
+ return_dict=return_dict,
1154
+ use_cache=use_cache,
1155
+ )
1156
+
1157
+ sequence_output = decoder_outputs[0]
1158
+
1159
+ if not return_dict:
1160
+ return (sequence_output,) + decoder_outputs[1:]
1161
+
1162
+ return BaseModelOutputWithPast(
1163
+ last_hidden_state=sequence_output,
1164
+ past_key_values=decoder_outputs.past_key_values,
1165
+ hidden_states=decoder_outputs.hidden_states,
1166
+ attentions=decoder_outputs.attentions,
1167
+ )
1168
+
1169
+
1170
+ @dataclass
1171
+ class DeCodonForPreTrainingOutput(CausalLMOutputWithPast):
1172
+ """
1173
+ Output type of [`BERTransForPreTraining`].
1174
+
1175
+ Args:
1176
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
1177
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
1178
+ (classification) loss.
1179
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1180
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1181
+ org_logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
1182
+ Prediction scores for organism classification (scores for each organism label before SoftMax).
1183
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1184
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
1185
+ shape `(batch_size, sequence_length, hidden_size)`.
1186
+
1187
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1188
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
1189
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1190
+ sequence_length)`.
1191
+
1192
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1193
+ heads.
1194
+ """
1195
+
1196
+ loss: Optional[torch.FloatTensor] = None
1197
+ logits: torch.FloatTensor = None
1198
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None
1199
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1200
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1201
+
1202
+
1203
+ class DeCodon(DeCodonPreTrainedModel):
1204
+ config_class = DeCodonConfig
1205
+ _tied_weights_keys = []
1206
+
1207
+ def __init__(self, config):
1208
+ super().__init__(config)
1209
+
1210
+ self.gpt = DeCodonModule(config)
1211
+
1212
+ # causal language modeling head
1213
+ if config.lm_type == "gpt":
1214
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1215
+ DeCodon._tied_weights_keys.append("lm_head.weight")
1216
+ else:
1217
+ self.lm_head = nn.Sequential(
1218
+ OrderedDict(
1219
+ [
1220
+ ("dropout", nn.Dropout(config.hidden_dropout_prob)),
1221
+ (
1222
+ "transform",
1223
+ nn.Linear(config.hidden_size, config.hidden_size),
1224
+ ),
1225
+ ("act", nn.ReLU()),
1226
+ (
1227
+ "norm",
1228
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
1229
+ ),
1230
+ (
1231
+ "pred",
1232
+ nn.Linear(
1233
+ config.hidden_size, config.vocab_size, bias=False
1234
+ ),
1235
+ ),
1236
+ ]
1237
+ )
1238
+ )
1239
+ DeCodon._tied_weights_keys.append("lm_head.pred.weight")
1240
+
1241
+ # Initialize weights and apply final processing
1242
+ self.post_init()
1243
+
1244
+ def get_input_embeddings(self):
1245
+ return self.gpt.embeddings.word_embeddings
1246
+
1247
+ def get_output_embeddings(self):
1248
+ return (
1249
+ self.lm_head.pred.weight
1250
+ if isinstance(self.lm_head, nn.Sequential)
1251
+ else self.lm_head.weight if self.config.lm_type == "gpt" else None
1252
+ )
1253
+
1254
+ def set_output_embeddings(self, new_embeddings):
1255
+ if isinstance(self.lm_head, nn.Sequential):
1256
+ self.lm_head.pred.weight = new_embeddings
1257
+ else:
1258
+ self.lm_head.weight = new_embeddings
1259
+
1260
+ def prepare_inputs_for_generation(
1261
+ self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs
1262
+ ):
1263
+ token_type_ids = kwargs.get("token_type_ids", None)
1264
+ attention_mask = kwargs.get("attention_mask", None)
1265
+ position_ids = kwargs.get("position_ids", None)
1266
+ use_cache = kwargs.get("use_cache", True)
1267
+
1268
+ if past_key_values is not None and use_cache:
1269
+ past_length = past_key_values[0][0].shape[1]
1270
+
1271
+ if input_ids.shape[1] > past_length:
1272
+ remove_prefix_len = past_length
1273
+ else:
1274
+ remove_prefix_len = input_ids.shape[1] - 1
1275
+
1276
+ input_ids = input_ids[:, remove_prefix_len:]
1277
+
1278
+ if token_type_ids is not None:
1279
+ token_type_ids = token_type_ids[:, remove_prefix_len:]
1280
+
1281
+ if attention_mask is not None and position_ids is None:
1282
+ # create position_ids on the fly for batch generation
1283
+ position_ids = attention_mask.long().cumsum(-1) - 1
1284
+ position_ids.masked_fill_(attention_mask == 0, 1)
1285
+ else:
1286
+ position_ids = None
1287
+
1288
+ if inputs_embeds is not None:
1289
+ model_inputs = {"inputs_embeds": inputs_embeds}
1290
+ else:
1291
+ model_inputs = {"input_ids": input_ids}
1292
+
1293
+ model_inputs.update(
1294
+ {
1295
+ "position_ids": position_ids,
1296
+ "attention_mask": attention_mask,
1297
+ "token_type_ids": token_type_ids,
1298
+ "past_key_values": past_key_values,
1299
+ "use_cache": kwargs.get("use_cache", True),
1300
+ }
1301
+ )
1302
+
1303
+ return model_inputs
1304
+
1305
+ @staticmethod
1306
+ def _reorder_cache(
1307
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1308
+ ) -> Tuple[Tuple[torch.Tensor]]:
1309
+ """
1310
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1311
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1312
+ beam_idx at every generation step.
1313
+ """
1314
+ return tuple(
1315
+ tuple(
1316
+ past_state.index_select(0, beam_idx.to(past_state.device))
1317
+ for past_state in layer_past
1318
+ )
1319
+ for layer_past in past_key_values
1320
+ )
1321
+
1322
+ def forward(
1323
+ self,
1324
+ input_ids: Optional[torch.Tensor] = None,
1325
+ attention_mask: Optional[torch.Tensor] = None,
1326
+ token_type_ids: Optional[torch.Tensor] = None,
1327
+ position_ids: Optional[torch.Tensor] = None,
1328
+ inputs_embeds: Optional[torch.Tensor] = None,
1329
+ labels: Optional[torch.Tensor] = None,
1330
+ organism: Optional[torch.Tensor] = None,
1331
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
1332
+ output_attentions: Optional[bool] = None,
1333
+ output_hidden_states: Optional[bool] = None,
1334
+ return_dict: Optional[bool] = None,
1335
+ use_cache: Optional[bool] = False,
1336
+ **kwargs,
1337
+ ) -> Union[Tuple[torch.Tensor], DeCodonForPreTrainingOutput]:
1338
+ r"""
1339
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1340
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1341
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1342
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1343
+ organism (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1344
+ Organism labels
1345
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1346
+ Used to hide legacy arguments that have been deprecated.
1347
+
1348
+ Returns:
1349
+
1350
+ Example:
1351
+
1352
+ ```python
1353
+ >>> from transformers import AutoTokenizer, BertForPreTraining
1354
+ >>> import torch
1355
+
1356
+ >>> tokenizer = AutoTokenizer.from_pretrained("bertrans-base")
1357
+ >>> model = BERTransForPreTraining.from_pretrained("bertrans-base")
1358
+
1359
+ >>> inputs = tokenizer("AAAAGGGGGGCCCCCCTTTTT", return_tensors="pt")
1360
+ >>> outputs = model(**inputs)
1361
+
1362
+ >>> prediction_logits = outputs.prediction_logits
1363
+ >>> organism_logits = outputs.organism_logits
1364
+ >>> biotype_logits = outputs.biotype_logits
1365
+ ```
1366
+ """
1367
+ return_dict = (
1368
+ return_dict if return_dict is not None else self.config.use_return_dict
1369
+ )
1370
+
1371
+ if input_ids is not None:
1372
+ batch_size, sequence_length = input_ids.shape[:2]
1373
+ else:
1374
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1375
+
1376
+ if self.config.pad_token_id is None:
1377
+ sequence_lengths = -1
1378
+ else:
1379
+ if input_ids is not None:
1380
+ sequence_lengths = (
1381
+ torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
1382
+ ).to(input_ids.device)
1383
+ else:
1384
+ sequence_lengths = -1
1385
+ logger.warning(
1386
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1387
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1388
+ )
1389
+
1390
+ gpt_outputs = self.gpt(
1391
+ input_ids,
1392
+ attention_mask=attention_mask,
1393
+ token_type_ids=token_type_ids,
1394
+ position_ids=position_ids,
1395
+ inputs_embeds=inputs_embeds,
1396
+ past_key_values=past_key_values,
1397
+ output_attentions=output_attentions,
1398
+ output_hidden_states=output_hidden_states,
1399
+ return_dict=return_dict,
1400
+ use_cache=use_cache,
1401
+ )
1402
+
1403
+ hidden_states = gpt_outputs[0] # (batch_size, sequence_length, hidden_size)
1404
+ lm_logits = self.lm_head(
1405
+ hidden_states
1406
+ ) # (batch_size, sequence_length, vocab_size)
1407
+
1408
+ loss = None
1409
+ if labels is not None:
1410
+ # move labels to correct device to enable model parallelism
1411
+ labels = labels.to(lm_logits.device)
1412
+ # Shift so that tokens < n predict n
1413
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1414
+ shift_labels = labels[..., 1:].contiguous()
1415
+ # Flatten the tokens
1416
+ loss_fct = nn.CrossEntropyLoss()
1417
+ lm_loss = loss_fct(
1418
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1419
+ )
1420
+ loss = lm_loss
1421
+ else:
1422
+ lm_loss = None
1423
+
1424
+ if not return_dict:
1425
+ output = (lm_logits,) + gpt_outputs[1:]
1426
+ return ((loss,) + output) if loss is not None else output
1427
+
1428
+ return DeCodonForPreTrainingOutput(
1429
+ loss=loss,
1430
+ logits=lm_logits,
1431
+ past_key_values=gpt_outputs.past_key_values,
1432
+ hidden_states=gpt_outputs.hidden_states,
1433
+ attentions=gpt_outputs.attentions,
1434
+ )
1435
+
1436
+ def freeze(self, layer_indices: Optional[list] = None):
1437
+ if layer_indices is None or len(layer_indices) == 0:
1438
+ for param in self.gpt.parameters():
1439
+ param.requires_grad = False
1440
+ else:
1441
+ for param in self.gpt.embeddings.parameters():
1442
+ param.requires_grad = False
1443
+
1444
+ if isinstance(layer_indices, int):
1445
+ layer_indices = [layer_indices]
1446
+
1447
+ layer_indices = [i % len(self.gpt.decoder.blocks) for i in layer_indices]
1448
+
1449
+ for i in range(len(self.gpt.decoder.blocks)):
1450
+ if i not in layer_indices:
1451
+ for param in self.gpt.decoder.blocks[i].parameters():
1452
+ param.requires_grad = False
1453
+
1454
+
1455
+
1456
+ class DeCodonForSequenceTask(DeCodonPreTrainedModel):
1457
+ def __init__(self, config):
1458
+ super().__init__(config)
1459
+ self.config = config
1460
+
1461
+ self.gpt = DeCodonModule(config)
1462
+
1463
+ if config.cls_type.lower() == "cls":
1464
+ layer_indices = config.layer_indices
1465
+ layer_indices = (
1466
+ []
1467
+ if layer_indices is None
1468
+ else (
1469
+ [layer_indices] if isinstance(layer_indices, int) else layer_indices
1470
+ )
1471
+ )
1472
+ layer_indices = [i % len(self.gpt.decoder.blocks) for i in layer_indices]
1473
+
1474
+ n_layers = len(layer_indices)
1475
+ self.layer_indices = layer_indices
1476
+ self.classifier = nn.Sequential(
1477
+ nn.LayerNorm(config.hidden_size * n_layers),
1478
+ nn.Linear(config.hidden_size * n_layers, config.hidden_size),
1479
+ ACT2CLS[config.cls_hidden_act](),
1480
+ nn.Dropout(config.cls_dropout_prob),
1481
+ nn.Linear(
1482
+ config.hidden_size,
1483
+ config.num_labels * config.num_tasks,
1484
+ ),
1485
+ )
1486
+ else:
1487
+ raise ValueError(f"Invalid cls_type: {config.cls_type}.")
1488
+
1489
+ self.init_weights()
1490
+
1491
+ def freeze(self, layers_idx: Optional[list] = None):
1492
+ if layers_idx is None or len(layers_idx) == 0:
1493
+ for param in self.gpt.parameters():
1494
+ param.requires_grad = False
1495
+ else:
1496
+ for param in self.gpt.embeddings.parameters():
1497
+ param.requires_grad = False
1498
+
1499
+ if isinstance(layers_idx, int):
1500
+ layers_idx = [layers_idx]
1501
+
1502
+ layers_idx = [i % self.config.num_hidden_layers for i in layers_idx]
1503
+
1504
+ for i in range(self.config.num_hidden_layers):
1505
+ if i not in layers_idx:
1506
+ for param in self.gpt.decoder.blocks[i].parameters():
1507
+ param.requires_grad = False
1508
+
1509
+ def forward(
1510
+ self,
1511
+ input_ids: Optional[torch.Tensor] = None,
1512
+ target: Optional[torch.Tensor] = None,
1513
+ attention_mask: Optional[torch.Tensor] = None,
1514
+ token_type_ids: Optional[torch.Tensor] = None,
1515
+ position_ids: Optional[torch.Tensor] = None,
1516
+ inputs_embeds: Optional[torch.Tensor] = None,
1517
+ output_attentions: Optional[bool] = None,
1518
+ output_hidden_states: Optional[bool] = None,
1519
+ return_dict: Optional[bool] = None,
1520
+ **kwargs,
1521
+ ):
1522
+ return_dict = (
1523
+ return_dict if return_dict is not None else self.config.use_return_dict
1524
+ )
1525
+
1526
+ if input_ids is not None:
1527
+ batch_size, sequence_length = input_ids.shape[:2]
1528
+ else:
1529
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1530
+
1531
+ if self.config.pad_token_id is None:
1532
+ sequence_lengths = -1
1533
+ else:
1534
+ if input_ids is not None:
1535
+ sequence_lengths = (
1536
+ torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
1537
+ ).to(
1538
+ input_ids.device
1539
+ ) # (batch_size,)
1540
+ else:
1541
+ sequence_lengths = -1
1542
+ logger.warning(
1543
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1544
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1545
+ )
1546
+
1547
+ gpt_outputs = self.gpt(
1548
+ input_ids,
1549
+ attention_mask=attention_mask,
1550
+ token_type_ids=token_type_ids,
1551
+ position_ids=position_ids,
1552
+ inputs_embeds=inputs_embeds,
1553
+ output_attentions=output_attentions,
1554
+ output_hidden_states=True,
1555
+ return_dict=return_dict,
1556
+ )
1557
+
1558
+ all_hidden_states = gpt_outputs.hidden_states
1559
+
1560
+ if self.config.cls_type.lower() not in ["crossattention", "ca", "cls"]:
1561
+ logits, _ = self.classifier(all_hidden_states, attention_mask)
1562
+ elif self.config.cls_type.lower() in ["crossattention", "ca"]:
1563
+ bs, seq_len = input_ids.shape
1564
+
1565
+ query_tasks = self.task_embeddings.weight # (num_tasks, hidden_size)
1566
+ query_tasks = query_tasks.unsqueeze(0).expand(
1567
+ bs, -1, -1
1568
+ ) # (batch_size, num_tasks, hidden_size)
1569
+
1570
+ cls_outputs = self.classifier(
1571
+ query_tasks,
1572
+ all_hidden_states,
1573
+ attention_mask,
1574
+ output_attentions=output_attentions,
1575
+ ) # (batch_size, num_tasks, num_labels)
1576
+
1577
+ logits, ca = cls_outputs
1578
+
1579
+ logits = logits.squeeze()
1580
+ elif self.config.cls_type.lower() == "cls":
1581
+ bs, seq_len = input_ids.shape
1582
+ # here we select latest token's hidden states as pooled output
1583
+ pooled_hidden_states = [
1584
+ h[torch.arange(bs, device=h.device), sequence_lengths - 1, :]
1585
+ for i, h in enumerate(all_hidden_states)
1586
+ if i in self.layer_indices
1587
+ ]
1588
+ pooled_output = torch.cat(
1589
+ pooled_hidden_states, dim=-1
1590
+ ) # (batch_size, hidden_size * n_layers)
1591
+
1592
+ logits = self.classifier(pooled_output)
1593
+
1594
+ loss = None
1595
+ if target is not None:
1596
+ if self.config.problem_type == "regression":
1597
+ logits = logits.view(-1, self.config.num_labels * self.config.num_tasks)
1598
+ target = target.view(-1, self.config.num_labels * self.config.num_tasks)
1599
+
1600
+ mask = target != -500.0
1601
+
1602
+ if self.config.loss_fn == "mse":
1603
+ loss_fct = nn.MSELoss()
1604
+ loss = loss_fct(logits[mask], target[mask])
1605
+ elif self.config.loss_fn == "mae":
1606
+ loss_fct = nn.L1Loss()
1607
+ loss = loss_fct(logits[mask], target[mask])
1608
+ elif self.config.loss_fn == "huber":
1609
+ loss_fct = nn.SmoothL1Loss()
1610
+ loss = loss_fct(logits[mask], target[mask])
1611
+ else:
1612
+ raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.")
1613
+ else:
1614
+ loss_fct = nn.CrossEntropyLoss()
1615
+
1616
+ logits = logits.view(-1, self.config.num_labels * self.config.num_tasks)
1617
+ target = target.view(
1618
+ -1,
1619
+ )
1620
+
1621
+ loss = loss_fct(logits, target)
1622
+
1623
+ if not return_dict:
1624
+ output = (logits,) + gpt_outputs[2:]
1625
+ return ((loss,) + output) if loss is not None else output
1626
+
1627
+ if output_attentions:
1628
+ if ca is not None:
1629
+ attentions = gpt_outputs.attentions + [ca]
1630
+ else:
1631
+ attentions = gpt_outputs.attentions
1632
+ else:
1633
+ attentions = None
1634
+
1635
+ return SequenceClassifierOutput(
1636
+ loss=loss,
1637
+ logits=logits,
1638
+ hidden_states=pooled_output,
1639
+ attentions=attentions,
1640
+ )