mohsennp commited on
Commit
2743698
·
verified ·
1 Parent(s): deae159

Upload EnCodon

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +39 -0
  3. configuration_encodon.py +140 -0
  4. model.safetensors +3 -0
  5. modeling_encodon.py +1659 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/large_experiments/goodarzilab/mohsen/biofm/saved_models/CodonBERT/CodonBERT_L2048_l12_a16_b16_r1e-05_mlm0.2_wd0.01_g2_euk_adapt-w8cn6xy0/checkpoint-150000",
3
+ "architectures": [
4
+ "EnCodon"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "attention_type": "self",
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_encodon.EnCodonConfig",
10
+ "AutoModelForMaskedLM": "modeling_encodon.EnCodon"
11
+ },
12
+ "classifier_dropout": 0.1,
13
+ "dilation_rates": null,
14
+ "gamma_init": 1.782709687623856,
15
+ "hidden_act": "gelu",
16
+ "hidden_dropout_prob": 0.1,
17
+ "hidden_size": 2048,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 8192,
20
+ "layer_norm_eps": 1e-12,
21
+ "lm_type": "distilled",
22
+ "max_position_embeddings": 2048,
23
+ "num_attention_heads": 16,
24
+ "num_divs": 0,
25
+ "num_hidden_layers": 12,
26
+ "pad_token_id": 3,
27
+ "pooler_activation": "tanh",
28
+ "position_embedding_type": "rotary",
29
+ "rotary_theta": 10000.0,
30
+ "segment_lengths": null,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.44.2",
33
+ "type_vocab_size": 2,
34
+ "use_cache": true,
35
+ "use_flash_attn": true,
36
+ "use_nsp": false,
37
+ "use_rotary_emb": true,
38
+ "vocab_size": 69
39
+ }
configuration_encodon.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class EnCodonConfig(PretrainedConfig):
4
+ def __init__(
5
+ self,
6
+ vocab_size=70,
7
+ hidden_size=768,
8
+ num_hidden_layers=12,
9
+ num_attention_heads=12,
10
+ intermediate_size=3072,
11
+ hidden_act="gelu",
12
+ hidden_dropout_prob=0.1,
13
+ attention_probs_dropout_prob=0.1,
14
+ max_position_embeddings=512,
15
+ type_vocab_size=2,
16
+ initializer_range=0.02,
17
+ layer_norm_eps=1e-12,
18
+ pad_token_id=0,
19
+ position_embedding_type="absolute",
20
+ use_cache=True,
21
+ classifier_dropout=0.1,
22
+ gamma_init=0.1,
23
+ use_rotary_emb=False,
24
+ rotary_theta=5e5,
25
+ use_flash_attn=False,
26
+ lm_type="bert",
27
+ **kwargs,
28
+ ):
29
+ super().__init__(
30
+ vocab_size=vocab_size,
31
+ hidden_size=hidden_size,
32
+ num_hidden_layers=num_hidden_layers,
33
+ num_attention_heads=num_attention_heads,
34
+ intermediate_size=intermediate_size,
35
+ hidden_act=hidden_act,
36
+ hidden_dropout_prob=hidden_dropout_prob,
37
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
38
+ max_position_embeddings=max_position_embeddings,
39
+ type_vocab_size=type_vocab_size,
40
+ initializer_range=initializer_range,
41
+ layer_norm_eps=layer_norm_eps,
42
+ pad_token_id=pad_token_id,
43
+ position_embedding_type=position_embedding_type,
44
+ use_cache=use_cache,
45
+ classifier_dropout=classifier_dropout,
46
+ gamma_init=gamma_init,
47
+ use_rotary_emb=use_rotary_emb,
48
+ rotary_theta=rotary_theta,
49
+ use_flash_attn=use_flash_attn,
50
+ lm_type=lm_type,
51
+ **kwargs,
52
+ )
53
+
54
+
55
+ class EnCodonForDMSConfig(EnCodonConfig):
56
+ def __init__(
57
+ self,
58
+ loss_fn="huber",
59
+ num_labels=1,
60
+ task_name="NoName",
61
+ problem_type="regression",
62
+ **kwargs,
63
+ ):
64
+
65
+ if problem_type == "classification":
66
+ problem_type_ = "single_label_classification"
67
+ else:
68
+ problem_type_ = problem_type
69
+
70
+ super().__init__(
71
+ loss_fn=loss_fn,
72
+ task_name=task_name,
73
+ num_labels=num_labels,
74
+ problem_type=problem_type_,
75
+ **kwargs,
76
+ )
77
+
78
+ self.problem_type = problem_type
79
+
80
+
81
+ class EnCodonForSequenceTaskConfig(EnCodonConfig):
82
+ def __init__(
83
+ self,
84
+ task_name="NoName",
85
+ loss_fn="huber",
86
+ num_labels=2,
87
+ num_tasks=1,
88
+ cls_num_hidden_layers=1,
89
+ cls_hidden_size=128,
90
+ cls_dropout_prob=0.1,
91
+ cls_hidden_act="relu",
92
+ cls_type="mlp",
93
+ cls_num_attention_heads=8,
94
+ cls_use_rotary_emb=False,
95
+ cls_rotary_theta=1e4,
96
+ num_filters=128,
97
+ kernel_size=3,
98
+ stride=1,
99
+ dilation=1,
100
+ pooling_size=2,
101
+ pooling_type="max",
102
+ layer_indices=-1,
103
+ reduction="mean",
104
+ layer_reduction="none",
105
+ problem_type="classification",
106
+ **kwargs,
107
+ ):
108
+
109
+ if problem_type == "classification":
110
+ problem_type_ = "single_label_classification"
111
+ else:
112
+ problem_type_ = problem_type
113
+
114
+ super().__init__(
115
+ loss_fn=loss_fn,
116
+ task_name=task_name,
117
+ num_labels=num_labels,
118
+ num_tasks=num_tasks,
119
+ cls_num_hidden_layers=cls_num_hidden_layers,
120
+ cls_hidden_size=cls_hidden_size,
121
+ cls_dropout_prob=cls_dropout_prob,
122
+ cls_hidden_act=cls_hidden_act,
123
+ cls_num_attention_heads=cls_num_attention_heads,
124
+ cls_use_rotary_emb=cls_use_rotary_emb,
125
+ cls_rotary_theta=cls_rotary_theta,
126
+ cls_type=cls_type,
127
+ num_filters=num_filters,
128
+ kernel_size=kernel_size,
129
+ stride=stride,
130
+ dilation=dilation,
131
+ pooling_size=pooling_size,
132
+ pooling_type=pooling_type,
133
+ layer_indices=layer_indices,
134
+ reduction=reduction,
135
+ layer_reduction=layer_reduction,
136
+ problem_type=problem_type_,
137
+ **kwargs,
138
+ )
139
+
140
+ self.problem_type = problem_type
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aafbc4910ea63f31116be6ac5f6f14aa3e4909b1b92dfc17240153dadb2fa040
3
+ size 2469742556
modeling_encodon.py ADDED
@@ -0,0 +1,1659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from dataclasses import dataclass
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.modules import Module
6
+
7
+ from transformers.modeling_outputs import (
8
+ SequenceClassifierOutput,
9
+ ModelOutput,
10
+ MaskedLMOutput,
11
+ )
12
+
13
+ from transformers.activations import ACT2FN
14
+
15
+ from .configuration_encodon import (
16
+ EnCodonConfig,
17
+ EnCodonForDMSConfig,
18
+ )
19
+ from typing import Optional, Tuple
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+
25
+ from dataclasses import dataclass
26
+ from transformers import (
27
+ apply_chunking_to_forward,
28
+ )
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.utils import ModelOutput, logging
32
+ from transformers.modeling_outputs import MaskedLMOutput
33
+ import xformers.ops as xops
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ from einops import rearrange, einsum
40
+ from transformers.pytorch_utils import Conv1D
41
+
42
+ from math import pi
43
+
44
+ import torch
45
+ from torch.amp import autocast
46
+ from torch import nn, einsum, broadcast_tensors, Tensor
47
+
48
+ from einops import rearrange, repeat
49
+ from typing import Optional, Union, Literal
50
+
51
+
52
+ def rotate_half(x):
53
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
54
+ x1, x2 = x.unbind(dim=-1)
55
+ x = torch.stack((-x2, x1), dim=-1)
56
+ return rearrange(x, "... d r -> ... (d r)")
57
+
58
+
59
+ @autocast(device_type="cuda", enabled=False)
60
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
61
+ """
62
+ Applies rotary embeddings to a tensor.
63
+
64
+ Parameters
65
+ ----------
66
+ freqs : Tensor
67
+ The frequencies to apply to the tensor: (seq_len, dim)
68
+ t : Tensor
69
+ The tensor to apply the rotary embeddings to: (..., seq_len, n_heads, dim)
70
+ start_index : int
71
+ The starting index to apply the rotary embeddings. (default: 0)
72
+ scale : float
73
+ The scale to apply to the rotary embeddings. (default: 1.0)
74
+
75
+ Returns
76
+ -------
77
+ Tensor
78
+ The tensor with the rotary embeddings applied.: (..., seq_len, n_heads, dim)
79
+
80
+ """
81
+ # if t.ndim == 3:
82
+ # seq_len = t.shape[seq_dim]
83
+ # freqs = freqs[-seq_len:].to(t)
84
+
85
+ rot_dim = freqs.shape[-1]
86
+ end_index = start_index + rot_dim
87
+
88
+ assert (
89
+ rot_dim <= t.shape[-1]
90
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
91
+
92
+ t_left, t, t_right = (
93
+ t[..., :start_index],
94
+ t[..., start_index:end_index],
95
+ t[..., end_index:],
96
+ )
97
+ if isinstance(scale, float):
98
+ scale = torch.tensor(scale, device=t.device, dtype=t.dtype)
99
+
100
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
101
+ return torch.cat((t_left, t, t_right), dim=-1)
102
+
103
+
104
+ # learned rotation helpers
105
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
106
+ if freq_ranges is not None:
107
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
108
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
109
+
110
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
111
+ return apply_rotary_emb(rotations, t, start_index=start_index)
112
+
113
+
114
+ """
115
+ Inspired from https://github.com/lucidrains/rotary-embedding-torch
116
+ """
117
+ class RotaryEmbedding(nn.Module):
118
+ """
119
+ Rotary Embeddings Implemenetation inspired by https://github.com/lucidrains/rotary-embedding-torch.
120
+
121
+ Rotary Positional Embeddings (RoPE) encode position information of tokens with a
122
+ rotation matrix that naturally incorporates explicit relative position dependency.
123
+
124
+ Parameters
125
+ ----------
126
+ emb_dim : int
127
+ Embedding dimension. Usually set to the dim of each head in the attention module.
128
+ freqs : Optional[Tensor]
129
+ Custom frequencies to apply to query/key tensors. (default: None)
130
+ theta : float
131
+ Base constant used for computing rotation angles.
132
+ learned_freq : bool (default: False)
133
+ Whether to learn the frequencies.
134
+ use_xpos : bool (default: False)
135
+ Whether to employ XPos technique for resolving length extrapolation issue.
136
+ NOTE: This can only be enabled for autoregressive models like GPT.
137
+ xpos_scale_base : int (default: 512)
138
+ The base for the scale factor used in XPos technique.
139
+ interpolate_factor : float (default: 1.0)
140
+ Length interpolation factor for extending context length of the pretrained model.
141
+ Final model's context length = pretrained_model_context_length * interpolate_factor.
142
+
143
+ theta_rescale_factor : float (default: 1.0)
144
+ The factor to rescale the theta.
145
+
146
+ cache_if_possible : bool (default: True)
147
+ Whether to cache the frequencies/scales if possible.
148
+
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ emb_dim,
154
+ freqs: Optional[Tensor] = None,
155
+ theta=1e4,
156
+ learned_freq=False,
157
+ use_xpos=False,
158
+ xpos_scale_base=512,
159
+ interpolate_factor=1.0,
160
+ theta_rescale_factor=1.0,
161
+ cache_if_possible=True,
162
+ ):
163
+ super().__init__()
164
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
165
+ # has some connection to NTK literature
166
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
167
+
168
+ theta *= theta_rescale_factor ** (emb_dim / (emb_dim - 2))
169
+
170
+ if freqs is None:
171
+ freqs = 1.0 / (
172
+ theta
173
+ ** (torch.arange(0, emb_dim, 2)[: (emb_dim // 2)].float() / emb_dim)
174
+ )
175
+ # freqs = torch.ones(num_freqs).float()
176
+
177
+ self.cache_if_possible = cache_if_possible
178
+
179
+ self.register_buffer("cached_freqs", None, persistent=False)
180
+ self.register_buffer("cached_scales", None, persistent=False)
181
+
182
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
183
+
184
+ self.learned_freq = learned_freq
185
+
186
+ # interpolation factors
187
+
188
+ assert interpolate_factor >= 1.0
189
+ self.interpolate_factor = interpolate_factor
190
+
191
+ # xpos
192
+ self.use_xpos = use_xpos
193
+ if not use_xpos:
194
+ self.register_buffer("scale", None, persistent=False)
195
+ return
196
+
197
+ scale = (torch.arange(0, emb_dim, 2) + 0.4 * emb_dim) / (1.4 * emb_dim)
198
+ self.scale_base = xpos_scale_base
199
+ self.register_buffer("scale", scale, persistent=False)
200
+
201
+ @property
202
+ def device(self):
203
+ return self.freqs.device
204
+
205
+ def rotate_queries_or_keys(self, t, offset=0, freq_seq_len=None, scale=None):
206
+ """
207
+ Parameters
208
+ ----------
209
+ t : Tensor
210
+ tensor to rotate: (batch_size, seq_len, num_heads, head_dim)
211
+ """
212
+ seq_len = t.shape[1]
213
+ assert (
214
+ not self.use_xpos or scale is not None
215
+ ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
216
+
217
+ if freq_seq_len is not None:
218
+ assert freq_seq_len >= seq_len
219
+ seq_len = freq_seq_len
220
+
221
+ seq = (
222
+ torch.arange(seq_len, device=t.device, dtype=t.dtype) + offset
223
+ ) / self.interpolate_factor
224
+
225
+ freqs = self.forward(
226
+ seq,
227
+ seq_len=seq_len,
228
+ offset=offset,
229
+ ).to(t.dtype)
230
+
231
+ freqs = rearrange(freqs, "n d -> n 1 d")
232
+
233
+ if scale is not None:
234
+ scale = rearrange(scale, "n d -> n 1 d")
235
+
236
+ if scale is None:
237
+ scale = torch.tensor(1.0, device=t.device, dtype=t.dtype)
238
+
239
+ return apply_rotary_emb(freqs, t, scale=scale)
240
+
241
+ def rotate_queries_and_keys(self, q, k):
242
+ """
243
+ Parameters
244
+ ----------
245
+ q : Tensor
246
+ queries tensor: (batch_size, seq_len, num_heads, head_dim)
247
+ k : Tensor
248
+ keys tensor: (batch_size, seq_len, num_heads, head_dim)
249
+ """
250
+ assert self.use_xpos
251
+ seq_len = q.shape[-3]
252
+
253
+ seq = (
254
+ torch.arange(seq_len, device=q.device, dtype=q.dtype)
255
+ ) / self.interpolate_factor
256
+
257
+ freqs = self.forward(seq, seq_len=seq_len)
258
+ scale = self.get_scale(seq, seq_len=seq_len)
259
+
260
+ freqs = rearrange(freqs, "n d -> n 1 d")
261
+ scale = rearrange(scale, "n d -> n 1 d")
262
+
263
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale)
264
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1)
265
+
266
+ rotated_q = rotated_q.type(q.dtype)
267
+ rotated_k = rotated_k.type(k.dtype)
268
+
269
+ return rotated_q, rotated_k
270
+
271
+ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
272
+ assert self.use_xpos
273
+
274
+ should_cache = self.cache_if_possible and seq_len is not None
275
+
276
+ if (
277
+ should_cache
278
+ and self.cached_scales is not None
279
+ and (seq_len + offset) <= self.cached_scales.shape[0]
280
+ ):
281
+ return self.cached_scales[offset : (offset + seq_len)]
282
+
283
+ scale = 1.0
284
+ if self.use_xpos:
285
+ power = (t - len(t) // 2) / self.scale_base
286
+ scale = self.scale ** rearrange(power, "n -> n 1")
287
+ scale = torch.cat((scale, scale), dim=-1)
288
+
289
+ if should_cache:
290
+ self.register_buffer("cached_scales", scale, persistent=False)
291
+
292
+ return scale
293
+
294
+ def rotate_queries_with_cached_keys(self, q, k, offset=0):
295
+ q_len, k_len = q.shape[1], k.shape[1]
296
+ assert q_len <= k_len
297
+
298
+ rotated_q, rotated_k = self.rotate_queries_and_keys(q, k)
299
+
300
+ rotated_q = rotated_q[:, -1:, ...]
301
+
302
+ return rotated_q, rotated_k
303
+
304
+ seq = (
305
+ torch.arange(k_len, device=q.device, dtype=q.dtype)
306
+ ) / self.interpolate_factor
307
+
308
+ if self.use_xpos:
309
+ q_scale = self.get_scale(seq[-q_len:]).to(q.dtype)
310
+ k_scale = self.get_scale(seq).to(k.dtype)
311
+
312
+ else:
313
+ k_scale = 1.0
314
+ q_scale = 1.0
315
+
316
+ rotated_q = self.rotate_queries_or_keys(
317
+ q, scale=q_scale, offset=k_len - q_len + offset
318
+ )
319
+ rotated_k = self.rotate_queries_or_keys(k, scale=k_scale**-1)
320
+
321
+ return rotated_q, rotated_k
322
+
323
+ @autocast(device_type="cuda", enabled=False)
324
+ def forward(self, t: Tensor, seq_len=None, offset=0):
325
+ should_cache = (
326
+ self.cache_if_possible and not self.learned_freq and seq_len is not None
327
+ )
328
+
329
+ if (
330
+ should_cache
331
+ and self.cached_freqs is not None
332
+ and (offset + seq_len) <= self.cached_freqs.shape[0]
333
+ ):
334
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
335
+
336
+ freqs = self.freqs
337
+
338
+ freqs = einsum("..., f -> ... f", t, freqs)
339
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
340
+
341
+ if should_cache:
342
+ self.register_buffer("cached_freqs", freqs.detach(), persistent=False)
343
+
344
+ return freqs
345
+
346
+
347
+
348
+ class MultiHeadedSelfAttention(nn.Module):
349
+ """
350
+ Multi-Headed Self Attention module supported with Flash Attention and Rotary Embeddings.
351
+
352
+ Parameters
353
+ ----------
354
+ q_input_dim: int
355
+ The input dimension of the query tensor.
356
+ kv_input_dim: int
357
+ The input dimension of the key and value tensors.
358
+ qk_proj_dim: int
359
+ The projected dimension of the query and key tensors.
360
+ v_proj_dim: int
361
+ The projected dimension of the value tensors.
362
+ num_heads: int
363
+ Number of attention heads.
364
+ dropout: float
365
+ Dropout rate to apply to the attention scores.
366
+ projection_layer: str
367
+ The type of projection layer to use. Either 'linear' or 'conv'.
368
+ Basically both are linear projections, but 'conv' uses Conv1D layer as proposed in the original GPT2 paper.
369
+ use_flash_attn: bool
370
+ Whether to use Flash Attention or not. If True, Flash Attention will be used.
371
+ NOTE: Flash Attention is required to be installed.
372
+ use_rotary_emb: bool
373
+ Whether to use Rotary Embeddings or not.
374
+ rotary_theta: int
375
+ The base for the geometric progression used to compute the rotation angles.
376
+ rotary_use_xpos: bool
377
+ Whether to use XPos technique for resolving length extrapolation issue.
378
+ NOTE: This can only be enabled for autoregressive models like GPT.
379
+ """
380
+
381
+ def __init__(
382
+ self,
383
+ q_input_dim,
384
+ kv_input_dim,
385
+ qk_proj_dim,
386
+ v_proj_dim,
387
+ num_heads,
388
+ dropout: float = 0.0,
389
+ projection_layer: str = "linear",
390
+ use_flash_attn: bool = True,
391
+ use_rotary_emb: bool = False,
392
+ rotary_theta: int = 1e4,
393
+ rotary_use_xpos: bool = False,
394
+ is_cross_attention: bool = False,
395
+ **kwargs,
396
+ ):
397
+ super().__init__()
398
+ assert (
399
+ qk_proj_dim % num_heads == 0
400
+ ), "qk_proj_dim must be divisible by num_heads"
401
+ assert v_proj_dim % num_heads == 0, "v_proj_dim must be divisible by num_heads"
402
+
403
+ self.num_heads = num_heads
404
+ self.dropout_rate = dropout
405
+ self.projection_layer = projection_layer
406
+ self.use_rotary_emb = use_rotary_emb
407
+ self.is_cross_attention = is_cross_attention
408
+
409
+ if use_flash_attn and not is_cross_attention:
410
+ try:
411
+ from flash_attn import flash_attn_qkvpacked_func
412
+
413
+ self.use_flash_attn = True
414
+ self.flashattn_fn = flash_attn_qkvpacked_func
415
+ except ImportError:
416
+ print("flash_attn not installed, reverting to default attention")
417
+ self.use_flash_attn = False
418
+ self.flashattn_fn = None
419
+ else:
420
+ self.use_flash_attn = False
421
+ self.flashattn_fn = None
422
+
423
+ if self.projection_layer == "linear":
424
+ self.query = nn.Linear(q_input_dim, qk_proj_dim)
425
+ self.key = nn.Linear(kv_input_dim, qk_proj_dim)
426
+ self.value = nn.Linear(kv_input_dim, v_proj_dim)
427
+ elif self.projection_layer == "conv":
428
+ self.query = Conv1D(qk_proj_dim, q_input_dim)
429
+ self.key = Conv1D(qk_proj_dim, kv_input_dim)
430
+ self.value = Conv1D(v_proj_dim, kv_input_dim)
431
+ else:
432
+ raise ValueError(
433
+ f"projection_layer must be either 'linear' or 'conv', got {projection_layer}"
434
+ )
435
+
436
+ if self.use_rotary_emb:
437
+ self.rotary_emb = RotaryEmbedding(
438
+ emb_dim=qk_proj_dim // num_heads // 2,
439
+ theta=rotary_theta,
440
+ use_xpos=rotary_use_xpos,
441
+ )
442
+
443
+ self.dr_rate = dropout
444
+ self.dropout = nn.Dropout(dropout)
445
+
446
+ def forward(
447
+ self,
448
+ x_q,
449
+ x_kv,
450
+ is_causal=False,
451
+ attention_bias=None,
452
+ attention_mask=None,
453
+ output_attentions=False,
454
+ query=None,
455
+ key=None,
456
+ value=None,
457
+ use_cache=False,
458
+ ):
459
+ """
460
+ Applies a classical self attention operation.
461
+
462
+ Parameters
463
+ ----------
464
+ x_q: torch.Tensor
465
+ The query tensor of shape (batch_size, query_seq_len, emb_dim)
466
+ x_kv: torch.Tensor
467
+ The key/value tensor of shape (batch_size, kv_seq_len, emb_dim)
468
+ attention_bias: torch.Tensor
469
+ The attention bias to apply to the attention scores. (default: None)
470
+ attention_mask: torch.Tensor
471
+ The attention mask to apply to the attention scores. Shape: (batch_size, q_len, kv_seq_len)
472
+ """
473
+ assert (x_q is not None and x_kv is not None) or (
474
+ query is not None and key is not None and value is not None
475
+ ), "Either x_q and x_kv or query, key and value must be provided"
476
+
477
+ past_memory_provided = (
478
+ query is not None and key is not None and value is not None
479
+ )
480
+
481
+ if query is None:
482
+ q_len = x_q.size(1)
483
+ k_len = x_kv.size(1)
484
+
485
+ query = self.query(x_q)
486
+ key = self.key(x_kv)
487
+ value = self.value(x_kv)
488
+
489
+ else:
490
+ q_len = query.size(1)
491
+ k_len = key.size(1)
492
+
493
+ if use_cache:
494
+ cache = (key.clone(), value.clone(), query.clone())
495
+
496
+ q = rearrange(query, "b q (h d) -> b q h d", h=self.num_heads)
497
+ k = rearrange(key, "b k (h d) -> b k h d", h=self.num_heads)
498
+ v = rearrange(value, "b v (h d) -> b v h d", h=self.num_heads)
499
+
500
+ if self.use_rotary_emb:
501
+ if use_cache and past_memory_provided:
502
+ q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
503
+ if self.rotary_emb.use_xpos:
504
+ q, k = self.rotary_emb.rotate_queries_and_keys(q, k)
505
+ else:
506
+ q = self.rotary_emb.rotate_queries_or_keys(q)
507
+ k = self.rotary_emb.rotate_queries_or_keys(k)
508
+
509
+ if (
510
+ self.use_flash_attn
511
+ and not use_cache
512
+ and not output_attentions
513
+ and attention_bias is None
514
+ ):
515
+ qkv = torch.stack([q, k, v], dim=2).to(torch.bfloat16)
516
+ x = self.flashattn_fn(
517
+ qkv=qkv,
518
+ dropout_p=self.dropout_rate if self.training else 0.0,
519
+ causal=is_causal,
520
+ deterministic=False,
521
+ return_attn_probs=False,
522
+ )
523
+
524
+ x = x.to(x_q.dtype)
525
+ elif self.use_flash_attn and not output_attentions:
526
+ attn_bias = xops.LowerTriangularMask() if is_causal else attention_bias
527
+
528
+ if attention_mask is not None:
529
+ if attn_bias is None:
530
+ attn_bias = attention_mask
531
+ else:
532
+ if isinstance(attn_bias, torch.Tensor):
533
+ attn_bias = attn_bias + attention_mask
534
+ else:
535
+ attn_bias.add_bias(bias=attention_mask)
536
+
537
+ attn_bias = attn_bias.materialize(
538
+ shape=(q_len, k_len),
539
+ device=q.device,
540
+ dtype=q.dtype,
541
+ )
542
+ else:
543
+ if isinstance(attn_bias, torch.Tensor) and len(attn_bias.shape) == 3:
544
+ attn_bias = (
545
+ attn_bias.unsqueeze(1)
546
+ .expand(-1, self.num_heads, -1, -1)
547
+ .float()
548
+ ) # (batch_size, num_heads, q_len, k_len)
549
+ else:
550
+ attn_bias = attn_bias.materialize(
551
+ shape=(q_len, k_len),
552
+ device=q.device,
553
+ dtype=q.dtype,
554
+ )
555
+
556
+ if isinstance(attn_bias, xops.LowerTriangularMask):
557
+ attn_bias = attn_bias.materialize(
558
+ shape=(q_len, k_len),
559
+ device=q.device,
560
+ dtype=q.dtype,
561
+ )
562
+
563
+ # print(attention_mask.shape, attn_bias.shape)
564
+ # print(attn_bias[0, 0, 0, :])
565
+
566
+ need_adjustment = False
567
+ if attn_bias.shape[-2] % 8 != 0:
568
+ nearest_multiple_q = 8 * (1 + attn_bias.shape[-2] // 8)
569
+ need_adjustment = True
570
+ else:
571
+ nearest_multiple_q = attn_bias.shape[-2]
572
+
573
+ if attn_bias.shape[-1] % 8 != 0:
574
+ nearest_multiple_k = 8 * (1 + attn_bias.shape[-1] // 8)
575
+ need_adjustment = True
576
+ else:
577
+ nearest_multiple_k = attn_bias.shape[-1]
578
+
579
+ if need_adjustment:
580
+ new_attn_bias = torch.zeros(
581
+ attn_bias.shape[0],
582
+ attn_bias.shape[1],
583
+ nearest_multiple_q,
584
+ nearest_multiple_k,
585
+ ).to(attn_bias.device)
586
+ new_attn_bias[:, :, : attn_bias.shape[-2], : attn_bias.shape[-1]] = (
587
+ attn_bias
588
+ )
589
+
590
+ x = xops.memory_efficient_attention(
591
+ query=q,
592
+ key=k,
593
+ value=v,
594
+ op=None,
595
+ attn_bias=new_attn_bias[:, :, :q_len, :k_len],
596
+ p=self.dr_rate,
597
+ )
598
+ else:
599
+ attn_bias = attn_bias.to(q.dtype)
600
+ attn_bias = attn_bias.repeat(1, self.num_heads, 1, 1)
601
+ x = xops.memory_efficient_attention(
602
+ query=q,
603
+ key=k,
604
+ value=v,
605
+ op=None,
606
+ attn_bias=attn_bias,
607
+ p=self.dr_rate,
608
+ )
609
+ # x: (batch_size, query_seq_len, n_head, head_dim)
610
+ else:
611
+ # if output_attentions:
612
+ attention_scores = einsum(q, k, "b q h d, b k h d -> b h q k")
613
+ attention_scores = attention_scores / (q.size(-1) ** 0.5)
614
+
615
+ if attention_bias is not None:
616
+ attn_bias = attention_bias.unsqueeze(1).expand(
617
+ -1, self.num_heads, -1, -1
618
+ )
619
+ # elif is_causal:
620
+ # attn_bias = xops.LowerTriangularMask().materialize(
621
+ # shape=attention_scores.shape, device=attention_scores.device
622
+ # )
623
+ else:
624
+ attn_bias = None
625
+
626
+ if attention_mask is not None:
627
+ if attn_bias is None:
628
+ attn_bias = attention_mask
629
+ else:
630
+ attn_bias = attn_bias + attention_mask
631
+
632
+ attention_scores = attention_scores + attn_bias
633
+
634
+ attention_probs = attention_scores.softmax(dim=-1)
635
+ attention_probs = self.dropout(attention_probs)
636
+
637
+ x = einsum(attention_probs, v, "b h q k, b v h d -> b q h d")
638
+
639
+ x = rearrange(x, "b q h d -> b q (h d)", h=self.num_heads)
640
+
641
+ if use_cache:
642
+ if output_attentions:
643
+ return x, attention_probs, cache
644
+ else:
645
+ return x, None, cache
646
+ else:
647
+ if output_attentions:
648
+ return x, attention_probs
649
+ else:
650
+ return x, None
651
+
652
+ class EnCodonPreTrainedModel(PreTrainedModel):
653
+ """
654
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
655
+ models.
656
+ """
657
+ base_model_prefix = "encodon"
658
+ supports_gradient_checkpointing = True
659
+
660
+ def _init_weights(self, module):
661
+ """MAGNETO Initialize the weights"""
662
+ if isinstance(module, nn.Linear):
663
+ # gain should be 1 for query and key weights
664
+ is_qk = False
665
+ for n, p in module.named_parameters():
666
+ if "query" in n or "key" in n:
667
+ is_qk = True
668
+ break
669
+ if is_qk:
670
+ nn.init.xavier_normal_(module.weight, gain=1.0)
671
+ else:
672
+ nn.init.xavier_normal_(module.weight, gain=self.config.gamma_init)
673
+ if module.bias is not None:
674
+ module.bias.data.zero_()
675
+
676
+ elif isinstance(module, nn.Embedding):
677
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
678
+ if module.padding_idx is not None:
679
+ module.weight.data[module.padding_idx].zero_()
680
+
681
+ elif isinstance(module, nn.LayerNorm):
682
+ module.bias.data.zero_()
683
+ module.weight.data.fill_(1.0)
684
+
685
+ def _set_gradient_checkpointing(self, module, value=False):
686
+ if isinstance(module, EnCodonStack):
687
+ module.gradient_checkpointing = value
688
+
689
+
690
+ class EnCodonEmbeddings(nn.Module):
691
+ """
692
+ EnCodon Embeddings module. This module contains word, token type, and (absolute) positional embeddings.
693
+ NOTE: This module is adapted from the original HuggingFace implementation.
694
+ NOTE: Absolute positional embeddings is mutual exclusive with rotary embeddings.
695
+ """
696
+
697
+ def __init__(self, config):
698
+ super().__init__()
699
+ self.word_embeddings = nn.Embedding(
700
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
701
+ )
702
+ self.position_embeddings = nn.Embedding(
703
+ config.max_position_embeddings, config.hidden_size
704
+ )
705
+ self.token_type_embeddings = nn.Embedding(
706
+ config.type_vocab_size, config.hidden_size
707
+ )
708
+
709
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
710
+ # any TensorFlow checkpoint file
711
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
712
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
713
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
714
+ self.position_embedding_type = getattr(
715
+ config, "position_embedding_type", "absolute"
716
+ )
717
+ self.register_buffer(
718
+ "position_ids",
719
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
720
+ persistent=False,
721
+ )
722
+ self.register_buffer(
723
+ "token_type_ids",
724
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
725
+ persistent=False,
726
+ )
727
+
728
+ def forward(
729
+ self,
730
+ input_ids: Optional[torch.LongTensor] = None,
731
+ token_type_ids: Optional[torch.LongTensor] = None,
732
+ position_ids: Optional[torch.LongTensor] = None,
733
+ inputs_embeds: Optional[torch.FloatTensor] = None,
734
+ past_key_values_length: int = 0,
735
+ ) -> torch.Tensor:
736
+ if input_ids is not None:
737
+ input_shape = input_ids.size()
738
+ else:
739
+ input_shape = inputs_embeds.size()[:-1]
740
+
741
+ seq_length = input_shape[1]
742
+
743
+ if position_ids is None:
744
+ position_ids = self.position_ids[
745
+ :, past_key_values_length : seq_length + past_key_values_length
746
+ ]
747
+
748
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
749
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
750
+ # issue #5664
751
+ if token_type_ids is None:
752
+ if hasattr(self, "token_type_ids"):
753
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
754
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
755
+ input_shape[0], seq_length
756
+ )
757
+ token_type_ids = buffered_token_type_ids_expanded
758
+ else:
759
+ token_type_ids = torch.zeros(
760
+ input_shape, dtype=torch.long, device=self.position_ids.device
761
+ )
762
+
763
+ if inputs_embeds is None:
764
+ inputs_embeds = self.word_embeddings(input_ids)
765
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
766
+
767
+ embeddings = inputs_embeds + token_type_embeddings
768
+ if self.position_embedding_type == "absolute":
769
+ position_embeddings = self.position_embeddings(position_ids)
770
+ embeddings += position_embeddings
771
+ embeddings = self.LayerNorm(embeddings)
772
+ embeddings = self.dropout(embeddings)
773
+ return embeddings
774
+
775
+
776
+ class EnCodonAttention(nn.Module):
777
+ """
778
+ EnCodon Attention module. This module supports two types of attention:
779
+ (1) self-attention and (2) dilated as described in Transformers and LongNet papers, respectively.
780
+ """
781
+
782
+ def __init__(self, config, layer_idx=0):
783
+ super().__init__()
784
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
785
+ config, "embedding_size"
786
+ ):
787
+ raise ValueError(
788
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
789
+ f"heads ({config.num_attention_heads})"
790
+ )
791
+
792
+ self.layer_idx = layer_idx
793
+ self.pre_layer_norm = nn.LayerNorm(
794
+ config.hidden_size, eps=config.layer_norm_eps
795
+ )
796
+ self.post_attn_dense = nn.Linear(config.hidden_size, config.hidden_size)
797
+ self.post_layer_norm = nn.LayerNorm(
798
+ config.hidden_size, eps=config.layer_norm_eps
799
+ )
800
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
801
+
802
+ self.self_attention = MultiHeadedSelfAttention(
803
+ q_input_dim=config.hidden_size,
804
+ kv_input_dim=config.hidden_size,
805
+ qk_proj_dim=config.hidden_size,
806
+ v_proj_dim=config.hidden_size,
807
+ num_heads=config.num_attention_heads,
808
+ dropout=config.attention_probs_dropout_prob,
809
+ projection_layer="linear",
810
+ use_flash_attn=config.use_flash_attn,
811
+ use_rotary_emb=config.use_rotary_emb,
812
+ rotary_theta=config.rotary_theta,
813
+ rotary_use_xpos=False,
814
+ )
815
+
816
+ def forward(
817
+ self,
818
+ hidden_states: torch.Tensor,
819
+ attention_mask: Optional[torch.FloatTensor] = None,
820
+ attention_bias: Optional[torch.FloatTensor] = None,
821
+ output_attentions: Optional[bool] = False,
822
+ ) -> Tuple[torch.Tensor]:
823
+ attn_input = self.pre_layer_norm(hidden_states)
824
+ attn_outputs = self.self_attention(
825
+ attn_input,
826
+ attn_input,
827
+ is_causal=False,
828
+ attention_bias=attention_bias,
829
+ attention_mask=attention_mask,
830
+ output_attentions=output_attentions,
831
+ )
832
+
833
+ attn_output = attn_outputs[0]
834
+ attn_output = self.post_layer_norm(attn_output)
835
+ attn_output = self.post_attn_dense(attn_output)
836
+ attn_output = self.dropout(attn_output)
837
+ attn_output = hidden_states + attn_output
838
+ return (attn_output,) + attn_outputs[1:] # add attentions if we output them
839
+
840
+
841
+ class EnCodonFFN(nn.Module):
842
+ """
843
+ EnCodon Position-wise Feed-Forward Network module.
844
+ """
845
+
846
+ def __init__(self, config):
847
+ super().__init__()
848
+ self.pre_layer_norm = nn.LayerNorm(
849
+ config.hidden_size, eps=config.layer_norm_eps
850
+ )
851
+ self.intermediate_dense = nn.Linear(
852
+ config.hidden_size, config.intermediate_size
853
+ )
854
+ self.post_layer_norm = nn.LayerNorm(
855
+ config.intermediate_size, eps=config.layer_norm_eps
856
+ )
857
+ self.post_dense = nn.Linear(config.intermediate_size, config.hidden_size)
858
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
859
+
860
+ if isinstance(config.hidden_act, str):
861
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
862
+ else:
863
+ self.intermediate_act_fn = config.hidden_act
864
+
865
+ def forward(self, input_states: torch.Tensor) -> torch.Tensor:
866
+ hidden_states = self.pre_layer_norm(input_states)
867
+ hidden_states = self.intermediate_dense(hidden_states)
868
+ hidden_states = self.intermediate_act_fn(hidden_states)
869
+ hidden_states = self.post_layer_norm(hidden_states)
870
+ hidden_states = self.post_dense(hidden_states)
871
+ hidden_states = self.dropout(hidden_states)
872
+ return hidden_states + input_states
873
+
874
+
875
+ class EnCodonLayer(nn.Module):
876
+ """
877
+ EnCodon Encoder layer module.
878
+
879
+ This module contains an attention layer followed by a position-wise feed-forward layer.
880
+ """
881
+
882
+ def __init__(self, config):
883
+ super().__init__()
884
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
885
+ self.seq_len_dim = 1
886
+ self.attention = EnCodonAttention(config)
887
+ self.output = EnCodonFFN(config)
888
+
889
+ def forward(
890
+ self,
891
+ hidden_states: torch.Tensor,
892
+ attention_mask: Optional[torch.FloatTensor] = None,
893
+ attention_bias: Optional[torch.FloatTensor] = None,
894
+ output_attentions: Optional[bool] = False,
895
+ ) -> Tuple[torch.Tensor]:
896
+ self_attention_outputs = self.attention(
897
+ hidden_states=hidden_states,
898
+ attention_mask=attention_mask,
899
+ attention_bias=attention_bias,
900
+ output_attentions=output_attentions,
901
+ )
902
+ attention_output = self_attention_outputs[0]
903
+
904
+ outputs = self_attention_outputs[
905
+ 1:
906
+ ] # add self attentions if we output attention weights
907
+
908
+ layer_output = apply_chunking_to_forward(
909
+ self.feed_forward_chunk,
910
+ self.chunk_size_feed_forward,
911
+ self.seq_len_dim,
912
+ attention_output,
913
+ )
914
+ outputs = (layer_output,) + outputs
915
+
916
+ return outputs
917
+
918
+ def feed_forward_chunk(self, attention_output):
919
+ layer_output = self.output(attention_output)
920
+ return layer_output
921
+
922
+
923
+ @dataclass
924
+ class BERTEncoderOutput(ModelOutput):
925
+ last_hidden_state: torch.Tensor
926
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
927
+ attentions: Optional[Tuple[torch.Tensor]] = None
928
+
929
+
930
+ class EnCodonStack(nn.Module):
931
+ """
932
+ EnCodon stack module. This module contains multiple EnCodon layers.
933
+ """
934
+
935
+ def __init__(self, config):
936
+ super().__init__()
937
+ self.config = config
938
+ self.layer = nn.ModuleList(
939
+ [EnCodonLayer(config) for _ in range(config.num_hidden_layers)]
940
+ )
941
+ self.gradient_checkpointing = False
942
+
943
+ def forward(
944
+ self,
945
+ hidden_states: torch.Tensor,
946
+ attention_mask: Optional[torch.FloatTensor] = None,
947
+ attention_bias: Optional[torch.FloatTensor] = None,
948
+ output_attentions: Optional[bool] = False,
949
+ output_hidden_states: Optional[bool] = False,
950
+ return_dict: Optional[bool] = True,
951
+ ):
952
+ all_hidden_states = () if output_hidden_states else None
953
+ all_self_attentions = () if output_attentions else None
954
+
955
+ for i, layer_module in enumerate(self.layer):
956
+ if output_hidden_states:
957
+ all_hidden_states = all_hidden_states + (hidden_states,)
958
+
959
+ layer_outputs = layer_module(
960
+ hidden_states=hidden_states,
961
+ attention_mask=attention_mask,
962
+ attention_bias=attention_bias,
963
+ output_attentions=output_attentions,
964
+ )
965
+
966
+ hidden_states = layer_outputs[0]
967
+ if output_attentions:
968
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
969
+
970
+ if output_hidden_states:
971
+ all_hidden_states = all_hidden_states + (hidden_states,)
972
+
973
+ if not return_dict:
974
+ return tuple(
975
+ v
976
+ for v in [
977
+ hidden_states,
978
+ all_hidden_states,
979
+ all_self_attentions,
980
+ ]
981
+ if v is not None
982
+ )
983
+ return BERTEncoderOutput(
984
+ last_hidden_state=hidden_states,
985
+ hidden_states=all_hidden_states,
986
+ attentions=all_self_attentions,
987
+ )
988
+
989
+
990
+ class BERTPooler(nn.Module):
991
+ """
992
+ BERT Pooler module. This module pools the desired token from the hidden states
993
+ which usually used for sequence-level classification/regression tasks.
994
+ """
995
+
996
+ def __init__(self, config, pooled_token_position=0):
997
+ super().__init__()
998
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
999
+ self.activation = nn.Tanh() if config.pooler_activation == "tanh" else nn.ReLU()
1000
+ self.pooled_token_position = pooled_token_position
1001
+
1002
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1003
+ # We "pool" the model by simply taking the hidden state corresponding
1004
+ # to the first token.
1005
+ pooled_token_tensor = hidden_states[:, self.pooled_token_position]
1006
+ pooled_output = self.dense(pooled_token_tensor)
1007
+ pooled_output = self.activation(pooled_output)
1008
+ return pooled_output
1009
+
1010
+
1011
+ class BERTPredictionHeadTransform(nn.Module):
1012
+ def __init__(self, config):
1013
+ super().__init__()
1014
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1015
+ if isinstance(config.hidden_act, str):
1016
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1017
+ else:
1018
+ self.transform_act_fn = config.hidden_act
1019
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1020
+
1021
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1022
+ hidden_states = self.dense(hidden_states)
1023
+ hidden_states = self.transform_act_fn(hidden_states)
1024
+ hidden_states = self.LayerNorm(hidden_states)
1025
+ return hidden_states
1026
+
1027
+
1028
+ class BERTLMPredictionHead(nn.Module):
1029
+ def __init__(self, config):
1030
+ super().__init__()
1031
+ self.transform = BERTPredictionHeadTransform(config)
1032
+
1033
+ # The output weights are the same as the input embeddings, but there is
1034
+ # an output-only bias for each token.
1035
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1036
+
1037
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1038
+
1039
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1040
+ self.decoder.bias = self.bias
1041
+
1042
+ def forward(self, hidden_states):
1043
+ hidden_states = self.transform(hidden_states)
1044
+ hidden_states = self.decoder(hidden_states)
1045
+ return hidden_states
1046
+
1047
+
1048
+ @dataclass
1049
+ class BERTModelOutput(ModelOutput):
1050
+ last_hidden_state: torch.Tensor
1051
+ pooled_output: Optional[torch.Tensor] = None
1052
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
1053
+ attentions: Optional[Tuple[torch.Tensor]] = None
1054
+
1055
+
1056
+ class EnCodonModule(EnCodonPreTrainedModel):
1057
+ """
1058
+ EnCodon Module
1059
+
1060
+ Parameters
1061
+ ----------
1062
+ config : EnCodonConfig
1063
+ Configuration class for EnCodon model.
1064
+ add_pooling_layer : bool (default: True)
1065
+ Whether to add a pooling layer to the model.
1066
+ pooled_token_position : int (default: 0)
1067
+ The position of the token to be pooled from the hidden states.
1068
+
1069
+ """
1070
+
1071
+ def __init__(self, config, add_pooling_layer=True, pooled_token_position=0):
1072
+ super().__init__(config)
1073
+ self.config = config
1074
+
1075
+ self.embeddings = EnCodonEmbeddings(config)
1076
+ self.encoder = EnCodonStack(config)
1077
+
1078
+ self.pooler = (
1079
+ BERTPooler(config, pooled_token_position=pooled_token_position)
1080
+ if add_pooling_layer
1081
+ else None
1082
+ )
1083
+
1084
+ # Initialize weights and apply final processing
1085
+ self.post_init()
1086
+
1087
+ def get_input_embeddings(self):
1088
+ return self.embeddings.word_embeddings
1089
+
1090
+ def set_input_embeddings(self, value):
1091
+ self.embeddings.word_embeddings = value
1092
+
1093
+ def forward(
1094
+ self,
1095
+ input_ids: Optional[torch.Tensor] = None,
1096
+ attention_mask: Optional[torch.Tensor] = None,
1097
+ attention_bias: Optional[torch.Tensor] = None,
1098
+ token_type_ids: Optional[torch.Tensor] = None,
1099
+ position_ids: Optional[torch.Tensor] = None,
1100
+ inputs_embeds: Optional[torch.Tensor] = None,
1101
+ output_attentions: Optional[bool] = None,
1102
+ output_hidden_states: Optional[bool] = None,
1103
+ return_dict: Optional[bool] = None,
1104
+ **kwargs,
1105
+ ):
1106
+ """
1107
+ Forward pass for the BERT model.
1108
+
1109
+ Parameters
1110
+ ----------
1111
+ input_ids : Optional[torch.Tensor]
1112
+ The input IDs for the model. Expected Shape: (batch_size, seq_length)
1113
+
1114
+ attention_mask : Optional[torch.Tensor]
1115
+ The attention mask for the model. Expected Shape: (batch_size, seq_length)
1116
+ - 1 for tokens that are NOT MASKED
1117
+ - 0 for tokens that are MASKED
1118
+
1119
+ token_type_ids : Optional[torch.Tensor]
1120
+ The token type IDs for the model. Expected Shape: (batch_size, seq_length)
1121
+
1122
+ position_ids : Optional[torch.Tensor]
1123
+ The position IDs for the model. Expected Shape: (batch_size, seq_length)
1124
+
1125
+ inputs_embeds : Optional[torch.Tensor]
1126
+ The input embeddings for the model. Expected Shape: (batch_size, seq_length, hidden_size)
1127
+
1128
+ output_attentions : Optional[bool]
1129
+ Whether to output attentions or not.
1130
+
1131
+ output_hidden_states : Optional[bool]
1132
+ Whether to output hidden states or not.
1133
+
1134
+ return_dict : Optional[bool]
1135
+ Whether to return a dictionary or not.
1136
+
1137
+ Returns
1138
+ -------
1139
+ BERTModelOutput
1140
+ The output of the BERT model.
1141
+ """
1142
+ output_attentions = (
1143
+ output_attentions
1144
+ if output_attentions is not None
1145
+ else self.config.output_attentions
1146
+ )
1147
+ output_hidden_states = (
1148
+ output_hidden_states
1149
+ if output_hidden_states is not None
1150
+ else self.config.output_hidden_states
1151
+ )
1152
+ return_dict = (
1153
+ return_dict if return_dict is not None else self.config.use_return_dict
1154
+ )
1155
+
1156
+ if input_ids is not None and inputs_embeds is not None:
1157
+ raise ValueError(
1158
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1159
+ )
1160
+ elif input_ids is not None:
1161
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1162
+ input_shape = input_ids.size()
1163
+ elif inputs_embeds is not None:
1164
+ input_shape = inputs_embeds.size()[:-1]
1165
+ else:
1166
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1167
+
1168
+ batch_size, seq_length = input_shape
1169
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1170
+
1171
+ if attention_mask is None:
1172
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
1173
+
1174
+ if token_type_ids is None:
1175
+ if hasattr(self.embeddings, "token_type_ids"):
1176
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1177
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
1178
+ batch_size, seq_length
1179
+ )
1180
+ token_type_ids = buffered_token_type_ids_expanded
1181
+ else:
1182
+ token_type_ids = torch.zeros(
1183
+ input_shape, dtype=torch.long, device=device
1184
+ )
1185
+
1186
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1187
+ # ourselves in which case we just need to make it broadcastable to all heads.
1188
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
1189
+ attention_mask, input_shape
1190
+ )
1191
+
1192
+ embedding_output = self.embeddings(
1193
+ input_ids=input_ids,
1194
+ position_ids=position_ids,
1195
+ token_type_ids=token_type_ids,
1196
+ inputs_embeds=inputs_embeds,
1197
+ )
1198
+
1199
+ encoder_outputs = self.encoder(
1200
+ embedding_output,
1201
+ attention_mask=extended_attention_mask,
1202
+ attention_bias=attention_bias,
1203
+ output_attentions=output_attentions,
1204
+ output_hidden_states=output_hidden_states,
1205
+ return_dict=return_dict,
1206
+ )
1207
+
1208
+ sequence_output = encoder_outputs[0]
1209
+ pooled_output = (
1210
+ self.pooler(sequence_output) if self.pooler is not None else None
1211
+ )
1212
+
1213
+ if not return_dict:
1214
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1215
+
1216
+ return BERTModelOutput(
1217
+ last_hidden_state=sequence_output,
1218
+ pooled_output=pooled_output,
1219
+ hidden_states=encoder_outputs.hidden_states,
1220
+ attentions=encoder_outputs.attentions,
1221
+ )
1222
+
1223
+
1224
+ @dataclass
1225
+ class EnCodonOutput(MaskedLMOutput):
1226
+ """
1227
+ Base class for EnCodon Outputs
1228
+
1229
+ Args:
1230
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1231
+ Masked language modeling (MLM) loss.
1232
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1233
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1234
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1235
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1236
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1237
+
1238
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1239
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
1240
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1241
+ sequence_length)`.
1242
+
1243
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1244
+ heads.
1245
+ """
1246
+
1247
+ loss: Optional[torch.FloatTensor] = None
1248
+ logits: torch.FloatTensor = None
1249
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1250
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1251
+
1252
+
1253
+ class EnCodon(EnCodonPreTrainedModel):
1254
+ config_class = EnCodonConfig
1255
+
1256
+ def __init__(self, config):
1257
+ super().__init__(config)
1258
+
1259
+ self.bert = EnCodonModule(config)
1260
+ if self.config.lm_type == "bert":
1261
+ self.cls = BERTLMPredictionHead(config)
1262
+ else:
1263
+ self.cls = nn.Sequential(
1264
+ nn.Linear(config.hidden_size, config.hidden_size),
1265
+ nn.GELU(),
1266
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
1267
+ nn.Linear(config.hidden_size, config.vocab_size),
1268
+ )
1269
+
1270
+ def forward(
1271
+ self,
1272
+ input_ids: Optional[torch.Tensor] = None,
1273
+ attention_mask: Optional[torch.Tensor] = None,
1274
+ token_type_ids: Optional[torch.Tensor] = None,
1275
+ special_tokens_mask: Optional[torch.Tensor] = None,
1276
+ position_ids: Optional[torch.Tensor] = None,
1277
+ inputs_embeds: Optional[torch.Tensor] = None,
1278
+ labels: Optional[torch.Tensor] = None,
1279
+ nsp_labels: Optional[torch.Tensor] = None,
1280
+ div_labels: Optional[torch.Tensor] = None,
1281
+ output_attentions: Optional[bool] = None,
1282
+ output_hidden_states: Optional[bool] = None,
1283
+ return_dict: Optional[bool] = None,
1284
+ return_pooled_output: Optional[bool] = False,
1285
+ **kwargs,
1286
+ ):
1287
+ r"""
1288
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1289
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1290
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1291
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1292
+ """
1293
+
1294
+ return_dict = (
1295
+ return_dict if return_dict is not None else self.config.use_return_dict
1296
+ )
1297
+
1298
+ outputs = self.bert(
1299
+ input_ids,
1300
+ attention_mask=attention_mask,
1301
+ token_type_ids=token_type_ids,
1302
+ position_ids=position_ids,
1303
+ inputs_embeds=inputs_embeds,
1304
+ output_attentions=output_attentions,
1305
+ output_hidden_states=output_hidden_states,
1306
+ return_dict=return_dict,
1307
+ )
1308
+
1309
+ sequence_output = outputs[0]
1310
+ prediction_scores = self.cls(sequence_output)
1311
+
1312
+ loss = None
1313
+ if labels is not None:
1314
+ loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
1315
+ loss = loss_fct(
1316
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1317
+ )
1318
+
1319
+ if not return_dict:
1320
+ output = (prediction_scores,) + outputs[2:]
1321
+ return ((loss,) + output) if loss is not None else output
1322
+
1323
+ return EnCodonOutput(
1324
+ loss=loss,
1325
+ logits=prediction_scores,
1326
+ hidden_states=outputs.hidden_states,
1327
+ attentions=outputs.attentions,
1328
+ )
1329
+
1330
+ def get_codon_embeddings(self) -> Module:
1331
+ return self.bert.embeddings.word_embeddings
1332
+
1333
+ def freeze_bert(self, layer_indices: Optional[list] = None):
1334
+ if layer_indices is None or len(layer_indices) == 0:
1335
+ for param in self.bert.parameters():
1336
+ param.requires_grad = False
1337
+ else:
1338
+ for param in self.bert.embeddings.parameters():
1339
+ param.requires_grad = False
1340
+
1341
+ if isinstance(layer_indices, int):
1342
+ layer_indices = [layer_indices]
1343
+
1344
+ layer_indices = [i % len(self.bert.encoder.layer) for i in layer_indices]
1345
+
1346
+ for i in range(len(self.bert.encoder.layer)):
1347
+ if i not in layer_indices:
1348
+ for param in self.bert.encoder.layer[i].parameters():
1349
+ param.requires_grad = False
1350
+
1351
+ for param in self.bert.pooler.parameters():
1352
+ param.requires_grad = False
1353
+
1354
+
1355
+ @dataclass
1356
+ class EnCodonForDMSOutput(ModelOutput):
1357
+ loss: Optional[torch.FloatTensor] = None
1358
+ logits: torch.FloatTensor = None
1359
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1360
+
1361
+
1362
+ class EnCodonForDMS(EnCodonPreTrainedModel):
1363
+ config_class = EnCodonForDMSConfig
1364
+ _tied_weights_keys = ["cls.layer.3.weight", "cls.layer.3.bias"]
1365
+
1366
+ def __init__(self, config):
1367
+ super().__init__(config)
1368
+
1369
+ self.bert = EnCodonModule(config)
1370
+ if self.config.lm_type == "bert":
1371
+ self.cls = BERTLMPredictionHead(config)
1372
+ else:
1373
+ self.cls = nn.Sequential(
1374
+ nn.Linear(config.hidden_size, config.hidden_size),
1375
+ nn.GELU(),
1376
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
1377
+ nn.Linear(config.hidden_size, config.vocab_size),
1378
+ )
1379
+
1380
+ def forward(
1381
+ self,
1382
+ input_ids: Optional[torch.Tensor] = None,
1383
+ alt_ids: Optional[torch.Tensor] = None,
1384
+ var_positions: Optional[torch.Tensor] = None,
1385
+ attention_mask: Optional[torch.Tensor] = None,
1386
+ token_type_ids: Optional[torch.Tensor] = None,
1387
+ position_ids: Optional[torch.Tensor] = None,
1388
+ inputs_embeds: Optional[torch.Tensor] = None,
1389
+ target: Optional[torch.Tensor] = None,
1390
+ output_attentions: Optional[bool] = None,
1391
+ output_hidden_states: Optional[bool] = None,
1392
+ return_dict: Optional[bool] = None,
1393
+ return_pooled_output: Optional[bool] = False,
1394
+ return_all_logits: Optional[bool] = False,
1395
+ **kwargs,
1396
+ ):
1397
+ r"""
1398
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1399
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1400
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1401
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1402
+ """
1403
+
1404
+ return_dict = (
1405
+ return_dict if return_dict is not None else self.config.use_return_dict
1406
+ )
1407
+
1408
+ outputs = self.bert(
1409
+ input_ids,
1410
+ attention_mask=attention_mask,
1411
+ token_type_ids=token_type_ids,
1412
+ position_ids=position_ids,
1413
+ inputs_embeds=inputs_embeds,
1414
+ output_attentions=output_attentions,
1415
+ output_hidden_states=output_hidden_states,
1416
+ return_dict=return_dict,
1417
+ )
1418
+
1419
+ sequence_output = outputs[0]
1420
+ prediction_scores = self.cls(
1421
+ sequence_output
1422
+ ) # (batch_size, seq_len, vocab_size)
1423
+
1424
+ # select the alt codon logits at the variant positions
1425
+ # alt_ids: (batch_size,)
1426
+ # var_positions: (batch_size,)
1427
+
1428
+ if return_all_logits:
1429
+ return EnCodonForDMSOutput(
1430
+ loss=None,
1431
+ logits=prediction_scores,
1432
+ attentions=outputs.attentions,
1433
+ )
1434
+ bs, seq_len, vocab_size = prediction_scores.shape
1435
+
1436
+ loss = None
1437
+ if var_positions is None and alt_ids is None:
1438
+ alt_prediction_scores = prediction_scores.gather(
1439
+ 2, input_ids.unsqueeze(-1)
1440
+ ).squeeze(
1441
+ 2
1442
+ ) # (batch_size, seq_len)
1443
+
1444
+ if target is not None:
1445
+ expanded_target = target
1446
+
1447
+ if self.config.loss_fn == "mse":
1448
+ loss_fct = nn.MSELoss()
1449
+ loss = loss_fct(alt_prediction_scores, expanded_target)
1450
+ elif self.config.loss_fn == "mae":
1451
+ loss_fct = nn.L1Loss()
1452
+ loss = loss_fct(alt_prediction_scores, expanded_target)
1453
+ elif self.config.loss_fn == "huber":
1454
+ loss_fct = nn.SmoothL1Loss()
1455
+ loss = loss_fct(alt_prediction_scores, expanded_target)
1456
+ else:
1457
+ raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.")
1458
+
1459
+ alt_prediction_scores = alt_prediction_scores.mean(dim=1) # (batch_size,)e
1460
+ else:
1461
+ alt_prediction_scores = prediction_scores[
1462
+ torch.arange(bs), var_positions, alt_ids
1463
+ ] # (batch_size,)
1464
+
1465
+ if target is not None:
1466
+ mask = target != -500.0
1467
+
1468
+ target = target[mask]
1469
+ alt_prediction_scores = alt_prediction_scores[mask]
1470
+
1471
+ if self.config.loss_fn == "mse":
1472
+ loss_fct = nn.MSELoss()
1473
+ loss = loss_fct(alt_prediction_scores, target)
1474
+ elif self.config.loss_fn == "mae":
1475
+ loss_fct = nn.L1Loss()
1476
+ loss = loss_fct(alt_prediction_scores, target)
1477
+ elif self.config.loss_fn == "huber":
1478
+ loss_fct = nn.SmoothL1Loss()
1479
+ loss = loss_fct(alt_prediction_scores, target)
1480
+ else:
1481
+ raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.")
1482
+
1483
+ if not return_dict:
1484
+ output = (alt_prediction_scores,) + outputs[2:]
1485
+ return ((loss,) + output) if loss is not None else output
1486
+
1487
+ return EnCodonForDMSOutput(
1488
+ loss=loss,
1489
+ logits=alt_prediction_scores,
1490
+ attentions=outputs.attentions,
1491
+ )
1492
+
1493
+ def get_codon_embeddings(self) -> Module:
1494
+ return self.bert.embeddings.word_embeddings
1495
+
1496
+ def freeze_bert(self, layers_idx: Optional[list] = None):
1497
+ if layers_idx is None or len(layers_idx) == 0:
1498
+ for param in self.bert.parameters():
1499
+ param.requires_grad = False
1500
+ else:
1501
+ for param in self.bert.embeddings.parameters():
1502
+ param.requires_grad = False
1503
+
1504
+ if isinstance(layers_idx, int):
1505
+ layers_idx = [layers_idx]
1506
+
1507
+ layers_idx = [i % len(self.bert.encoder.layer) for i in layers_idx]
1508
+
1509
+ for i in range(len(self.bert.encoder.layer)):
1510
+ if i not in layers_idx:
1511
+ for param in self.bert.encoder.layer[i].parameters():
1512
+ param.requires_grad = False
1513
+
1514
+ for param in self.bert.pooler.parameters():
1515
+ param.requires_grad = False
1516
+
1517
+
1518
+ class EnCodonForSequenceTask(EnCodonPreTrainedModel):
1519
+ def __init__(self, config):
1520
+ super().__init__(config)
1521
+ self.config = config
1522
+
1523
+ self.bert = EnCodonModule(config)
1524
+
1525
+ if config.cls_type.lower() == "cls":
1526
+ self.classifier = nn.Linear(
1527
+ config.hidden_size, config.num_labels * config.num_tasks
1528
+ )
1529
+ else:
1530
+ raise ValueError(f"Invalid cls_type: {config.cls_type}.")
1531
+
1532
+ self.init_weights()
1533
+
1534
+ def freeze_bert(self, layers_idx: Optional[list] = None):
1535
+ if layers_idx is None or len(layers_idx) == 0:
1536
+ for param in self.bert.parameters():
1537
+ param.requires_grad = False
1538
+ else:
1539
+ for param in self.bert.embeddings.parameters():
1540
+ param.requires_grad = False
1541
+
1542
+ if isinstance(layers_idx, int):
1543
+ layers_idx = [layers_idx]
1544
+
1545
+ layers_idx = [i % len(self.bert.encoder.layer) for i in layers_idx]
1546
+
1547
+ for i in range(len(self.bert.encoder.layer)):
1548
+ if i not in layers_idx:
1549
+ for param in self.bert.encoder.layer[i].parameters():
1550
+ param.requires_grad = False
1551
+
1552
+ if self.config.cls_type.lower() != "cls":
1553
+ for param in self.bert.pooler.parameters():
1554
+ param.requires_grad = False
1555
+
1556
+ def forward(
1557
+ self,
1558
+ input_ids: Optional[torch.Tensor] = None,
1559
+ target: Optional[torch.Tensor] = None,
1560
+ attention_mask: Optional[torch.Tensor] = None,
1561
+ token_type_ids: Optional[torch.Tensor] = None,
1562
+ position_ids: Optional[torch.Tensor] = None,
1563
+ head_mask: Optional[torch.Tensor] = None,
1564
+ inputs_embeds: Optional[torch.Tensor] = None,
1565
+ output_attentions: Optional[bool] = None,
1566
+ output_hidden_states: Optional[bool] = None,
1567
+ return_dict: Optional[bool] = None,
1568
+ **kwargs,
1569
+ ):
1570
+ return_dict = (
1571
+ return_dict if return_dict is not None else self.config.use_return_dict
1572
+ )
1573
+
1574
+ outputs = self.bert(
1575
+ input_ids=input_ids,
1576
+ attention_mask=attention_mask,
1577
+ token_type_ids=token_type_ids,
1578
+ position_ids=position_ids,
1579
+ head_mask=head_mask,
1580
+ inputs_embeds=inputs_embeds,
1581
+ output_attentions=output_attentions,
1582
+ output_hidden_states=True,
1583
+ return_dict=return_dict,
1584
+ )
1585
+
1586
+ all_hidden_states = outputs[2]
1587
+
1588
+ if self.config.cls_type.lower() not in ["crossattention", "ca", "cls"]:
1589
+ logits, _ = self.classifier(all_hidden_states, attention_mask)
1590
+ ca = None
1591
+ elif self.config.cls_type.lower() in ["crossattention", "ca"]:
1592
+ bs, seq_len = input_ids.shape
1593
+
1594
+ query_tasks = self.task_embeddings.weight # (num_tasks, hidden_size)
1595
+ query_tasks = query_tasks.unsqueeze(0).expand(
1596
+ bs, -1, -1
1597
+ ) # (batch_size, num_tasks, hidden_size)
1598
+
1599
+ cls_outputs = self.classifier(
1600
+ query_tasks,
1601
+ all_hidden_states,
1602
+ attention_mask,
1603
+ output_attentions=output_attentions,
1604
+ ) # (batch_size, num_tasks, num_labels)
1605
+
1606
+ logits, ca = cls_outputs
1607
+
1608
+ logits = logits.squeeze()
1609
+ elif self.config.cls_type.lower() == "cls":
1610
+ pooled_output = outputs[1]
1611
+ logits = self.classifier(pooled_output)
1612
+ ca = None
1613
+
1614
+ loss = None
1615
+ if target is not None:
1616
+ if self.config.problem_type == "regression":
1617
+ if self.config.loss_fn == "mse":
1618
+ loss_fct = nn.MSELoss()
1619
+ elif self.config.loss_fn == "mae":
1620
+ loss_fct = nn.L1Loss()
1621
+ elif self.config.loss_fn == "huber":
1622
+ loss_fct = nn.SmoothL1Loss()
1623
+ else:
1624
+ raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.")
1625
+
1626
+ logits = logits.view(-1, self.config.num_labels * self.config.num_tasks)
1627
+ target = target.view(-1, self.config.num_labels * self.config.num_tasks)
1628
+
1629
+ mask = target != -500.0
1630
+
1631
+ loss = loss_fct(logits[mask], target[mask])
1632
+ else:
1633
+ loss_fct = nn.CrossEntropyLoss()
1634
+
1635
+ logits = logits.view(-1, self.config.num_labels * self.config.num_tasks)
1636
+ target = target.view(
1637
+ -1,
1638
+ )
1639
+
1640
+ loss = loss_fct(logits, target)
1641
+
1642
+ if not return_dict:
1643
+ output = (logits,) + outputs[2:]
1644
+ return ((loss,) + output) if loss is not None else output
1645
+
1646
+ if output_attentions:
1647
+ if ca is not None:
1648
+ attentions = outputs.attentions + [ca]
1649
+ else:
1650
+ attentions = outputs.attentions
1651
+ else:
1652
+ attentions = None
1653
+
1654
+ return SequenceClassifierOutput(
1655
+ loss=loss,
1656
+ logits=logits,
1657
+ hidden_states=outputs.hidden_states,
1658
+ attentions=attentions,
1659
+ )