mohsennp commited on
Commit
82f0d64
·
verified ·
1 Parent(s): 7814d5e

Upload EnCodon

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +39 -0
  3. configuration_encodon.py +140 -0
  4. model.safetensors +3 -0
  5. modeling_encodon.py +1661 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/large_experiments/goodarzilab/mohsen/biofm/saved_models/CodonBERT/CodonBERT_L2048_l6_a8_b32_r0.0001_mlm0.2_wd0.01_g2_euk_adapt-4pxzs58g/checkpoint-90000",
3
+ "architectures": [
4
+ "EnCodon"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "attention_type": "self",
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_encodon.EnCodonConfig",
10
+ "AutoModelForMaskedLM": "modeling_encodon.EnCodon"
11
+ },
12
+ "classifier_dropout": 0.1,
13
+ "dilation_rates": null,
14
+ "gamma_init": 1.5763586678760644,
15
+ "hidden_act": "gelu",
16
+ "hidden_dropout_prob": 0.1,
17
+ "hidden_size": 1024,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 4096,
20
+ "layer_norm_eps": 1e-12,
21
+ "lm_type": "distilled",
22
+ "max_position_embeddings": 2048,
23
+ "num_attention_heads": 8,
24
+ "num_divs": 0,
25
+ "num_hidden_layers": 6,
26
+ "pad_token_id": 3,
27
+ "pooler_activation": "tanh",
28
+ "position_embedding_type": "rotary",
29
+ "rotary_theta": 10000.0,
30
+ "segment_lengths": null,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.44.2",
33
+ "type_vocab_size": 2,
34
+ "use_cache": true,
35
+ "use_flash_attn": true,
36
+ "use_nsp": false,
37
+ "use_rotary_emb": true,
38
+ "vocab_size": 69
39
+ }
configuration_encodon.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class EnCodonConfig(PretrainedConfig):
4
+ def __init__(
5
+ self,
6
+ vocab_size=70,
7
+ hidden_size=768,
8
+ num_hidden_layers=12,
9
+ num_attention_heads=12,
10
+ intermediate_size=3072,
11
+ hidden_act="gelu",
12
+ hidden_dropout_prob=0.1,
13
+ attention_probs_dropout_prob=0.1,
14
+ max_position_embeddings=512,
15
+ type_vocab_size=2,
16
+ initializer_range=0.02,
17
+ layer_norm_eps=1e-12,
18
+ pad_token_id=0,
19
+ position_embedding_type="absolute",
20
+ use_cache=True,
21
+ classifier_dropout=0.1,
22
+ gamma_init=0.1,
23
+ use_rotary_emb=False,
24
+ rotary_theta=5e5,
25
+ use_flash_attn=False,
26
+ lm_type="bert",
27
+ **kwargs,
28
+ ):
29
+ super().__init__(
30
+ vocab_size=vocab_size,
31
+ hidden_size=hidden_size,
32
+ num_hidden_layers=num_hidden_layers,
33
+ num_attention_heads=num_attention_heads,
34
+ intermediate_size=intermediate_size,
35
+ hidden_act=hidden_act,
36
+ hidden_dropout_prob=hidden_dropout_prob,
37
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
38
+ max_position_embeddings=max_position_embeddings,
39
+ type_vocab_size=type_vocab_size,
40
+ initializer_range=initializer_range,
41
+ layer_norm_eps=layer_norm_eps,
42
+ pad_token_id=pad_token_id,
43
+ position_embedding_type=position_embedding_type,
44
+ use_cache=use_cache,
45
+ classifier_dropout=classifier_dropout,
46
+ gamma_init=gamma_init,
47
+ use_rotary_emb=use_rotary_emb,
48
+ rotary_theta=rotary_theta,
49
+ use_flash_attn=use_flash_attn,
50
+ lm_type=lm_type,
51
+ **kwargs,
52
+ )
53
+
54
+
55
+ class EnCodonForDMSConfig(EnCodonConfig):
56
+ def __init__(
57
+ self,
58
+ loss_fn="huber",
59
+ num_labels=1,
60
+ task_name="NoName",
61
+ problem_type="regression",
62
+ **kwargs,
63
+ ):
64
+
65
+ if problem_type == "classification":
66
+ problem_type_ = "single_label_classification"
67
+ else:
68
+ problem_type_ = problem_type
69
+
70
+ super().__init__(
71
+ loss_fn=loss_fn,
72
+ task_name=task_name,
73
+ num_labels=num_labels,
74
+ problem_type=problem_type_,
75
+ **kwargs,
76
+ )
77
+
78
+ self.problem_type = problem_type
79
+
80
+
81
+ class EnCodonForSequenceTaskConfig(EnCodonConfig):
82
+ def __init__(
83
+ self,
84
+ task_name="NoName",
85
+ loss_fn="huber",
86
+ num_labels=2,
87
+ num_tasks=1,
88
+ cls_num_hidden_layers=1,
89
+ cls_hidden_size=128,
90
+ cls_dropout_prob=0.1,
91
+ cls_hidden_act="relu",
92
+ cls_type="mlp",
93
+ cls_num_attention_heads=8,
94
+ cls_use_rotary_emb=False,
95
+ cls_rotary_theta=1e4,
96
+ num_filters=128,
97
+ kernel_size=3,
98
+ stride=1,
99
+ dilation=1,
100
+ pooling_size=2,
101
+ pooling_type="max",
102
+ layer_indices=-1,
103
+ reduction="mean",
104
+ layer_reduction="none",
105
+ problem_type="classification",
106
+ **kwargs,
107
+ ):
108
+
109
+ if problem_type == "classification":
110
+ problem_type_ = "single_label_classification"
111
+ else:
112
+ problem_type_ = problem_type
113
+
114
+ super().__init__(
115
+ loss_fn=loss_fn,
116
+ task_name=task_name,
117
+ num_labels=num_labels,
118
+ num_tasks=num_tasks,
119
+ cls_num_hidden_layers=cls_num_hidden_layers,
120
+ cls_hidden_size=cls_hidden_size,
121
+ cls_dropout_prob=cls_dropout_prob,
122
+ cls_hidden_act=cls_hidden_act,
123
+ cls_num_attention_heads=cls_num_attention_heads,
124
+ cls_use_rotary_emb=cls_use_rotary_emb,
125
+ cls_rotary_theta=cls_rotary_theta,
126
+ cls_type=cls_type,
127
+ num_filters=num_filters,
128
+ kernel_size=kernel_size,
129
+ stride=stride,
130
+ dilation=dilation,
131
+ pooling_size=pooling_size,
132
+ pooling_type=pooling_type,
133
+ layer_indices=layer_indices,
134
+ reduction=reduction,
135
+ layer_reduction=layer_reduction,
136
+ problem_type=problem_type_,
137
+ **kwargs,
138
+ )
139
+
140
+ self.problem_type = problem_type
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5555ac1024e6015614550b345aecb749a1744231e47f8e7a9ddb1bd1f5981cc
3
+ size 319948268
modeling_encodon.py ADDED
@@ -0,0 +1,1661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from dataclasses import dataclass
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.modules import Module
6
+
7
+ from transformers.modeling_outputs import (
8
+ SequenceClassifierOutput,
9
+ ModelOutput,
10
+ MaskedLMOutput,
11
+ )
12
+
13
+ from transformers.activations import ACT2FN
14
+
15
+ from .configuration_encodon import (
16
+ EnCodonConfig,
17
+ EnCodonForDMSConfig,
18
+ )
19
+
20
+ from typing import Optional, Tuple
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from dataclasses import dataclass
27
+ from transformers import (
28
+ apply_chunking_to_forward,
29
+ )
30
+ from transformers.activations import ACT2FN
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import ModelOutput, logging
33
+ from transformers.modeling_outputs import MaskedLMOutput
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import xformers.ops as xops
40
+ from einops import rearrange, einsum
41
+ from transformers.pytorch_utils import Conv1D
42
+
43
+ """
44
+ Inspired from https://github.com/lucidrains/rotary-embedding-torch
45
+ """
46
+
47
+ from math import pi
48
+
49
+ import torch
50
+ from torch.amp import autocast
51
+ from torch import nn, einsum, broadcast_tensors, Tensor
52
+
53
+ from einops import rearrange, repeat
54
+ from typing import Optional, Union, Literal
55
+
56
+
57
+ def rotate_half(x):
58
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
59
+ x1, x2 = x.unbind(dim=-1)
60
+ x = torch.stack((-x2, x1), dim=-1)
61
+ return rearrange(x, "... d r -> ... (d r)")
62
+
63
+
64
+ @autocast(device_type="cuda", enabled=False)
65
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
66
+ """
67
+ Applies rotary embeddings to a tensor.
68
+
69
+ Parameters
70
+ ----------
71
+ freqs : Tensor
72
+ The frequencies to apply to the tensor: (seq_len, dim)
73
+ t : Tensor
74
+ The tensor to apply the rotary embeddings to: (..., seq_len, n_heads, dim)
75
+ start_index : int
76
+ The starting index to apply the rotary embeddings. (default: 0)
77
+ scale : float
78
+ The scale to apply to the rotary embeddings. (default: 1.0)
79
+
80
+ Returns
81
+ -------
82
+ Tensor
83
+ The tensor with the rotary embeddings applied.: (..., seq_len, n_heads, dim)
84
+
85
+ """
86
+ # if t.ndim == 3:
87
+ # seq_len = t.shape[seq_dim]
88
+ # freqs = freqs[-seq_len:].to(t)
89
+
90
+ rot_dim = freqs.shape[-1]
91
+ end_index = start_index + rot_dim
92
+
93
+ assert (
94
+ rot_dim <= t.shape[-1]
95
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
96
+
97
+ t_left, t, t_right = (
98
+ t[..., :start_index],
99
+ t[..., start_index:end_index],
100
+ t[..., end_index:],
101
+ )
102
+ if isinstance(scale, float):
103
+ scale = torch.tensor(scale, device=t.device, dtype=t.dtype)
104
+
105
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
106
+ return torch.cat((t_left, t, t_right), dim=-1)
107
+
108
+
109
+ # learned rotation helpers
110
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
111
+ if freq_ranges is not None:
112
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
113
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
114
+
115
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
116
+ return apply_rotary_emb(rotations, t, start_index=start_index)
117
+
118
+
119
+ class RotaryEmbedding(nn.Module):
120
+ """
121
+ Rotary Embeddings Implemenetation inspired by https://github.com/lucidrains/rotary-embedding-torch.
122
+
123
+ Rotary Positional Embeddings (RoPE) encode position information of tokens with a
124
+ rotation matrix that naturally incorporates explicit relative position dependency.
125
+
126
+ Parameters
127
+ ----------
128
+ emb_dim : int
129
+ Embedding dimension. Usually set to the dim of each head in the attention module.
130
+ freqs : Optional[Tensor]
131
+ Custom frequencies to apply to query/key tensors. (default: None)
132
+ theta : float
133
+ Base constant used for computing rotation angles.
134
+ learned_freq : bool (default: False)
135
+ Whether to learn the frequencies.
136
+ use_xpos : bool (default: False)
137
+ Whether to employ XPos technique for resolving length extrapolation issue.
138
+ NOTE: This can only be enabled for autoregressive models like GPT.
139
+ xpos_scale_base : int (default: 512)
140
+ The base for the scale factor used in XPos technique.
141
+ interpolate_factor : float (default: 1.0)
142
+ Length interpolation factor for extending context length of the pretrained model.
143
+ Final model's context length = pretrained_model_context_length * interpolate_factor.
144
+
145
+ theta_rescale_factor : float (default: 1.0)
146
+ The factor to rescale the theta.
147
+
148
+ cache_if_possible : bool (default: True)
149
+ Whether to cache the frequencies/scales if possible.
150
+
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ emb_dim,
156
+ freqs: Optional[Tensor] = None,
157
+ theta=1e4,
158
+ learned_freq=False,
159
+ use_xpos=False,
160
+ xpos_scale_base=512,
161
+ interpolate_factor=1.0,
162
+ theta_rescale_factor=1.0,
163
+ cache_if_possible=True,
164
+ ):
165
+ super().__init__()
166
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
167
+ # has some connection to NTK literature
168
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
169
+
170
+ theta *= theta_rescale_factor ** (emb_dim / (emb_dim - 2))
171
+
172
+ if freqs is None:
173
+ freqs = 1.0 / (
174
+ theta
175
+ ** (torch.arange(0, emb_dim, 2)[: (emb_dim // 2)].float() / emb_dim)
176
+ )
177
+ # freqs = torch.ones(num_freqs).float()
178
+
179
+ self.cache_if_possible = cache_if_possible
180
+
181
+ self.register_buffer("cached_freqs", None, persistent=False)
182
+ self.register_buffer("cached_scales", None, persistent=False)
183
+
184
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
185
+
186
+ self.learned_freq = learned_freq
187
+
188
+ # interpolation factors
189
+
190
+ assert interpolate_factor >= 1.0
191
+ self.interpolate_factor = interpolate_factor
192
+
193
+ # xpos
194
+ self.use_xpos = use_xpos
195
+ if not use_xpos:
196
+ self.register_buffer("scale", None, persistent=False)
197
+ return
198
+
199
+ scale = (torch.arange(0, emb_dim, 2) + 0.4 * emb_dim) / (1.4 * emb_dim)
200
+ self.scale_base = xpos_scale_base
201
+ self.register_buffer("scale", scale, persistent=False)
202
+
203
+ @property
204
+ def device(self):
205
+ return self.freqs.device
206
+
207
+ def rotate_queries_or_keys(self, t, offset=0, freq_seq_len=None, scale=None):
208
+ """
209
+ Parameters
210
+ ----------
211
+ t : Tensor
212
+ tensor to rotate: (batch_size, seq_len, num_heads, head_dim)
213
+ """
214
+ seq_len = t.shape[1]
215
+ assert (
216
+ not self.use_xpos or scale is not None
217
+ ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
218
+
219
+ if freq_seq_len is not None:
220
+ assert freq_seq_len >= seq_len
221
+ seq_len = freq_seq_len
222
+
223
+ seq = (
224
+ torch.arange(seq_len, device=t.device, dtype=t.dtype) + offset
225
+ ) / self.interpolate_factor
226
+
227
+ freqs = self.forward(
228
+ seq,
229
+ seq_len=seq_len,
230
+ offset=offset,
231
+ ).to(t.dtype)
232
+
233
+ freqs = rearrange(freqs, "n d -> n 1 d")
234
+
235
+ if scale is not None:
236
+ scale = rearrange(scale, "n d -> n 1 d")
237
+
238
+ if scale is None:
239
+ scale = torch.tensor(1.0, device=t.device, dtype=t.dtype)
240
+
241
+ return apply_rotary_emb(freqs, t, scale=scale)
242
+
243
+ def rotate_queries_and_keys(self, q, k):
244
+ """
245
+ Parameters
246
+ ----------
247
+ q : Tensor
248
+ queries tensor: (batch_size, seq_len, num_heads, head_dim)
249
+ k : Tensor
250
+ keys tensor: (batch_size, seq_len, num_heads, head_dim)
251
+ """
252
+ assert self.use_xpos
253
+ seq_len = q.shape[-3]
254
+
255
+ seq = (
256
+ torch.arange(seq_len, device=q.device, dtype=q.dtype)
257
+ ) / self.interpolate_factor
258
+
259
+ freqs = self.forward(seq, seq_len=seq_len)
260
+ scale = self.get_scale(seq, seq_len=seq_len)
261
+
262
+ freqs = rearrange(freqs, "n d -> n 1 d")
263
+ scale = rearrange(scale, "n d -> n 1 d")
264
+
265
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale)
266
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1)
267
+
268
+ rotated_q = rotated_q.type(q.dtype)
269
+ rotated_k = rotated_k.type(k.dtype)
270
+
271
+ return rotated_q, rotated_k
272
+
273
+ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
274
+ assert self.use_xpos
275
+
276
+ should_cache = self.cache_if_possible and seq_len is not None
277
+
278
+ if (
279
+ should_cache
280
+ and self.cached_scales is not None
281
+ and (seq_len + offset) <= self.cached_scales.shape[0]
282
+ ):
283
+ return self.cached_scales[offset : (offset + seq_len)]
284
+
285
+ scale = 1.0
286
+ if self.use_xpos:
287
+ power = (t - len(t) // 2) / self.scale_base
288
+ scale = self.scale ** rearrange(power, "n -> n 1")
289
+ scale = torch.cat((scale, scale), dim=-1)
290
+
291
+ if should_cache:
292
+ self.register_buffer("cached_scales", scale, persistent=False)
293
+
294
+ return scale
295
+
296
+ def rotate_queries_with_cached_keys(self, q, k, offset=0):
297
+ q_len, k_len = q.shape[1], k.shape[1]
298
+ assert q_len <= k_len
299
+
300
+ rotated_q, rotated_k = self.rotate_queries_and_keys(q, k)
301
+
302
+ rotated_q = rotated_q[:, -1:, ...]
303
+
304
+ return rotated_q, rotated_k
305
+
306
+ seq = (
307
+ torch.arange(k_len, device=q.device, dtype=q.dtype)
308
+ ) / self.interpolate_factor
309
+
310
+ if self.use_xpos:
311
+ q_scale = self.get_scale(seq[-q_len:]).to(q.dtype)
312
+ k_scale = self.get_scale(seq).to(k.dtype)
313
+
314
+ else:
315
+ k_scale = 1.0
316
+ q_scale = 1.0
317
+
318
+ rotated_q = self.rotate_queries_or_keys(
319
+ q, scale=q_scale, offset=k_len - q_len + offset
320
+ )
321
+ rotated_k = self.rotate_queries_or_keys(k, scale=k_scale**-1)
322
+
323
+ return rotated_q, rotated_k
324
+
325
+ @autocast(device_type="cuda", enabled=False)
326
+ def forward(self, t: Tensor, seq_len=None, offset=0):
327
+ should_cache = (
328
+ self.cache_if_possible and not self.learned_freq and seq_len is not None
329
+ )
330
+
331
+ if (
332
+ should_cache
333
+ and self.cached_freqs is not None
334
+ and (offset + seq_len) <= self.cached_freqs.shape[0]
335
+ ):
336
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
337
+
338
+ freqs = self.freqs
339
+
340
+ freqs = einsum("..., f -> ... f", t, freqs)
341
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
342
+
343
+ if should_cache:
344
+ self.register_buffer("cached_freqs", freqs.detach(), persistent=False)
345
+
346
+ return freqs
347
+
348
+
349
+
350
+ class MultiHeadedSelfAttention(nn.Module):
351
+ """
352
+ Multi-Headed Self Attention module supported with Flash Attention and Rotary Embeddings.
353
+
354
+ Parameters
355
+ ----------
356
+ q_input_dim: int
357
+ The input dimension of the query tensor.
358
+ kv_input_dim: int
359
+ The input dimension of the key and value tensors.
360
+ qk_proj_dim: int
361
+ The projected dimension of the query and key tensors.
362
+ v_proj_dim: int
363
+ The projected dimension of the value tensors.
364
+ num_heads: int
365
+ Number of attention heads.
366
+ dropout: float
367
+ Dropout rate to apply to the attention scores.
368
+ projection_layer: str
369
+ The type of projection layer to use. Either 'linear' or 'conv'.
370
+ Basically both are linear projections, but 'conv' uses Conv1D layer as proposed in the original GPT2 paper.
371
+ use_flash_attn: bool
372
+ Whether to use Flash Attention or not. If True, Flash Attention will be used.
373
+ NOTE: Flash Attention is required to be installed.
374
+ use_rotary_emb: bool
375
+ Whether to use Rotary Embeddings or not.
376
+ rotary_theta: int
377
+ The base for the geometric progression used to compute the rotation angles.
378
+ rotary_use_xpos: bool
379
+ Whether to use XPos technique for resolving length extrapolation issue.
380
+ NOTE: This can only be enabled for autoregressive models like GPT.
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ q_input_dim,
386
+ kv_input_dim,
387
+ qk_proj_dim,
388
+ v_proj_dim,
389
+ num_heads,
390
+ dropout: float = 0.0,
391
+ projection_layer: str = "linear",
392
+ use_flash_attn: bool = True,
393
+ use_rotary_emb: bool = False,
394
+ rotary_theta: int = 1e4,
395
+ rotary_use_xpos: bool = False,
396
+ is_cross_attention: bool = False,
397
+ **kwargs,
398
+ ):
399
+ super().__init__()
400
+ assert (
401
+ qk_proj_dim % num_heads == 0
402
+ ), "qk_proj_dim must be divisible by num_heads"
403
+ assert v_proj_dim % num_heads == 0, "v_proj_dim must be divisible by num_heads"
404
+
405
+ self.num_heads = num_heads
406
+ self.dropout_rate = dropout
407
+ self.projection_layer = projection_layer
408
+ self.use_rotary_emb = use_rotary_emb
409
+ self.is_cross_attention = is_cross_attention
410
+
411
+ if use_flash_attn and not is_cross_attention:
412
+ try:
413
+ from flash_attn import flash_attn_qkvpacked_func
414
+
415
+ self.use_flash_attn = True
416
+ self.flashattn_fn = flash_attn_qkvpacked_func
417
+ except ImportError:
418
+ print("flash_attn not installed, reverting to default attention")
419
+ self.use_flash_attn = False
420
+ self.flashattn_fn = None
421
+ else:
422
+ self.use_flash_attn = False
423
+ self.flashattn_fn = None
424
+
425
+ if self.projection_layer == "linear":
426
+ self.query = nn.Linear(q_input_dim, qk_proj_dim)
427
+ self.key = nn.Linear(kv_input_dim, qk_proj_dim)
428
+ self.value = nn.Linear(kv_input_dim, v_proj_dim)
429
+ elif self.projection_layer == "conv":
430
+ self.query = Conv1D(qk_proj_dim, q_input_dim)
431
+ self.key = Conv1D(qk_proj_dim, kv_input_dim)
432
+ self.value = Conv1D(v_proj_dim, kv_input_dim)
433
+ else:
434
+ raise ValueError(
435
+ f"projection_layer must be either 'linear' or 'conv', got {projection_layer}"
436
+ )
437
+
438
+ if self.use_rotary_emb:
439
+ self.rotary_emb = RotaryEmbedding(
440
+ emb_dim=qk_proj_dim // num_heads // 2,
441
+ theta=rotary_theta,
442
+ use_xpos=rotary_use_xpos,
443
+ )
444
+
445
+ self.dr_rate = dropout
446
+ self.dropout = nn.Dropout(dropout)
447
+
448
+ def forward(
449
+ self,
450
+ x_q,
451
+ x_kv,
452
+ is_causal=False,
453
+ attention_bias=None,
454
+ attention_mask=None,
455
+ output_attentions=False,
456
+ query=None,
457
+ key=None,
458
+ value=None,
459
+ use_cache=False,
460
+ ):
461
+ """
462
+ Applies a classical self attention operation.
463
+
464
+ Parameters
465
+ ----------
466
+ x_q: torch.Tensor
467
+ The query tensor of shape (batch_size, query_seq_len, emb_dim)
468
+ x_kv: torch.Tensor
469
+ The key/value tensor of shape (batch_size, kv_seq_len, emb_dim)
470
+ attention_bias: torch.Tensor
471
+ The attention bias to apply to the attention scores. (default: None)
472
+ attention_mask: torch.Tensor
473
+ The attention mask to apply to the attention scores. Shape: (batch_size, q_len, kv_seq_len)
474
+ """
475
+ assert (x_q is not None and x_kv is not None) or (
476
+ query is not None and key is not None and value is not None
477
+ ), "Either x_q and x_kv or query, key and value must be provided"
478
+
479
+ past_memory_provided = (
480
+ query is not None and key is not None and value is not None
481
+ )
482
+
483
+ if query is None:
484
+ q_len = x_q.size(1)
485
+ k_len = x_kv.size(1)
486
+
487
+ query = self.query(x_q)
488
+ key = self.key(x_kv)
489
+ value = self.value(x_kv)
490
+
491
+ else:
492
+ q_len = query.size(1)
493
+ k_len = key.size(1)
494
+
495
+ if use_cache:
496
+ cache = (key.clone(), value.clone(), query.clone())
497
+
498
+ q = rearrange(query, "b q (h d) -> b q h d", h=self.num_heads)
499
+ k = rearrange(key, "b k (h d) -> b k h d", h=self.num_heads)
500
+ v = rearrange(value, "b v (h d) -> b v h d", h=self.num_heads)
501
+
502
+ if self.use_rotary_emb:
503
+ if use_cache and past_memory_provided:
504
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
505
+ if self.rotary_emb.use_xpos:
506
+ q, k = self.rotary_emb.rotate_queries_and_keys(q, k)
507
+ else:
508
+ q = self.rotary_emb.rotate_queries_or_keys(q)
509
+ k = self.rotary_emb.rotate_queries_or_keys(k)
510
+
511
+ if (
512
+ self.use_flash_attn
513
+ and not use_cache
514
+ and not output_attentions
515
+ and attention_bias is None
516
+ ):
517
+ qkv = torch.stack([q, k, v], dim=2).to(torch.bfloat16)
518
+ x = self.flashattn_fn(
519
+ qkv=qkv,
520
+ dropout_p=self.dropout_rate if self.training else 0.0,
521
+ causal=is_causal,
522
+ deterministic=False,
523
+ return_attn_probs=False,
524
+ )
525
+
526
+ x = x.to(x_q.dtype)
527
+ elif self.use_flash_attn and not output_attentions:
528
+ attn_bias = xops.LowerTriangularMask() if is_causal else attention_bias
529
+
530
+ if attention_mask is not None:
531
+ if attn_bias is None:
532
+ attn_bias = attention_mask
533
+ else:
534
+ if isinstance(attn_bias, torch.Tensor):
535
+ attn_bias = attn_bias + attention_mask
536
+ else:
537
+ attn_bias.add_bias(bias=attention_mask)
538
+
539
+ attn_bias = attn_bias.materialize(
540
+ shape=(q_len, k_len),
541
+ device=q.device,
542
+ dtype=q.dtype,
543
+ )
544
+ else:
545
+ if isinstance(attn_bias, torch.Tensor) and len(attn_bias.shape) == 3:
546
+ attn_bias = (
547
+ attn_bias.unsqueeze(1)
548
+ .expand(-1, self.num_heads, -1, -1)
549
+ .float()
550
+ ) # (batch_size, num_heads, q_len, k_len)
551
+ else:
552
+ attn_bias = attn_bias.materialize(
553
+ shape=(q_len, k_len),
554
+ device=q.device,
555
+ dtype=q.dtype,
556
+ )
557
+
558
+ if isinstance(attn_bias, xops.LowerTriangularMask):
559
+ attn_bias = attn_bias.materialize(
560
+ shape=(q_len, k_len),
561
+ device=q.device,
562
+ dtype=q.dtype,
563
+ )
564
+
565
+ # print(attention_mask.shape, attn_bias.shape)
566
+ # print(attn_bias[0, 0, 0, :])
567
+
568
+ need_adjustment = False
569
+ if attn_bias.shape[-2] % 8 != 0:
570
+ nearest_multiple_q = 8 * (1 + attn_bias.shape[-2] // 8)
571
+ need_adjustment = True
572
+ else:
573
+ nearest_multiple_q = attn_bias.shape[-2]
574
+
575
+ if attn_bias.shape[-1] % 8 != 0:
576
+ nearest_multiple_k = 8 * (1 + attn_bias.shape[-1] // 8)
577
+ need_adjustment = True
578
+ else:
579
+ nearest_multiple_k = attn_bias.shape[-1]
580
+
581
+ if need_adjustment:
582
+ new_attn_bias = torch.zeros(
583
+ attn_bias.shape[0],
584
+ attn_bias.shape[1],
585
+ nearest_multiple_q,
586
+ nearest_multiple_k,
587
+ ).to(attn_bias.device)
588
+ new_attn_bias[:, :, : attn_bias.shape[-2], : attn_bias.shape[-1]] = (
589
+ attn_bias
590
+ )
591
+
592
+ x = xops.memory_efficient_attention(
593
+ query=q,
594
+ key=k,
595
+ value=v,
596
+ op=None,
597
+ attn_bias=new_attn_bias[:, :, :q_len, :k_len],
598
+ p=self.dr_rate,
599
+ )
600
+ else:
601
+ attn_bias = attn_bias.to(q.dtype)
602
+ attn_bias = attn_bias.repeat(1, self.num_heads, 1, 1)
603
+ x = xops.memory_efficient_attention(
604
+ query=q,
605
+ key=k,
606
+ value=v,
607
+ op=None,
608
+ attn_bias=attn_bias,
609
+ p=self.dr_rate,
610
+ )
611
+ # x: (batch_size, query_seq_len, n_head, head_dim)
612
+ else:
613
+ # if output_attentions:
614
+ attention_scores = einsum(q, k, "b q h d, b k h d -> b h q k")
615
+ attention_scores = attention_scores / (q.size(-1) ** 0.5)
616
+
617
+ if attention_bias is not None:
618
+ attn_bias = attention_bias.unsqueeze(1).expand(
619
+ -1, self.num_heads, -1, -1
620
+ )
621
+ # elif is_causal:
622
+ # attn_bias = xops.LowerTriangularMask().materialize(
623
+ # shape=attention_scores.shape, device=attention_scores.device
624
+ # )
625
+ else:
626
+ attn_bias = None
627
+
628
+ if attention_mask is not None:
629
+ if attn_bias is None:
630
+ attn_bias = attention_mask
631
+ else:
632
+ attn_bias = attn_bias + attention_mask
633
+
634
+ attention_scores = attention_scores + attn_bias
635
+
636
+ attention_probs = attention_scores.softmax(dim=-1)
637
+ attention_probs = self.dropout(attention_probs)
638
+
639
+ x = einsum(attention_probs, v, "b h q k, b v h d -> b q h d")
640
+
641
+ x = rearrange(x, "b q h d -> b q (h d)", h=self.num_heads)
642
+
643
+ if use_cache:
644
+ if output_attentions:
645
+ return x, attention_probs, cache
646
+ else:
647
+ return x, None, cache
648
+ else:
649
+ if output_attentions:
650
+ return x, attention_probs
651
+ else:
652
+ return x, None
653
+
654
+ class EnCodonPreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+ base_model_prefix = "encodon"
660
+ supports_gradient_checkpointing = True
661
+
662
+ def _init_weights(self, module):
663
+ """MAGNETO Initialize the weights"""
664
+ if isinstance(module, nn.Linear):
665
+ # gain should be 1 for query and key weights
666
+ is_qk = False
667
+ for n, p in module.named_parameters():
668
+ if "query" in n or "key" in n:
669
+ is_qk = True
670
+ break
671
+ if is_qk:
672
+ nn.init.xavier_normal_(module.weight, gain=1.0)
673
+ else:
674
+ nn.init.xavier_normal_(module.weight, gain=self.config.gamma_init)
675
+ if module.bias is not None:
676
+ module.bias.data.zero_()
677
+
678
+ elif isinstance(module, nn.Embedding):
679
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
680
+ if module.padding_idx is not None:
681
+ module.weight.data[module.padding_idx].zero_()
682
+
683
+ elif isinstance(module, nn.LayerNorm):
684
+ module.bias.data.zero_()
685
+ module.weight.data.fill_(1.0)
686
+
687
+ def _set_gradient_checkpointing(self, module, value=False):
688
+ if isinstance(module, EnCodonStack):
689
+ module.gradient_checkpointing = value
690
+
691
+
692
+ class EnCodonEmbeddings(nn.Module):
693
+ """
694
+ EnCodon Embeddings module. This module contains word, token type, and (absolute) positional embeddings.
695
+ NOTE: This module is adapted from the original HuggingFace implementation.
696
+ NOTE: Absolute positional embeddings is mutual exclusive with rotary embeddings.
697
+ """
698
+
699
+ def __init__(self, config):
700
+ super().__init__()
701
+ self.word_embeddings = nn.Embedding(
702
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
703
+ )
704
+ self.position_embeddings = nn.Embedding(
705
+ config.max_position_embeddings, config.hidden_size
706
+ )
707
+ self.token_type_embeddings = nn.Embedding(
708
+ config.type_vocab_size, config.hidden_size
709
+ )
710
+
711
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
712
+ # any TensorFlow checkpoint file
713
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
714
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
715
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
716
+ self.position_embedding_type = getattr(
717
+ config, "position_embedding_type", "absolute"
718
+ )
719
+ self.register_buffer(
720
+ "position_ids",
721
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
722
+ persistent=False,
723
+ )
724
+ self.register_buffer(
725
+ "token_type_ids",
726
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
727
+ persistent=False,
728
+ )
729
+
730
+ def forward(
731
+ self,
732
+ input_ids: Optional[torch.LongTensor] = None,
733
+ token_type_ids: Optional[torch.LongTensor] = None,
734
+ position_ids: Optional[torch.LongTensor] = None,
735
+ inputs_embeds: Optional[torch.FloatTensor] = None,
736
+ past_key_values_length: int = 0,
737
+ ) -> torch.Tensor:
738
+ if input_ids is not None:
739
+ input_shape = input_ids.size()
740
+ else:
741
+ input_shape = inputs_embeds.size()[:-1]
742
+
743
+ seq_length = input_shape[1]
744
+
745
+ if position_ids is None:
746
+ position_ids = self.position_ids[
747
+ :, past_key_values_length : seq_length + past_key_values_length
748
+ ]
749
+
750
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
751
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
752
+ # issue #5664
753
+ if token_type_ids is None:
754
+ if hasattr(self, "token_type_ids"):
755
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
756
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
757
+ input_shape[0], seq_length
758
+ )
759
+ token_type_ids = buffered_token_type_ids_expanded
760
+ else:
761
+ token_type_ids = torch.zeros(
762
+ input_shape, dtype=torch.long, device=self.position_ids.device
763
+ )
764
+
765
+ if inputs_embeds is None:
766
+ inputs_embeds = self.word_embeddings(input_ids)
767
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
768
+
769
+ embeddings = inputs_embeds + token_type_embeddings
770
+ if self.position_embedding_type == "absolute":
771
+ position_embeddings = self.position_embeddings(position_ids)
772
+ embeddings += position_embeddings
773
+ embeddings = self.LayerNorm(embeddings)
774
+ embeddings = self.dropout(embeddings)
775
+ return embeddings
776
+
777
+
778
+ class EnCodonAttention(nn.Module):
779
+ """
780
+ EnCodon Attention module. This module supports two types of attention:
781
+ (1) self-attention and (2) dilated as described in Transformers and LongNet papers, respectively.
782
+ """
783
+
784
+ def __init__(self, config, layer_idx=0):
785
+ super().__init__()
786
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
787
+ config, "embedding_size"
788
+ ):
789
+ raise ValueError(
790
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
791
+ f"heads ({config.num_attention_heads})"
792
+ )
793
+
794
+ self.layer_idx = layer_idx
795
+ self.pre_layer_norm = nn.LayerNorm(
796
+ config.hidden_size, eps=config.layer_norm_eps
797
+ )
798
+ self.post_attn_dense = nn.Linear(config.hidden_size, config.hidden_size)
799
+ self.post_layer_norm = nn.LayerNorm(
800
+ config.hidden_size, eps=config.layer_norm_eps
801
+ )
802
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
803
+
804
+ self.self_attention = MultiHeadedSelfAttention(
805
+ q_input_dim=config.hidden_size,
806
+ kv_input_dim=config.hidden_size,
807
+ qk_proj_dim=config.hidden_size,
808
+ v_proj_dim=config.hidden_size,
809
+ num_heads=config.num_attention_heads,
810
+ dropout=config.attention_probs_dropout_prob,
811
+ projection_layer="linear",
812
+ use_flash_attn=config.use_flash_attn,
813
+ use_rotary_emb=config.use_rotary_emb,
814
+ rotary_theta=config.rotary_theta,
815
+ rotary_use_xpos=False,
816
+ )
817
+
818
+ def forward(
819
+ self,
820
+ hidden_states: torch.Tensor,
821
+ attention_mask: Optional[torch.FloatTensor] = None,
822
+ attention_bias: Optional[torch.FloatTensor] = None,
823
+ output_attentions: Optional[bool] = False,
824
+ ) -> Tuple[torch.Tensor]:
825
+ attn_input = self.pre_layer_norm(hidden_states)
826
+ attn_outputs = self.self_attention(
827
+ attn_input,
828
+ attn_input,
829
+ is_causal=False,
830
+ attention_bias=attention_bias,
831
+ attention_mask=attention_mask,
832
+ output_attentions=output_attentions,
833
+ )
834
+
835
+ attn_output = attn_outputs[0]
836
+ attn_output = self.post_layer_norm(attn_output)
837
+ attn_output = self.post_attn_dense(attn_output)
838
+ attn_output = self.dropout(attn_output)
839
+ attn_output = hidden_states + attn_output
840
+ return (attn_output,) + attn_outputs[1:] # add attentions if we output them
841
+
842
+
843
+ class EnCodonFFN(nn.Module):
844
+ """
845
+ EnCodon Position-wise Feed-Forward Network module.
846
+ """
847
+
848
+ def __init__(self, config):
849
+ super().__init__()
850
+ self.pre_layer_norm = nn.LayerNorm(
851
+ config.hidden_size, eps=config.layer_norm_eps
852
+ )
853
+ self.intermediate_dense = nn.Linear(
854
+ config.hidden_size, config.intermediate_size
855
+ )
856
+ self.post_layer_norm = nn.LayerNorm(
857
+ config.intermediate_size, eps=config.layer_norm_eps
858
+ )
859
+ self.post_dense = nn.Linear(config.intermediate_size, config.hidden_size)
860
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
861
+
862
+ if isinstance(config.hidden_act, str):
863
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
864
+ else:
865
+ self.intermediate_act_fn = config.hidden_act
866
+
867
+ def forward(self, input_states: torch.Tensor) -> torch.Tensor:
868
+ hidden_states = self.pre_layer_norm(input_states)
869
+ hidden_states = self.intermediate_dense(hidden_states)
870
+ hidden_states = self.intermediate_act_fn(hidden_states)
871
+ hidden_states = self.post_layer_norm(hidden_states)
872
+ hidden_states = self.post_dense(hidden_states)
873
+ hidden_states = self.dropout(hidden_states)
874
+ return hidden_states + input_states
875
+
876
+
877
+ class EnCodonLayer(nn.Module):
878
+ """
879
+ EnCodon Encoder layer module.
880
+
881
+ This module contains an attention layer followed by a position-wise feed-forward layer.
882
+ """
883
+
884
+ def __init__(self, config):
885
+ super().__init__()
886
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
887
+ self.seq_len_dim = 1
888
+ self.attention = EnCodonAttention(config)
889
+ self.output = EnCodonFFN(config)
890
+
891
+ def forward(
892
+ self,
893
+ hidden_states: torch.Tensor,
894
+ attention_mask: Optional[torch.FloatTensor] = None,
895
+ attention_bias: Optional[torch.FloatTensor] = None,
896
+ output_attentions: Optional[bool] = False,
897
+ ) -> Tuple[torch.Tensor]:
898
+ self_attention_outputs = self.attention(
899
+ hidden_states=hidden_states,
900
+ attention_mask=attention_mask,
901
+ attention_bias=attention_bias,
902
+ output_attentions=output_attentions,
903
+ )
904
+ attention_output = self_attention_outputs[0]
905
+
906
+ outputs = self_attention_outputs[
907
+ 1:
908
+ ] # add self attentions if we output attention weights
909
+
910
+ layer_output = apply_chunking_to_forward(
911
+ self.feed_forward_chunk,
912
+ self.chunk_size_feed_forward,
913
+ self.seq_len_dim,
914
+ attention_output,
915
+ )
916
+ outputs = (layer_output,) + outputs
917
+
918
+ return outputs
919
+
920
+ def feed_forward_chunk(self, attention_output):
921
+ layer_output = self.output(attention_output)
922
+ return layer_output
923
+
924
+
925
+ @dataclass
926
+ class BERTEncoderOutput(ModelOutput):
927
+ last_hidden_state: torch.Tensor
928
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
929
+ attentions: Optional[Tuple[torch.Tensor]] = None
930
+
931
+
932
+ class EnCodonStack(nn.Module):
933
+ """
934
+ EnCodon stack module. This module contains multiple EnCodon layers.
935
+ """
936
+
937
+ def __init__(self, config):
938
+ super().__init__()
939
+ self.config = config
940
+ self.layer = nn.ModuleList(
941
+ [EnCodonLayer(config) for _ in range(config.num_hidden_layers)]
942
+ )
943
+ self.gradient_checkpointing = False
944
+
945
+ def forward(
946
+ self,
947
+ hidden_states: torch.Tensor,
948
+ attention_mask: Optional[torch.FloatTensor] = None,
949
+ attention_bias: Optional[torch.FloatTensor] = None,
950
+ output_attentions: Optional[bool] = False,
951
+ output_hidden_states: Optional[bool] = False,
952
+ return_dict: Optional[bool] = True,
953
+ ):
954
+ all_hidden_states = () if output_hidden_states else None
955
+ all_self_attentions = () if output_attentions else None
956
+
957
+ for i, layer_module in enumerate(self.layer):
958
+ if output_hidden_states:
959
+ all_hidden_states = all_hidden_states + (hidden_states,)
960
+
961
+ layer_outputs = layer_module(
962
+ hidden_states=hidden_states,
963
+ attention_mask=attention_mask,
964
+ attention_bias=attention_bias,
965
+ output_attentions=output_attentions,
966
+ )
967
+
968
+ hidden_states = layer_outputs[0]
969
+ if output_attentions:
970
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
971
+
972
+ if output_hidden_states:
973
+ all_hidden_states = all_hidden_states + (hidden_states,)
974
+
975
+ if not return_dict:
976
+ return tuple(
977
+ v
978
+ for v in [
979
+ hidden_states,
980
+ all_hidden_states,
981
+ all_self_attentions,
982
+ ]
983
+ if v is not None
984
+ )
985
+ return BERTEncoderOutput(
986
+ last_hidden_state=hidden_states,
987
+ hidden_states=all_hidden_states,
988
+ attentions=all_self_attentions,
989
+ )
990
+
991
+
992
+ class BERTPooler(nn.Module):
993
+ """
994
+ BERT Pooler module. This module pools the desired token from the hidden states
995
+ which usually used for sequence-level classification/regression tasks.
996
+ """
997
+
998
+ def __init__(self, config, pooled_token_position=0):
999
+ super().__init__()
1000
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1001
+ self.activation = nn.Tanh() if config.pooler_activation == "tanh" else nn.ReLU()
1002
+ self.pooled_token_position = pooled_token_position
1003
+
1004
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1005
+ # We "pool" the model by simply taking the hidden state corresponding
1006
+ # to the first token.
1007
+ pooled_token_tensor = hidden_states[:, self.pooled_token_position]
1008
+ pooled_output = self.dense(pooled_token_tensor)
1009
+ pooled_output = self.activation(pooled_output)
1010
+ return pooled_output
1011
+
1012
+
1013
+ class BERTPredictionHeadTransform(nn.Module):
1014
+ def __init__(self, config):
1015
+ super().__init__()
1016
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1017
+ if isinstance(config.hidden_act, str):
1018
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1019
+ else:
1020
+ self.transform_act_fn = config.hidden_act
1021
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1022
+
1023
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1024
+ hidden_states = self.dense(hidden_states)
1025
+ hidden_states = self.transform_act_fn(hidden_states)
1026
+ hidden_states = self.LayerNorm(hidden_states)
1027
+ return hidden_states
1028
+
1029
+
1030
+ class BERTLMPredictionHead(nn.Module):
1031
+ def __init__(self, config):
1032
+ super().__init__()
1033
+ self.transform = BERTPredictionHeadTransform(config)
1034
+
1035
+ # The output weights are the same as the input embeddings, but there is
1036
+ # an output-only bias for each token.
1037
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1038
+
1039
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1040
+
1041
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1042
+ self.decoder.bias = self.bias
1043
+
1044
+ def forward(self, hidden_states):
1045
+ hidden_states = self.transform(hidden_states)
1046
+ hidden_states = self.decoder(hidden_states)
1047
+ return hidden_states
1048
+
1049
+
1050
+ @dataclass
1051
+ class BERTModelOutput(ModelOutput):
1052
+ last_hidden_state: torch.Tensor
1053
+ pooled_output: Optional[torch.Tensor] = None
1054
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
1055
+ attentions: Optional[Tuple[torch.Tensor]] = None
1056
+
1057
+
1058
+ class EnCodonModule(EnCodonPreTrainedModel):
1059
+ """
1060
+ EnCodon Module
1061
+
1062
+ Parameters
1063
+ ----------
1064
+ config : EnCodonConfig
1065
+ Configuration class for EnCodon model.
1066
+ add_pooling_layer : bool (default: True)
1067
+ Whether to add a pooling layer to the model.
1068
+ pooled_token_position : int (default: 0)
1069
+ The position of the token to be pooled from the hidden states.
1070
+
1071
+ """
1072
+
1073
+ def __init__(self, config, add_pooling_layer=True, pooled_token_position=0):
1074
+ super().__init__(config)
1075
+ self.config = config
1076
+
1077
+ self.embeddings = EnCodonEmbeddings(config)
1078
+ self.encoder = EnCodonStack(config)
1079
+
1080
+ self.pooler = (
1081
+ BERTPooler(config, pooled_token_position=pooled_token_position)
1082
+ if add_pooling_layer
1083
+ else None
1084
+ )
1085
+
1086
+ # Initialize weights and apply final processing
1087
+ self.post_init()
1088
+
1089
+ def get_input_embeddings(self):
1090
+ return self.embeddings.word_embeddings
1091
+
1092
+ def set_input_embeddings(self, value):
1093
+ self.embeddings.word_embeddings = value
1094
+
1095
+ def forward(
1096
+ self,
1097
+ input_ids: Optional[torch.Tensor] = None,
1098
+ attention_mask: Optional[torch.Tensor] = None,
1099
+ attention_bias: Optional[torch.Tensor] = None,
1100
+ token_type_ids: Optional[torch.Tensor] = None,
1101
+ position_ids: Optional[torch.Tensor] = None,
1102
+ inputs_embeds: Optional[torch.Tensor] = None,
1103
+ output_attentions: Optional[bool] = None,
1104
+ output_hidden_states: Optional[bool] = None,
1105
+ return_dict: Optional[bool] = None,
1106
+ **kwargs,
1107
+ ):
1108
+ """
1109
+ Forward pass for the BERT model.
1110
+
1111
+ Parameters
1112
+ ----------
1113
+ input_ids : Optional[torch.Tensor]
1114
+ The input IDs for the model. Expected Shape: (batch_size, seq_length)
1115
+
1116
+ attention_mask : Optional[torch.Tensor]
1117
+ The attention mask for the model. Expected Shape: (batch_size, seq_length)
1118
+ - 1 for tokens that are NOT MASKED
1119
+ - 0 for tokens that are MASKED
1120
+
1121
+ token_type_ids : Optional[torch.Tensor]
1122
+ The token type IDs for the model. Expected Shape: (batch_size, seq_length)
1123
+
1124
+ position_ids : Optional[torch.Tensor]
1125
+ The position IDs for the model. Expected Shape: (batch_size, seq_length)
1126
+
1127
+ inputs_embeds : Optional[torch.Tensor]
1128
+ The input embeddings for the model. Expected Shape: (batch_size, seq_length, hidden_size)
1129
+
1130
+ output_attentions : Optional[bool]
1131
+ Whether to output attentions or not.
1132
+
1133
+ output_hidden_states : Optional[bool]
1134
+ Whether to output hidden states or not.
1135
+
1136
+ return_dict : Optional[bool]
1137
+ Whether to return a dictionary or not.
1138
+
1139
+ Returns
1140
+ -------
1141
+ BERTModelOutput
1142
+ The output of the BERT model.
1143
+ """
1144
+ output_attentions = (
1145
+ output_attentions
1146
+ if output_attentions is not None
1147
+ else self.config.output_attentions
1148
+ )
1149
+ output_hidden_states = (
1150
+ output_hidden_states
1151
+ if output_hidden_states is not None
1152
+ else self.config.output_hidden_states
1153
+ )
1154
+ return_dict = (
1155
+ return_dict if return_dict is not None else self.config.use_return_dict
1156
+ )
1157
+
1158
+ if input_ids is not None and inputs_embeds is not None:
1159
+ raise ValueError(
1160
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1161
+ )
1162
+ elif input_ids is not None:
1163
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1164
+ input_shape = input_ids.size()
1165
+ elif inputs_embeds is not None:
1166
+ input_shape = inputs_embeds.size()[:-1]
1167
+ else:
1168
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1169
+
1170
+ batch_size, seq_length = input_shape
1171
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1172
+
1173
+ if attention_mask is None:
1174
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
1175
+
1176
+ if token_type_ids is None:
1177
+ if hasattr(self.embeddings, "token_type_ids"):
1178
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1179
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
1180
+ batch_size, seq_length
1181
+ )
1182
+ token_type_ids = buffered_token_type_ids_expanded
1183
+ else:
1184
+ token_type_ids = torch.zeros(
1185
+ input_shape, dtype=torch.long, device=device
1186
+ )
1187
+
1188
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1189
+ # ourselves in which case we just need to make it broadcastable to all heads.
1190
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
1191
+ attention_mask, input_shape
1192
+ )
1193
+
1194
+ embedding_output = self.embeddings(
1195
+ input_ids=input_ids,
1196
+ position_ids=position_ids,
1197
+ token_type_ids=token_type_ids,
1198
+ inputs_embeds=inputs_embeds,
1199
+ )
1200
+
1201
+ encoder_outputs = self.encoder(
1202
+ embedding_output,
1203
+ attention_mask=extended_attention_mask,
1204
+ attention_bias=attention_bias,
1205
+ output_attentions=output_attentions,
1206
+ output_hidden_states=output_hidden_states,
1207
+ return_dict=return_dict,
1208
+ )
1209
+
1210
+ sequence_output = encoder_outputs[0]
1211
+ pooled_output = (
1212
+ self.pooler(sequence_output) if self.pooler is not None else None
1213
+ )
1214
+
1215
+ if not return_dict:
1216
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1217
+
1218
+ return BERTModelOutput(
1219
+ last_hidden_state=sequence_output,
1220
+ pooled_output=pooled_output,
1221
+ hidden_states=encoder_outputs.hidden_states,
1222
+ attentions=encoder_outputs.attentions,
1223
+ )
1224
+
1225
+
1226
+ @dataclass
1227
+ class EnCodonOutput(MaskedLMOutput):
1228
+ """
1229
+ Base class for EnCodon Outputs
1230
+
1231
+ Args:
1232
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1233
+ Masked language modeling (MLM) loss.
1234
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1235
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1236
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1237
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1238
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1239
+
1240
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1241
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
1242
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1243
+ sequence_length)`.
1244
+
1245
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1246
+ heads.
1247
+ """
1248
+
1249
+ loss: Optional[torch.FloatTensor] = None
1250
+ logits: torch.FloatTensor = None
1251
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1252
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1253
+
1254
+
1255
+ class EnCodon(EnCodonPreTrainedModel):
1256
+ config_class = EnCodonConfig
1257
+
1258
+ def __init__(self, config):
1259
+ super().__init__(config)
1260
+
1261
+ self.bert = EnCodonModule(config)
1262
+ if self.config.lm_type == "bert":
1263
+ self.cls = BERTLMPredictionHead(config)
1264
+ else:
1265
+ self.cls = nn.Sequential(
1266
+ nn.Linear(config.hidden_size, config.hidden_size),
1267
+ nn.GELU(),
1268
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
1269
+ nn.Linear(config.hidden_size, config.vocab_size),
1270
+ )
1271
+
1272
+ def forward(
1273
+ self,
1274
+ input_ids: Optional[torch.Tensor] = None,
1275
+ attention_mask: Optional[torch.Tensor] = None,
1276
+ token_type_ids: Optional[torch.Tensor] = None,
1277
+ special_tokens_mask: Optional[torch.Tensor] = None,
1278
+ position_ids: Optional[torch.Tensor] = None,
1279
+ inputs_embeds: Optional[torch.Tensor] = None,
1280
+ labels: Optional[torch.Tensor] = None,
1281
+ nsp_labels: Optional[torch.Tensor] = None,
1282
+ div_labels: Optional[torch.Tensor] = None,
1283
+ output_attentions: Optional[bool] = None,
1284
+ output_hidden_states: Optional[bool] = None,
1285
+ return_dict: Optional[bool] = None,
1286
+ return_pooled_output: Optional[bool] = False,
1287
+ **kwargs,
1288
+ ):
1289
+ r"""
1290
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1291
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1292
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1293
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1294
+ """
1295
+
1296
+ return_dict = (
1297
+ return_dict if return_dict is not None else self.config.use_return_dict
1298
+ )
1299
+
1300
+ outputs = self.bert(
1301
+ input_ids,
1302
+ attention_mask=attention_mask,
1303
+ token_type_ids=token_type_ids,
1304
+ position_ids=position_ids,
1305
+ inputs_embeds=inputs_embeds,
1306
+ output_attentions=output_attentions,
1307
+ output_hidden_states=output_hidden_states,
1308
+ return_dict=return_dict,
1309
+ )
1310
+
1311
+ sequence_output = outputs[0]
1312
+ prediction_scores = self.cls(sequence_output)
1313
+
1314
+ loss = None
1315
+ if labels is not None:
1316
+ loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
1317
+ loss = loss_fct(
1318
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1319
+ )
1320
+
1321
+ if not return_dict:
1322
+ output = (prediction_scores,) + outputs[2:]
1323
+ return ((loss,) + output) if loss is not None else output
1324
+
1325
+ return EnCodonOutput(
1326
+ loss=loss,
1327
+ logits=prediction_scores,
1328
+ hidden_states=outputs.hidden_states,
1329
+ attentions=outputs.attentions,
1330
+ )
1331
+
1332
+ def get_codon_embeddings(self) -> Module:
1333
+ return self.bert.embeddings.word_embeddings
1334
+
1335
+ def freeze_bert(self, layer_indices: Optional[list] = None):
1336
+ if layer_indices is None or len(layer_indices) == 0:
1337
+ for param in self.bert.parameters():
1338
+ param.requires_grad = False
1339
+ else:
1340
+ for param in self.bert.embeddings.parameters():
1341
+ param.requires_grad = False
1342
+
1343
+ if isinstance(layer_indices, int):
1344
+ layer_indices = [layer_indices]
1345
+
1346
+ layer_indices = [i % len(self.bert.encoder.layer) for i in layer_indices]
1347
+
1348
+ for i in range(len(self.bert.encoder.layer)):
1349
+ if i not in layer_indices:
1350
+ for param in self.bert.encoder.layer[i].parameters():
1351
+ param.requires_grad = False
1352
+
1353
+ for param in self.bert.pooler.parameters():
1354
+ param.requires_grad = False
1355
+
1356
+
1357
+ @dataclass
1358
+ class EnCodonForDMSOutput(ModelOutput):
1359
+ loss: Optional[torch.FloatTensor] = None
1360
+ logits: torch.FloatTensor = None
1361
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1362
+
1363
+
1364
+ class EnCodonForDMS(EnCodonPreTrainedModel):
1365
+ config_class = EnCodonForDMSConfig
1366
+ _tied_weights_keys = ["cls.layer.3.weight", "cls.layer.3.bias"]
1367
+
1368
+ def __init__(self, config):
1369
+ super().__init__(config)
1370
+
1371
+ self.bert = EnCodonModule(config)
1372
+ if self.config.lm_type == "bert":
1373
+ self.cls = BERTLMPredictionHead(config)
1374
+ else:
1375
+ self.cls = nn.Sequential(
1376
+ nn.Linear(config.hidden_size, config.hidden_size),
1377
+ nn.GELU(),
1378
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
1379
+ nn.Linear(config.hidden_size, config.vocab_size),
1380
+ )
1381
+
1382
+ def forward(
1383
+ self,
1384
+ input_ids: Optional[torch.Tensor] = None,
1385
+ alt_ids: Optional[torch.Tensor] = None,
1386
+ var_positions: Optional[torch.Tensor] = None,
1387
+ attention_mask: Optional[torch.Tensor] = None,
1388
+ token_type_ids: Optional[torch.Tensor] = None,
1389
+ position_ids: Optional[torch.Tensor] = None,
1390
+ inputs_embeds: Optional[torch.Tensor] = None,
1391
+ target: Optional[torch.Tensor] = None,
1392
+ output_attentions: Optional[bool] = None,
1393
+ output_hidden_states: Optional[bool] = None,
1394
+ return_dict: Optional[bool] = None,
1395
+ return_pooled_output: Optional[bool] = False,
1396
+ return_all_logits: Optional[bool] = False,
1397
+ **kwargs,
1398
+ ):
1399
+ r"""
1400
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1401
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1402
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1403
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1404
+ """
1405
+
1406
+ return_dict = (
1407
+ return_dict if return_dict is not None else self.config.use_return_dict
1408
+ )
1409
+
1410
+ outputs = self.bert(
1411
+ input_ids,
1412
+ attention_mask=attention_mask,
1413
+ token_type_ids=token_type_ids,
1414
+ position_ids=position_ids,
1415
+ inputs_embeds=inputs_embeds,
1416
+ output_attentions=output_attentions,
1417
+ output_hidden_states=output_hidden_states,
1418
+ return_dict=return_dict,
1419
+ )
1420
+
1421
+ sequence_output = outputs[0]
1422
+ prediction_scores = self.cls(
1423
+ sequence_output
1424
+ ) # (batch_size, seq_len, vocab_size)
1425
+
1426
+ # select the alt codon logits at the variant positions
1427
+ # alt_ids: (batch_size,)
1428
+ # var_positions: (batch_size,)
1429
+
1430
+ if return_all_logits:
1431
+ return EnCodonForDMSOutput(
1432
+ loss=None,
1433
+ logits=prediction_scores,
1434
+ attentions=outputs.attentions,
1435
+ )
1436
+ bs, seq_len, vocab_size = prediction_scores.shape
1437
+
1438
+ loss = None
1439
+ if var_positions is None and alt_ids is None:
1440
+ alt_prediction_scores = prediction_scores.gather(
1441
+ 2, input_ids.unsqueeze(-1)
1442
+ ).squeeze(
1443
+ 2
1444
+ ) # (batch_size, seq_len)
1445
+
1446
+ if target is not None:
1447
+ expanded_target = target
1448
+
1449
+ if self.config.loss_fn == "mse":
1450
+ loss_fct = nn.MSELoss()
1451
+ loss = loss_fct(alt_prediction_scores, expanded_target)
1452
+ elif self.config.loss_fn == "mae":
1453
+ loss_fct = nn.L1Loss()
1454
+ loss = loss_fct(alt_prediction_scores, expanded_target)
1455
+ elif self.config.loss_fn == "huber":
1456
+ loss_fct = nn.SmoothL1Loss()
1457
+ loss = loss_fct(alt_prediction_scores, expanded_target)
1458
+ else:
1459
+ raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.")
1460
+
1461
+ alt_prediction_scores = alt_prediction_scores.mean(dim=1) # (batch_size,)e
1462
+ else:
1463
+ alt_prediction_scores = prediction_scores[
1464
+ torch.arange(bs), var_positions, alt_ids
1465
+ ] # (batch_size,)
1466
+
1467
+ if target is not None:
1468
+ mask = target != -500.0
1469
+
1470
+ target = target[mask]
1471
+ alt_prediction_scores = alt_prediction_scores[mask]
1472
+
1473
+ if self.config.loss_fn == "mse":
1474
+ loss_fct = nn.MSELoss()
1475
+ loss = loss_fct(alt_prediction_scores, target)
1476
+ elif self.config.loss_fn == "mae":
1477
+ loss_fct = nn.L1Loss()
1478
+ loss = loss_fct(alt_prediction_scores, target)
1479
+ elif self.config.loss_fn == "huber":
1480
+ loss_fct = nn.SmoothL1Loss()
1481
+ loss = loss_fct(alt_prediction_scores, target)
1482
+ else:
1483
+ raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.")
1484
+
1485
+ if not return_dict:
1486
+ output = (alt_prediction_scores,) + outputs[2:]
1487
+ return ((loss,) + output) if loss is not None else output
1488
+
1489
+ return EnCodonForDMSOutput(
1490
+ loss=loss,
1491
+ logits=alt_prediction_scores,
1492
+ attentions=outputs.attentions,
1493
+ )
1494
+
1495
+ def get_codon_embeddings(self) -> Module:
1496
+ return self.bert.embeddings.word_embeddings
1497
+
1498
+ def freeze_bert(self, layers_idx: Optional[list] = None):
1499
+ if layers_idx is None or len(layers_idx) == 0:
1500
+ for param in self.bert.parameters():
1501
+ param.requires_grad = False
1502
+ else:
1503
+ for param in self.bert.embeddings.parameters():
1504
+ param.requires_grad = False
1505
+
1506
+ if isinstance(layers_idx, int):
1507
+ layers_idx = [layers_idx]
1508
+
1509
+ layers_idx = [i % len(self.bert.encoder.layer) for i in layers_idx]
1510
+
1511
+ for i in range(len(self.bert.encoder.layer)):
1512
+ if i not in layers_idx:
1513
+ for param in self.bert.encoder.layer[i].parameters():
1514
+ param.requires_grad = False
1515
+
1516
+ for param in self.bert.pooler.parameters():
1517
+ param.requires_grad = False
1518
+
1519
+
1520
+ class EnCodonForSequenceTask(EnCodonPreTrainedModel):
1521
+ def __init__(self, config):
1522
+ super().__init__(config)
1523
+ self.config = config
1524
+
1525
+ self.bert = EnCodonModule(config)
1526
+
1527
+ if config.cls_type.lower() == "cls":
1528
+ self.classifier = nn.Linear(
1529
+ config.hidden_size, config.num_labels * config.num_tasks
1530
+ )
1531
+ else:
1532
+ raise ValueError(f"Invalid cls_type: {config.cls_type}.")
1533
+
1534
+ self.init_weights()
1535
+
1536
+ def freeze_bert(self, layers_idx: Optional[list] = None):
1537
+ if layers_idx is None or len(layers_idx) == 0:
1538
+ for param in self.bert.parameters():
1539
+ param.requires_grad = False
1540
+ else:
1541
+ for param in self.bert.embeddings.parameters():
1542
+ param.requires_grad = False
1543
+
1544
+ if isinstance(layers_idx, int):
1545
+ layers_idx = [layers_idx]
1546
+
1547
+ layers_idx = [i % len(self.bert.encoder.layer) for i in layers_idx]
1548
+
1549
+ for i in range(len(self.bert.encoder.layer)):
1550
+ if i not in layers_idx:
1551
+ for param in self.bert.encoder.layer[i].parameters():
1552
+ param.requires_grad = False
1553
+
1554
+ if self.config.cls_type.lower() != "cls":
1555
+ for param in self.bert.pooler.parameters():
1556
+ param.requires_grad = False
1557
+
1558
+ def forward(
1559
+ self,
1560
+ input_ids: Optional[torch.Tensor] = None,
1561
+ target: Optional[torch.Tensor] = None,
1562
+ attention_mask: Optional[torch.Tensor] = None,
1563
+ token_type_ids: Optional[torch.Tensor] = None,
1564
+ position_ids: Optional[torch.Tensor] = None,
1565
+ head_mask: Optional[torch.Tensor] = None,
1566
+ inputs_embeds: Optional[torch.Tensor] = None,
1567
+ output_attentions: Optional[bool] = None,
1568
+ output_hidden_states: Optional[bool] = None,
1569
+ return_dict: Optional[bool] = None,
1570
+ **kwargs,
1571
+ ):
1572
+ return_dict = (
1573
+ return_dict if return_dict is not None else self.config.use_return_dict
1574
+ )
1575
+
1576
+ outputs = self.bert(
1577
+ input_ids=input_ids,
1578
+ attention_mask=attention_mask,
1579
+ token_type_ids=token_type_ids,
1580
+ position_ids=position_ids,
1581
+ head_mask=head_mask,
1582
+ inputs_embeds=inputs_embeds,
1583
+ output_attentions=output_attentions,
1584
+ output_hidden_states=True,
1585
+ return_dict=return_dict,
1586
+ )
1587
+
1588
+ all_hidden_states = outputs[2]
1589
+
1590
+ if self.config.cls_type.lower() not in ["crossattention", "ca", "cls"]:
1591
+ logits, _ = self.classifier(all_hidden_states, attention_mask)
1592
+ ca = None
1593
+ elif self.config.cls_type.lower() in ["crossattention", "ca"]:
1594
+ bs, seq_len = input_ids.shape
1595
+
1596
+ query_tasks = self.task_embeddings.weight # (num_tasks, hidden_size)
1597
+ query_tasks = query_tasks.unsqueeze(0).expand(
1598
+ bs, -1, -1
1599
+ ) # (batch_size, num_tasks, hidden_size)
1600
+
1601
+ cls_outputs = self.classifier(
1602
+ query_tasks,
1603
+ all_hidden_states,
1604
+ attention_mask,
1605
+ output_attentions=output_attentions,
1606
+ ) # (batch_size, num_tasks, num_labels)
1607
+
1608
+ logits, ca = cls_outputs
1609
+
1610
+ logits = logits.squeeze()
1611
+ elif self.config.cls_type.lower() == "cls":
1612
+ pooled_output = outputs[1]
1613
+ logits = self.classifier(pooled_output)
1614
+ ca = None
1615
+
1616
+ loss = None
1617
+ if target is not None:
1618
+ if self.config.problem_type == "regression":
1619
+ if self.config.loss_fn == "mse":
1620
+ loss_fct = nn.MSELoss()
1621
+ elif self.config.loss_fn == "mae":
1622
+ loss_fct = nn.L1Loss()
1623
+ elif self.config.loss_fn == "huber":
1624
+ loss_fct = nn.SmoothL1Loss()
1625
+ else:
1626
+ raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.")
1627
+
1628
+ logits = logits.view(-1, self.config.num_labels * self.config.num_tasks)
1629
+ target = target.view(-1, self.config.num_labels * self.config.num_tasks)
1630
+
1631
+ mask = target != -500.0
1632
+
1633
+ loss = loss_fct(logits[mask], target[mask])
1634
+ else:
1635
+ loss_fct = nn.CrossEntropyLoss()
1636
+
1637
+ logits = logits.view(-1, self.config.num_labels * self.config.num_tasks)
1638
+ target = target.view(
1639
+ -1,
1640
+ )
1641
+
1642
+ loss = loss_fct(logits, target)
1643
+
1644
+ if not return_dict:
1645
+ output = (logits,) + outputs[2:]
1646
+ return ((loss,) + output) if loss is not None else output
1647
+
1648
+ if output_attentions:
1649
+ if ca is not None:
1650
+ attentions = outputs.attentions + [ca]
1651
+ else:
1652
+ attentions = outputs.attentions
1653
+ else:
1654
+ attentions = None
1655
+
1656
+ return SequenceClassifierOutput(
1657
+ loss=loss,
1658
+ logits=logits,
1659
+ hidden_states=outputs.hidden_states,
1660
+ attentions=attentions,
1661
+ )