ohadmo commited on
Commit
d6f6cac
·
1 Parent(s): dc47c24

upload TE checkpoint

Browse files
Files changed (5) hide show
  1. LICENSE +178 -0
  2. README.md +167 -3
  3. config.json +35 -0
  4. geneformer.py +930 -0
  5. model.safetensors +3 -0
LICENSE ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+ "Licensor" shall mean the copyright owner or entity authorized by
12
+ the copyright owner that is granting the License.
13
+ "Legal Entity" shall mean the union of the acting entity and all
14
+ other entities that control, are controlled by, or are under common
15
+ control with that entity. For the purposes of this definition,
16
+ "control" means (i) the power, direct or indirect, to cause the
17
+ direction or management of such entity, whether by contract or
18
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
19
+ outstanding shares, or (iii) beneficial ownership of such entity.
20
+ "You" (or "Your") shall mean an individual or Legal Entity
21
+ exercising permissions granted by this License.
22
+ "Source" form shall mean the preferred form for making modifications,
23
+ including but not limited to software source code, documentation
24
+ source, and configuration files.
25
+ "Object" form shall mean any form resulting from mechanical
26
+ transformation or translation of a Source form, including but
27
+ not limited to compiled object code, generated documentation,
28
+ and conversions to other media types.
29
+ "Work" shall mean the work of authorship, whether in Source or
30
+ Object form, made available under the License, as indicated by a
31
+ copyright notice that is included in or attached to the work
32
+ (an example is provided in the Appendix below).
33
+ "Derivative Works" shall mean any work, whether in Source or Object
34
+ form, that is based on (or derived from) the Work and for which the
35
+ editorial revisions, annotations, elaborations, or other modifications
36
+ represent, as a whole, an original work of authorship. For the purposes
37
+ of this License, Derivative Works shall not include works that remain
38
+ separable from, or merely link (or bind by name) to the interfaces of,
39
+ the Work and Derivative Works thereof.
40
+ "Contribution" shall mean any work of authorship, including
41
+ the original version of the Work and any modifications or additions
42
+ to that Work or Derivative Works thereof, that is intentionally
43
+ submitted to Licensor for inclusion in the Work by the copyright owner
44
+ or by an individual or Legal Entity authorized to submit on behalf of
45
+ the copyright owner. For the purposes of this definition, "submitted"
46
+ means any form of electronic, verbal, or written communication sent
47
+ to the Licensor or its representatives, including but not limited to
48
+ communication on electronic mailing lists, source code control systems,
49
+ and issue tracking systems that are managed by, or on behalf of, the
50
+ Licensor for the purpose of discussing and improving the Work, but
51
+ excluding communication that is conspicuously marked or otherwise
52
+ designated in writing by the copyright owner as "Not a Contribution."
53
+ "Contributor" shall mean Licensor and any individual or Legal Entity
54
+ on behalf of whom a Contribution has been received by Licensor and
55
+ subsequently incorporated within the Work.
56
+ 2. Grant of Copyright License. Subject to the terms and conditions of
57
+ this License, each Contributor hereby grants to You a perpetual,
58
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
59
+ copyright license to reproduce, prepare Derivative Works of,
60
+ publicly display, publicly perform, sublicense, and distribute the
61
+ Work and such Derivative Works in Source or Object form.
62
+ 3. Grant of Patent License. Subject to the terms and conditions of
63
+ this License, each Contributor hereby grants to You a perpetual,
64
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
65
+ (except as stated in this section) patent license to make, have made,
66
+ use, offer to sell, sell, import, and otherwise transfer the Work,
67
+ where such license applies only to those patent claims licensable
68
+ by such Contributor that are necessarily infringed by their
69
+ Contribution(s) alone or by combination of their Contribution(s)
70
+ with the Work to which such Contribution(s) was submitted. If You
71
+ institute patent litigation against any entity (including a
72
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
73
+ or a Contribution incorporated within the Work constitutes direct
74
+ or contributory patent infringement, then any patent licenses
75
+ granted to You under this License for that Work shall terminate
76
+ as of the date such litigation is filed.
77
+ 4. Redistribution. You may reproduce and distribute copies of the
78
+ Work or Derivative Works thereof in any medium, with or without
79
+ modifications, and in Source or Object form, provided that You
80
+ meet the following conditions:
81
+ (a) You must give any other recipients of the Work or
82
+ Derivative Works a copy of this License; and
83
+ (b) You must cause any modified files to carry prominent notices
84
+ stating that You changed the files; and
85
+ (c) You must retain, in the Source form of any Derivative Works
86
+ that You distribute, all copyright, patent, trademark, and
87
+ attribution notices from the Source form of the Work,
88
+ excluding those notices that do not pertain to any part of
89
+ the Derivative Works; and
90
+ (d) If the Work includes a "NOTICE" text file as part of its
91
+ distribution, then any Derivative Works that You distribute must
92
+ include a readable copy of the attribution notices contained
93
+ within such NOTICE file, excluding those notices that do not
94
+ pertain to any part of the Derivative Works, in at least one
95
+ of the following places: within a NOTICE text file distributed
96
+ as part of the Derivative Works; within the Source form or
97
+ documentation, if provided along with the Derivative Works; or,
98
+ within a display generated by the Derivative Works, if and
99
+ wherever such third-party notices normally appear. The contents
100
+ of the NOTICE file are for informational purposes only and
101
+ do not modify the License. You may add Your own attribution
102
+ notices within Derivative Works that You distribute, alongside
103
+ or as an addendum to the NOTICE text from the Work, provided
104
+ that such additional attribution notices cannot be construed
105
+ as modifying the License.
106
+ You may add Your own copyright statement to Your modifications and
107
+ may provide additional or different license terms and conditions
108
+ for use, reproduction, or distribution of Your modifications, or
109
+ for any such Derivative Works as a whole, provided Your use,
110
+ reproduction, and distribution of the Work otherwise complies with
111
+ the conditions stated in this License.
112
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
113
+ any Contribution intentionally submitted for inclusion in the Work
114
+ by You to the Licensor shall be under the terms and conditions of
115
+ this License, without any additional terms or conditions.
116
+ Notwithstanding the above, nothing herein shall supersede or modify
117
+ the terms of any separate license agreement you may have executed
118
+ with Licensor regarding such Contributions.
119
+ 6. Trademarks. This License does not grant permission to use the trade
120
+ names, trademarks, service marks, or product names of the Licensor,
121
+ except as required for reasonable and customary use in describing the
122
+ origin of the Work and reproducing the content of the NOTICE file.
123
+ 7. Disclaimer of Warranty. Unless required by applicable law or
124
+ agreed to in writing, Licensor provides the Work (and each
125
+ Contributor provides its Contributions) on an "AS IS" BASIS,
126
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
127
+ implied, including, without limitation, any warranties or conditions
128
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
129
+ PARTICULAR PURPOSE. You are solely responsible for determining the
130
+ appropriateness of using or redistributing the Work and assume any
131
+ risks associated with Your exercise of permissions under this License.
132
+ 8. Limitation of Liability. In no event and under no legal theory,
133
+ whether in tort (including negligence), contract, or otherwise,
134
+ unless required by applicable law (such as deliberate and grossly
135
+ negligent acts) or agreed to in writing, shall any Contributor be
136
+ liable to You for damages, including any direct, indirect, special,
137
+ incidental, or consequential damages of any character arising as a
138
+ result of this License or out of the use or inability to use the
139
+ Work (including but not limited to damages for loss of goodwill,
140
+ work stoppage, computer failure or malfunction, or any and all
141
+ other commercial damages or losses), even if such Contributor
142
+ has been advised of the possibility of such damages.
143
+ 9. Accepting Warranty or Additional Liability. While redistributing
144
+ the Work or Derivative Works thereof, You may choose to offer,
145
+ and charge a fee for, acceptance of support, warranty, indemnity,
146
+ or other liability obligations and/or rights consistent with this
147
+ License. However, in accepting such obligations, You may act only
148
+ on Your own behalf and on Your sole responsibility, not on behalf
149
+ of any other Contributor, and only if You agree to indemnify,
150
+ defend, and hold each Contributor harmless for any liability
151
+ incurred by, or claims asserted against, such Contributor by reason
152
+ of your accepting any such warranty or additional liability.
153
+ END OF TERMS AND CONDITIONS
154
+
155
+ APPENDIX: How to apply the Apache License to your work.
156
+
157
+ To apply the Apache License to your work, attach the following
158
+ boilerplate notice, with the fields enclosed by brackets "[]"
159
+ replaced with your own identifying information. (Don't include
160
+ the brackets!) The text should be enclosed in the appropriate
161
+ comment syntax for the file format. We also recommend that a
162
+ file or class name and description of purpose be included on the
163
+ same "printed page" as the copyright notice for easier
164
+ identification within third-party archives.
165
+
166
+ Copyright 2022 Theodoris Lab, Gladstone Institute and The HuggingFace Inc. team. All rights reserved.
167
+ Copyright 2025 NVIDIA CORPORATION. All rights reserved.
168
+
169
+ Licensed under the Apache License, Version 2.0 (the "License");
170
+ you may not use this file except in compliance with the License.
171
+ You may obtain a copy of the License at
172
+
173
+ http://www.apache.org/licenses/LICENSE-2.0
174
+ Unless required by applicable law or agreed to in writing, software
175
+ distributed under the License is distributed on an "AS IS" BASIS,
176
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
177
+ See the License for the specific language governing permissions and
178
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,167 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets: ctheodoris/Genecorpus-30M
3
+ library_name: transformers
4
+ license: apache-2.0
5
+ tags:
6
+ - single-cell
7
+ - genomics
8
+ ---
9
+
10
+ # Geneformer-10M (TransformerEngine-Optimized) Overview
11
+
12
+ ## Description:
13
+ Geneformer is a foundational transformer model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology.
14
+
15
+ This version of the Geneformer model is optimized with NVIDIA's [TransformerEngine](https://github.com/NVIDIA/TransformerEngine) library. It is based on the original Geneformer V1 model, and (within numerical precision) has identical weights and outputs.
16
+
17
+ This model is ready for commercial/non-commercial use.
18
+
19
+ ## Third-Party Community Consideration
20
+ This model is not owned or developed by NVIDIA. This model has been developed and built to a third-party's requirements for this application and use case; see link to Non-NVIDIA Model Card [Geneformer Model Card](https://huggingface.co/ctheodoris/Geneformer).
21
+
22
+ ### License/Terms of Use:
23
+ Geneformer is licensed under the [Apache 2.0 license](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md).
24
+
25
+ ### Deployment Geography:
26
+ Global
27
+
28
+ ### Use Case:
29
+ Network biology and therapeutic discovery, particularly in data-limited settings such as rare diseases or diseases affecting hard-to-access tissues.
30
+
31
+ ### Release Date:
32
+ Hugging Face 12/19/2025 via [https://huggingface.co/nvidia/geneformer_V1_10M](https://huggingface.co/nvidia/geneformer_V1_10M)
33
+
34
+ ## Reference(s):
35
+ * [Transfer learning enables predictions in network biology](https://www.nature.com/articles/s41586-023-06139-9.epdf?sharing_token=u_5LUGVkd3A8zR-f73lU59RgN0jAjWel9jnR3ZoTv0N2UB4yyXENUK50s6uqjXH69sDxh4Z3J4plYCKlVME-W2WSuRiS96vx6t5ex2-krVDS46JkoVvAvJyWtYXIyj74pDWn_DutZq1oAlDaxfvBpUfSKDdBPJ8SKlTId8uT47M%3D) - details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of the in silico perturbation and cell and gene classification strategies.
36
+ * [Quantized multi-task learning for context-specific representations of gene network dynamics](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) - the expanded model, trained on ~104 million transcriptomes, and continual learning, multitask learning, and quantization strategies.
37
+ * See [geneformer.readthedocs.io](https://geneformer.readthedocs.io/) for documentation.
38
+
39
+ ## Model Architecture:
40
+ **Architecture Type:** Transformer
41
+ **Network Architecture:** BERT
42
+
43
+ **This model was developed based on:** [Geneformer](https://huggingface.co/ctheodoris/Geneformer) <br>
44
+ **Number of model parameters:** 1 x 10^7
45
+
46
+ ## Input:
47
+ **Input Type:** Number (Row represents cell, containing gene names and single cell expression counts) <br>
48
+ **Input Format:** Array [AnnData](https://anndata.readthedocs.io/en/latest/) <br>
49
+ **Input Parameters:** One-Dimensional (1D) <br>
50
+ **Other Properties Related to Input:** This model supports a context length of 2048.
51
+
52
+ ## Output:
53
+ **Output Type:** Dense Embedding Predictions <br>
54
+ **Output Format:** Vector <br>
55
+ **Output Parameters:** One-Dimensional (1D) <br>
56
+ **Other Properties Related to Output:** Numeric floating point vector (fp16, bf16, or fp32); Geneformer-10M outputs 256 dimensional embeddings.
57
+
58
+ Our AI models are designed and/or optimized to run on NVIDIA GPU-accelerated systems. By leveraging NVIDIA’s hardware (e.g. GPU cores) and software frameworks (e.g., CUDA libraries), the model achieves faster training and inference times compared to CPU-only solutions.
59
+
60
+ ## Software Integration:
61
+ **Runtime Engine(s):**
62
+ * Transformer Engine
63
+ * PyTorch
64
+
65
+ **Supported Hardware Microarchitecture Compatibility:**
66
+ * A100
67
+ * H100
68
+ * H200
69
+ * GB200
70
+
71
+ **Preferred/Supported Operating System(s):**
72
+ * Linux
73
+
74
+ The integration of foundation and fine-tuned models into AI systems requires additional testing using use-case-specific data to ensure safe and effective deployment. Following the V-model methodology, iterative testing and validation at both unit and system levels are essential to mitigate risks, meet technical and functional requirements, and ensure compliance with safety and ethical standards before deployment.
75
+
76
+ ## Model Version(s):
77
+ * Geneformer-V1-10M
78
+ * Geneformer-V2-104M
79
+ * Geneformer-V2-316M
80
+ * Geneformer-V2-104M_CLcancer
81
+
82
+
83
+ ## Training and Evaluation Datasets:
84
+
85
+ ## Training Datasets:
86
+ **Link:** [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M)
87
+
88
+ **Data Modality:**
89
+ * Text (Human single-cell transcriptomes)
90
+
91
+ **Text Training Data Size:**
92
+ * 1 Billion to 10 Trillion Tokens
93
+
94
+ **Data Collection Method by dataset:**
95
+ * Human
96
+
97
+ **Labeling Method by dataset:**
98
+ * N/A
99
+
100
+ **Properties:** The single-cell transcriptomes were assembled from a broad range of publicly available data sources. The researchers collected raw counts from sources like NCBI Gene Expression Omnibus (GEO), Human Cell Atlas, and Tumor Immune Single-cell Hub (TISCH), among others. They excluded cells with high mutational burdens, such as malignant cells and immortalized cell lines, and included only droplet-based sequencing platforms to ensure data comparability. The raw data was then converted into a uniform loom HDF5 file format.
101
+
102
+ ## Evaluation Datasets:
103
+ **Link:** [A cross-disorder dosage sensitivity map of the human genome](https://zenodo.org/records/6347673)
104
+
105
+ **Data Collection Method by dataset:**
106
+ * Human
107
+
108
+ **Labeling Method by dataset:**
109
+ * Not Applicable <!-- there are no labels for this dataset -->
110
+
111
+ **Properties:** The data was collected by harmonizing and meta-analyzing rare copy-number variants (rCNVs) from nearly one million individuals across 54 different disorders. This approach created a genome-wide catalog of dosage sensitivity.
112
+
113
+ **Link:** [Single-cell Transcriptome Analysis Reveals Dynamic Cell Populations and Differential Gene Expression Patterns in Control and Aneurysmal Human Aortic Tissue](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE155468)
114
+
115
+ **Data Collection Method by dataset:**
116
+ * Human
117
+
118
+ **Labeling Method by dataset:**
119
+ * Human
120
+
121
+ **Properties:** The data was collected by performing single-cell RNA sequencing (scRNA-seq) on human ascending aortic tissues. Tissues were obtained from 11 study participants, consisting of 8 patients with ascending thoracic aortic aneurysm (ATAA) and 3 control subjects.
122
+
123
+ **Link:** [Systematic Comparison of High-throughput Single-Cell and Single-Nucleus Transcriptomes during Cardiomyocyte Differentiation](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE129096)
124
+
125
+ **Data Collection Method by dataset:**
126
+ * Automated
127
+
128
+ **Labeling Method by dataset:**
129
+ * Human
130
+
131
+ **Properties:** The researchers used two different sequencing platforms to collect data from the same biological process: induced pluripotent stem cell (iPSC) differentiation into cardiomyocytes. The two platforms used were Drop-seq (single-cell) and DroNc-seq (single-nucleus). The study involved two iPSC lines and collected data over a 15-day time period.
132
+
133
+ **Link:** [A human cell atlas of fetal gene expression](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE156793)
134
+
135
+ **Data Collection Method by dataset:**
136
+ * Human
137
+
138
+ **Labeling Method by dataset:**
139
+ * Hybrid: Human, Automated
140
+
141
+ **Properties:** The data was collected by profiling the gene expression of millions of single cells from 15 different human fetal organs.
142
+
143
+ **Link:** [Single-nuclei profiling of human dilated and hypertrophic cardiomyopathy](https://singlecell.broadinstitute.org/single_cell/study/SCP1303/single-nuclei-profiling-of-human-dilated-and-hypertrophic-cardiomyopathy#study-summary)
144
+
145
+ **Data Collection Method by dataset:**
146
+ * Human
147
+
148
+ **Labeling Method by dataset:**
149
+ * Hybrid: Human, Automated
150
+
151
+ **Properties:** The data was collected by performing single-nucleus RNA sequencing (snRNA-seq) on left ventricle samples from human hearts. The study included samples from 11 hearts with dilated cardiomyopathy, 15 hearts with hypertrophic cardiomyopathy, and 16 non-failing hearts. In total, nearly 600,000 nuclei were sequenced.
152
+
153
+ ## Inference:
154
+ **Acceleration Engine:** Transformer Engine, PyTorch
155
+
156
+ **Test Hardware:**
157
+ * A100
158
+ * H100
159
+ * H200
160
+ * GB200
161
+
162
+ ## Ethical Considerations:
163
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
164
+
165
+ Users are responsible for ensuring the physical properties of model-generated molecules are appropriately evaluated and comply with applicable safety regulations and ethical standards.
166
+
167
+ Please report model quality, risk, security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "architectures": [
4
+ "BertForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.02,
7
+ "auto_map": {
8
+ "AutoConfig": "geneformer.TEBertConfig",
9
+ "AutoModel": "geneformer.BertModel",
10
+ "AutoModelForMaskedLM": "geneformer.BertForMaskedLM"
11
+ },
12
+ "classifier_dropout": null,
13
+ "framework": "pytorch",
14
+ "fuse_qkv_params": true,
15
+ "gradient_checkpointing": false,
16
+ "hidden_act": "relu",
17
+ "hidden_dropout_prob": 0.02,
18
+ "hidden_size": 256,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 512,
21
+ "layer_norm_eps": 1e-12,
22
+ "max_position_embeddings": 2048,
23
+ "micro_batch_size": null,
24
+ "model_type": "bert",
25
+ "num_attention_heads": 4,
26
+ "num_hidden_layers": 6,
27
+ "pad_token_id": 0,
28
+ "position_embedding_type": "absolute",
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.51.3",
31
+ "type_vocab_size": 2,
32
+ "use_cache": true,
33
+ "use_te_layers": true,
34
+ "vocab_size": 25426
35
+ }
geneformer.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-Apache2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # coding=utf-8
17
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
18
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
19
+ #
20
+ # Licensed under the Apache License, Version 2.0 (the "License");
21
+ # you may not use this file except in compliance with the License.
22
+ # You may obtain a copy of the License at
23
+ #
24
+ # http://www.apache.org/licenses/LICENSE-2.0
25
+ #
26
+ # Unless required by applicable law or agreed to in writing, software
27
+ # distributed under the License is distributed on an "AS IS" BASIS,
28
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29
+ # See the License for the specific language governing permissions and
30
+ # limitations under the License.
31
+
32
+
33
+ """PyTorch BERT model with and without transformer engine layers.
34
+
35
+ This file is a modified version of the BERT model from the Hugging Face Transformers library.
36
+ It includes a custom BERT encoder that can be used with or without transformer engine layers.
37
+
38
+ The BERT encoder is a modified version of the encoder from the Hugging Face Transformers library.
39
+ It includes a custom BERT layer that can be used with or without transformer engine layers.
40
+ """
41
+
42
+ from typing import ClassVar, List, Optional, Tuple, Union
43
+
44
+ import torch
45
+ import transformer_engine.pytorch as te
46
+ from torch import nn
47
+ from torch.nn import CrossEntropyLoss
48
+ from transformers.modeling_attn_mask_utils import (
49
+ _prepare_4d_attention_mask_for_sdpa,
50
+ _prepare_4d_causal_attention_mask_for_sdpa,
51
+ )
52
+ from transformers.modeling_outputs import (
53
+ BaseModelOutputWithPastAndCrossAttentions,
54
+ BaseModelOutputWithPoolingAndCrossAttentions,
55
+ MaskedLMOutput,
56
+ )
57
+ from transformers.models.bert.configuration_bert import BertConfig
58
+ from transformers.models.bert.modeling_bert import (
59
+ BertEmbeddings,
60
+ BertLayer,
61
+ BertOnlyMLMHead,
62
+ BertPooler,
63
+ BertPreTrainedModel,
64
+ )
65
+ from transformers.utils import logging
66
+
67
+
68
+ logger = logging.get_logger(__name__)
69
+
70
+ _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
71
+ _CONFIG_FOR_DOC = "BertConfig"
72
+
73
+
74
+ class TEBertConfig(BertConfig):
75
+ """Configuration class for the TE BERT model.
76
+
77
+ This class is a subclass of BertConfig, and it adds the following attributes:
78
+ - torch_dtype: The dtype of the model parameters.
79
+ - use_te_layers: Whether to use the TE layers.
80
+ - micro_batch_size: The micro batch size for TE layers.
81
+ """
82
+
83
+ def __init__(self, **kwargs):
84
+ """Initialize the TEBertConfig.
85
+
86
+ Args:
87
+ **kwargs: Additional keyword arguments to pass to BertConfig.
88
+ """
89
+ super().__init__(**kwargs)
90
+ # TODO(@jomitchell): Fix this in JIRA BIONEMO-2406
91
+ torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
92
+ # Convert string dtype to torch dtype if needed
93
+ if isinstance(torch_dtype, str):
94
+ if torch_dtype == "bfloat16":
95
+ torch_dtype = torch.bfloat16
96
+ elif torch_dtype == "float16":
97
+ torch_dtype = torch.float16
98
+ elif torch_dtype == "float32":
99
+ torch_dtype = torch.float32
100
+ else:
101
+ raise ValueError(f"Unsupported dtype: {torch_dtype}")
102
+ self.torch_dtype = torch_dtype
103
+ self.use_te_layers = kwargs.get("use_te_layers", False)
104
+ self.micro_batch_size = kwargs.get("micro_batch_size", None)
105
+ self.fuse_qkv_params = kwargs.get("fuse_qkv_params", False)
106
+
107
+
108
+ class TEBertLayer(nn.Module):
109
+ """Custom BERT layer using individual TE components for correct post-norm architecture.
110
+
111
+ This builds a BERT-style post-norm layer using:
112
+ - te.MultiheadAttention (with input_layernorm=False)
113
+ - te.LayerNorm for post-attention normalization as layernorm
114
+ - te.Linear for MLP layers (fc1, fc2) wrapped in layernorm_mlp module
115
+ - te.LayerNorm for post-MLP normalization as layernorm_mlp.layer_norm
116
+
117
+ Parameter naming matches convert.py expectations for weight loading from HF checkpoints.
118
+
119
+ DIVERGENCE FROM TYPICAL TRANSFORMERLAYER:
120
+ This implementation uses POST-norm architecture, which differs significantly from the
121
+ typical TransformerLayer that uses PRE-norm.
122
+
123
+ Geneformer/HF BERT (POST-norm, output_layernorm=True equivalent):
124
+ Input -> Attention -> Dropout -> Residual Add -> LayerNorm
125
+ -> MLP -> Dropout -> Residual Add -> LayerNorm -> Output
126
+
127
+ Typical TransformerLayer (PRE-norm, output_layernorm=False default):
128
+ Input -> [LayerNorm Attn inside MultiheadAttention] -> Dropout -> Residual Add
129
+ -> [LayerNorm MLP inside LayerNormMLP] -> Dropout -> Residual Add -> Output
130
+
131
+ Geneformer applies LayerNorm AFTER residual connections as
132
+ explicit separate modules, whereas typical TransformerLayer applies LayerNorm Before
133
+ operations via input_layernorm=True inside MultiheadAttention and LayerNormMLP modules.
134
+
135
+ For more information, see:
136
+ https://github.com/NVIDIA/TransformerEngine/blob/dd9433e7ad28c12f27da9770be54c9c584e85fa0/transformer_engine/pytorch/transformer.py#L822
137
+ """
138
+
139
+ def __init__(self, config, layer_number=None):
140
+ """Initialize the TEBertLayer.
141
+
142
+ Args:
143
+ config: Configuration object containing model parameters.
144
+ layer_number: Optional layer number for identification.
145
+ """
146
+ super().__init__()
147
+
148
+ self.hidden_size = config.hidden_size
149
+ self.num_attention_heads = config.num_attention_heads
150
+ self.layer_number = layer_number
151
+ self.is_decoder = config.is_decoder
152
+ self.add_cross_attention = config.add_cross_attention
153
+
154
+ # Self-attention using TE MultiheadAttention
155
+ self.self_attention = te.MultiheadAttention(
156
+ hidden_size=config.hidden_size,
157
+ num_attention_heads=config.num_attention_heads,
158
+ num_gqa_groups=config.num_attention_heads,
159
+ attention_dropout=config.attention_probs_dropout_prob,
160
+ input_layernorm=False, # No LayerNorm before attention
161
+ attention_type="self",
162
+ layer_number=layer_number,
163
+ attn_mask_type="padding",
164
+ params_dtype=config.torch_dtype,
165
+ fuse_qkv_params=getattr(config, "fuse_qkv_params", False),
166
+ window_size=(-1, -1), # No sliding window attention
167
+ qkv_format="bshd", # BERT uses [batch, seq, head, dim]
168
+ )
169
+
170
+ # Post-attention TE LayerNorm
171
+ self.layernorm = te.LayerNorm(
172
+ normalized_shape=config.hidden_size,
173
+ eps=config.layer_norm_eps,
174
+ params_dtype=config.torch_dtype,
175
+ )
176
+
177
+ # MLP using TE Linear layers
178
+ self.layernorm_mlp = nn.Module()
179
+ self.layernorm_mlp.fc1 = te.Linear(
180
+ config.hidden_size,
181
+ config.intermediate_size,
182
+ bias=True,
183
+ params_dtype=config.torch_dtype,
184
+ )
185
+
186
+ if config.hidden_act != "relu":
187
+ raise ValueError(f"Geneformer requires hidden_act='relu', got '{config.hidden_act}'")
188
+ self.layernorm_mlp.activation = nn.ReLU()
189
+
190
+ self.layernorm_mlp.fc2 = te.Linear(
191
+ config.intermediate_size,
192
+ config.hidden_size,
193
+ bias=True,
194
+ params_dtype=config.torch_dtype,
195
+ )
196
+
197
+ # Post-MLP LayerNorm
198
+ self.layernorm_mlp.layer_norm = te.LayerNorm(
199
+ normalized_shape=config.hidden_size,
200
+ eps=config.layer_norm_eps,
201
+ params_dtype=config.torch_dtype,
202
+ )
203
+
204
+ # Dropout
205
+ self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
206
+ self.mlp_dropout = nn.Dropout(config.hidden_dropout_prob)
207
+
208
+ def forward(
209
+ self,
210
+ hidden_states: torch.Tensor,
211
+ attention_mask: Optional[torch.FloatTensor] = None,
212
+ head_mask: Optional[torch.FloatTensor] = None,
213
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
214
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
215
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
216
+ output_attentions: Optional[bool] = False,
217
+ ) -> Tuple[torch.Tensor]:
218
+ """Forward pass through the TE BERT layer.
219
+
220
+ Architecture
221
+ Input
222
+ → Self-Attention
223
+ → Dropout
224
+ → Residual Connection
225
+ → LayerNorm
226
+ → MLP
227
+ → Dropout
228
+ → Residual Connection
229
+ → LayerNorm
230
+ → Output
231
+
232
+ This architecture is the key divergence from typical TransformerLayer
233
+ (with output_layernorm=False default) which uses PRE-norm.
234
+
235
+ In PRE-norm TransformerLayer, LayerNorm is applied Before operations:
236
+ - MultiheadAttention with input_layernorm=True applies LayerNorm internally before attention
237
+ - LayerNormMLP applies LayerNorm internally before MLP
238
+ - Residuals bypass these internal LayerNorms
239
+
240
+ In Geneformer's POST-norm, LayerNorm is applied after residual connections as explicit
241
+ separate modules, meaning the normalized output flows to the next layer.
242
+
243
+ Args:
244
+ hidden_states: Input hidden states.
245
+ attention_mask: Attention mask.
246
+ head_mask: Head mask.
247
+ encoder_hidden_states: Encoder hidden states.
248
+ encoder_attention_mask: Encoder attention mask.
249
+ past_key_value: Past key value.
250
+ output_attentions: Whether to output attentions.
251
+
252
+ Returns:
253
+ Tuple of tensors containing the layer output.
254
+ """
255
+ # Attention mask handling for TE MultiheadAttention, [batch, 1, 1, seq_len], True=masked, False=attend
256
+ te_attention_mask = None
257
+ te_mask_type = "no_mask"
258
+
259
+ if attention_mask is not None:
260
+ # Check if there's actual padding (not all 1s for 2D or not all 0s for 4D)
261
+ if attention_mask.dim() == 2:
262
+ # Standard [batch, seq_len] where 1=attend, 0=masked
263
+ has_padding = not torch.all(attention_mask == 1)
264
+ if has_padding:
265
+ # Convert to TE format: [batch, 1, 1, seq_len], invert polarity
266
+ te_attention_mask = ~attention_mask.bool().unsqueeze(1).unsqueeze(1)
267
+ te_mask_type = "padding"
268
+ elif attention_mask.dim() in [3, 4]:
269
+ # Extended mask with -inf for masked positions
270
+
271
+ has_masking = torch.any(
272
+ attention_mask < -10000.0
273
+ ) # Check if it's not a trivial mask (all zeros/no masking)
274
+ if has_masking:
275
+ # Extract padding mask and convert to TE format
276
+ if attention_mask.dim() == 4:
277
+ padding_mask = attention_mask[:, 0, 0, :] # [batch, seq_len]
278
+ else: # dim == 3
279
+ padding_mask = attention_mask[:, 0, :] # [batch, seq_len]
280
+ # -inf to True (masked), 0 to False (attend)
281
+ # Then reshape to [batch, 1, 1, seq_len]
282
+ te_attention_mask = (padding_mask < -10000.0).unsqueeze(1).unsqueeze(1)
283
+ te_mask_type = "padding"
284
+
285
+ # Self-Attention sub-layer
286
+ attention_output = self.self_attention(
287
+ hidden_states,
288
+ attention_mask=te_attention_mask,
289
+ attn_mask_type=te_mask_type,
290
+ )
291
+
292
+ # Residual connection + dropout + LayerNorm (POST-norm)
293
+ attention_output = self.attention_dropout(attention_output)
294
+ hidden_states = hidden_states + attention_output
295
+ hidden_states = self.layernorm(hidden_states)
296
+
297
+ # MLP sub-layer
298
+ mlp_output = self.layernorm_mlp.fc1(hidden_states)
299
+ mlp_output = self.layernorm_mlp.activation(mlp_output)
300
+ mlp_output = self.layernorm_mlp.fc2(mlp_output)
301
+
302
+ # Residual connection + dropout + LayerNorm (POST-norm)
303
+ mlp_output = self.mlp_dropout(mlp_output)
304
+ hidden_states = hidden_states + mlp_output
305
+ hidden_states = self.layernorm_mlp.layer_norm(hidden_states)
306
+
307
+ return (hidden_states,)
308
+
309
+
310
+ class BertEncoder(nn.Module):
311
+ def __init__(self, config):
312
+ super().__init__()
313
+ self.config = config
314
+ if self.config.use_te_layers:
315
+ self.layer = nn.ModuleList(
316
+ [TEBertLayer(config, layer_number=i + 1) for i in range(config.num_hidden_layers)]
317
+ )
318
+ else:
319
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
320
+ self.gradient_checkpointing = False
321
+
322
+ def _process_layer_outputs(
323
+ self,
324
+ layer_outputs,
325
+ hidden_states,
326
+ all_hidden_states,
327
+ all_self_attentions,
328
+ all_cross_attentions,
329
+ output_hidden_states,
330
+ output_attentions,
331
+ use_cache,
332
+ next_decoder_cache,
333
+ ):
334
+ """Process outputs from a single layer."""
335
+ hidden_states = layer_outputs[0]
336
+
337
+ if use_cache and next_decoder_cache is not None:
338
+ next_decoder_cache = (*next_decoder_cache, layer_outputs[-1])
339
+
340
+ if output_attentions and len(layer_outputs) > 1:
341
+ if all_self_attentions is None:
342
+ all_self_attentions = (layer_outputs[1],)
343
+ else:
344
+ all_self_attentions = (*all_self_attentions, layer_outputs[1])
345
+ if self.config.add_cross_attention and len(layer_outputs) > 2:
346
+ if all_cross_attentions is None:
347
+ all_cross_attentions = (layer_outputs[2],)
348
+ else:
349
+ all_cross_attentions = (*all_cross_attentions, layer_outputs[2])
350
+
351
+ return hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions, next_decoder_cache
352
+
353
+ def forward(
354
+ self,
355
+ hidden_states: torch.Tensor,
356
+ attention_mask: Optional[torch.FloatTensor] = None,
357
+ head_mask: Optional[torch.FloatTensor] = None,
358
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
359
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
360
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
361
+ use_cache: Optional[bool] = None,
362
+ output_attentions: Optional[bool] = False,
363
+ output_hidden_states: Optional[bool] = False,
364
+ return_dict: Optional[bool] = True,
365
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
366
+ all_hidden_states = () if output_hidden_states else None
367
+ all_self_attentions = () if output_attentions else None
368
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
369
+
370
+ if self.gradient_checkpointing and self.training:
371
+ if use_cache:
372
+ logger.warning(
373
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
374
+ )
375
+ use_cache = False
376
+
377
+ next_decoder_cache = () if use_cache else None
378
+ for i, layer_module in enumerate(self.layer):
379
+ if output_hidden_states:
380
+ if all_hidden_states is None:
381
+ all_hidden_states = (hidden_states,)
382
+ else:
383
+ all_hidden_states = (*all_hidden_states, hidden_states)
384
+
385
+ layer_head_mask = head_mask[i] if head_mask is not None else None
386
+ past_key_value = past_key_values[i] if past_key_values is not None else None
387
+
388
+ if self.gradient_checkpointing and self.training:
389
+ from torch.utils.checkpoint import checkpoint
390
+
391
+ layer_outputs = checkpoint(
392
+ layer_module,
393
+ hidden_states,
394
+ attention_mask,
395
+ layer_head_mask,
396
+ encoder_hidden_states,
397
+ encoder_attention_mask,
398
+ past_key_value,
399
+ output_attentions,
400
+ use_reentrant=False,
401
+ )
402
+ else:
403
+ layer_outputs = layer_module(
404
+ hidden_states,
405
+ attention_mask,
406
+ layer_head_mask,
407
+ encoder_hidden_states,
408
+ encoder_attention_mask,
409
+ past_key_value,
410
+ output_attentions,
411
+ )
412
+
413
+ hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions, next_decoder_cache = (
414
+ self._process_layer_outputs(
415
+ layer_outputs,
416
+ hidden_states,
417
+ all_hidden_states,
418
+ all_self_attentions,
419
+ all_cross_attentions,
420
+ output_hidden_states,
421
+ output_attentions,
422
+ use_cache,
423
+ next_decoder_cache,
424
+ )
425
+ )
426
+
427
+ if output_hidden_states:
428
+ if all_hidden_states is None:
429
+ all_hidden_states = (hidden_states,)
430
+ else:
431
+ all_hidden_states = (*all_hidden_states, hidden_states)
432
+
433
+ if not return_dict:
434
+ return tuple(
435
+ v
436
+ for v in [
437
+ hidden_states,
438
+ next_decoder_cache,
439
+ all_hidden_states,
440
+ all_self_attentions,
441
+ all_cross_attentions,
442
+ ]
443
+ if v is not None
444
+ )
445
+ return BaseModelOutputWithPastAndCrossAttentions(
446
+ last_hidden_state=hidden_states,
447
+ past_key_values=next_decoder_cache,
448
+ hidden_states=all_hidden_states,
449
+ attentions=all_self_attentions,
450
+ cross_attentions=all_cross_attentions,
451
+ )
452
+
453
+
454
+ class BertModel(BertPreTrainedModel):
455
+ """BERT model for encoding and decoding.
456
+
457
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
458
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
459
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
460
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
461
+
462
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
463
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
464
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
465
+ """
466
+
467
+ config_class = TEBertConfig
468
+
469
+ # TODO(@jomitchell) Can start swapping layers here for TE layers.
470
+ _no_split_modules: ClassVar[List[str]] = ["BertEmbeddings", "BertLayer", "TEBertLayer"]
471
+
472
+ def __init__(self, config, add_pooling_layer=True):
473
+ """Initialize the BertModel.
474
+
475
+ Args:
476
+ config: Configuration object containing model parameters.
477
+ add_pooling_layer: Whether to add a pooling layer on top of the encoder.
478
+ """
479
+ super().__init__(config)
480
+ self.config = config
481
+
482
+ self.embeddings = BertEmbeddings(config)
483
+ self.encoder = BertEncoder(config)
484
+
485
+ self.pooler = BertPooler(config) if add_pooling_layer else None
486
+
487
+ self.attn_implementation = config._attn_implementation
488
+ self.position_embedding_type = config.position_embedding_type
489
+
490
+ # Initialize weights and apply final processing
491
+ self.post_init()
492
+
493
+ def get_input_embeddings(self):
494
+ """Get the input embeddings."""
495
+ return self.embeddings.word_embeddings
496
+
497
+ def set_input_embeddings(self, value):
498
+ """Set the input embeddings."""
499
+ self.embeddings.word_embeddings = value
500
+
501
+ def _prune_heads(self, heads_to_prune):
502
+ """Prunes heads of the model.
503
+
504
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
505
+ class PreTrainedModel.
506
+ """
507
+ for layer, heads in heads_to_prune.items():
508
+ self.encoder.layer[layer].attention.prune_heads(heads)
509
+
510
+ def _validate_and_prepare_inputs(
511
+ self,
512
+ input_ids,
513
+ inputs_embeds,
514
+ attention_mask,
515
+ token_type_ids,
516
+ position_ids,
517
+ past_key_values,
518
+ ):
519
+ """Validate inputs and prepare basic input data."""
520
+ if input_ids is not None and inputs_embeds is not None:
521
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
522
+ elif input_ids is not None:
523
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
524
+ input_shape = input_ids.size()
525
+ elif inputs_embeds is not None:
526
+ input_shape = inputs_embeds.size()[:-1]
527
+ else:
528
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
529
+
530
+ batch_size, seq_length = input_shape
531
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
532
+
533
+ # past_key_values_length
534
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
535
+
536
+ if token_type_ids is None:
537
+ if hasattr(self.embeddings, "token_type_ids"):
538
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
539
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
540
+ token_type_ids = buffered_token_type_ids_expanded
541
+ else:
542
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
543
+
544
+ embedding_output = self.embeddings(
545
+ input_ids=input_ids,
546
+ position_ids=position_ids,
547
+ token_type_ids=token_type_ids,
548
+ inputs_embeds=inputs_embeds,
549
+ past_key_values_length=past_key_values_length,
550
+ )
551
+
552
+ if attention_mask is None:
553
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
554
+
555
+ return (
556
+ input_shape,
557
+ batch_size,
558
+ seq_length,
559
+ device,
560
+ past_key_values_length,
561
+ token_type_ids,
562
+ embedding_output,
563
+ attention_mask,
564
+ )
565
+
566
+ def _prepare_attention_masks(
567
+ self,
568
+ attention_mask,
569
+ input_shape,
570
+ embedding_output,
571
+ past_key_values_length,
572
+ seq_length,
573
+ device,
574
+ head_mask,
575
+ output_attentions,
576
+ encoder_hidden_states,
577
+ encoder_attention_mask,
578
+ ):
579
+ """Prepare attention masks for the forward pass."""
580
+ use_sdpa_attention_masks = (
581
+ self.attn_implementation == "sdpa"
582
+ and self.position_embedding_type == "absolute"
583
+ and head_mask is None
584
+ and not output_attentions
585
+ )
586
+
587
+ # Expand the attention mask
588
+ if use_sdpa_attention_masks and attention_mask.dim() == 2:
589
+ # Expand the attention mask for SDPA.
590
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
591
+ if self.config.is_decoder:
592
+ extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
593
+ attention_mask,
594
+ input_shape,
595
+ embedding_output,
596
+ past_key_values_length,
597
+ )
598
+ else:
599
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
600
+ attention_mask, embedding_output.dtype, tgt_len=seq_length
601
+ )
602
+ else:
603
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
604
+ # ourselves in which case we just need to make it broadcastable to all heads.
605
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
606
+
607
+ # If a 2D or 3D attention mask is provided for the cross-attention
608
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
609
+ if self.config.is_decoder and encoder_hidden_states is not None:
610
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
611
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
612
+ if encoder_attention_mask is None:
613
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
614
+
615
+ if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
616
+ # Expand the attention mask for SDPA.
617
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
618
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
619
+ encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
620
+ )
621
+ else:
622
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
623
+ else:
624
+ encoder_extended_attention_mask = None
625
+
626
+ return extended_attention_mask, encoder_extended_attention_mask
627
+
628
+ def _prepare_inputs_and_masks(
629
+ self,
630
+ input_ids,
631
+ inputs_embeds,
632
+ attention_mask,
633
+ token_type_ids,
634
+ position_ids,
635
+ head_mask,
636
+ past_key_values,
637
+ encoder_hidden_states,
638
+ encoder_attention_mask,
639
+ output_attentions,
640
+ output_hidden_states,
641
+ use_cache,
642
+ return_dict,
643
+ ):
644
+ """Prepare inputs and attention masks for the forward pass."""
645
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
646
+ output_hidden_states = (
647
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
648
+ )
649
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
650
+
651
+ if self.config.is_decoder:
652
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
653
+ else:
654
+ use_cache = False
655
+
656
+ (
657
+ input_shape,
658
+ batch_size,
659
+ seq_length,
660
+ device,
661
+ past_key_values_length,
662
+ token_type_ids,
663
+ embedding_output,
664
+ attention_mask,
665
+ ) = self._validate_and_prepare_inputs(
666
+ input_ids,
667
+ inputs_embeds,
668
+ attention_mask,
669
+ token_type_ids,
670
+ position_ids,
671
+ past_key_values,
672
+ )
673
+
674
+ extended_attention_mask, encoder_extended_attention_mask = self._prepare_attention_masks(
675
+ attention_mask,
676
+ input_shape,
677
+ embedding_output,
678
+ past_key_values_length,
679
+ seq_length,
680
+ device,
681
+ head_mask,
682
+ output_attentions,
683
+ encoder_hidden_states,
684
+ encoder_attention_mask,
685
+ )
686
+
687
+ # Prepare head mask if needed
688
+ # 1.0 in head_mask indicate we keep the head
689
+ # attention_probs has shape bsz x n_heads x N x N
690
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
691
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
692
+ processed_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
693
+
694
+ return (
695
+ embedding_output,
696
+ extended_attention_mask,
697
+ processed_head_mask,
698
+ encoder_extended_attention_mask,
699
+ use_cache,
700
+ return_dict,
701
+ )
702
+
703
+ def forward(
704
+ self,
705
+ input_ids: Optional[torch.Tensor] = None,
706
+ attention_mask: Optional[torch.Tensor] = None,
707
+ token_type_ids: Optional[torch.Tensor] = None,
708
+ position_ids: Optional[torch.Tensor] = None,
709
+ head_mask: Optional[torch.Tensor] = None,
710
+ inputs_embeds: Optional[torch.Tensor] = None,
711
+ encoder_hidden_states: Optional[torch.Tensor] = None,
712
+ encoder_attention_mask: Optional[torch.Tensor] = None,
713
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
714
+ use_cache: Optional[bool] = None,
715
+ output_attentions: Optional[bool] = None,
716
+ output_hidden_states: Optional[bool] = None,
717
+ return_dict: Optional[bool] = None,
718
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
719
+ r"""Forward pass of the BertModel.
720
+
721
+ Args:
722
+ input_ids (`torch.Tensor`, *optional*): Input token IDs.
723
+ attention_mask (`torch.Tensor`, *optional*): Attention mask.
724
+ token_type_ids (`torch.Tensor`, *optional*): Token type IDs.
725
+ position_ids (`torch.Tensor`, *optional*): Position IDs.
726
+ head_mask (`torch.Tensor`, *optional*): Head mask.
727
+ inputs_embeds (`torch.Tensor`, *optional*): Input embeddings.
728
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
729
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
730
+ the model is configured as a decoder.
731
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
732
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
733
+ the model is configured as a decoder.
734
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
735
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
736
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
737
+
738
+ - 1 for tokens that are **not masked**,
739
+ - 0 for tokens that are **masked**.
740
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
741
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding (see
742
+ `past_key_values`).
743
+ use_cache (`bool`, *optional*):
744
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
745
+ `past_key_values`).
746
+ output_attentions (`bool`, *optional*): Whether to output attentions.
747
+ output_hidden_states (`bool`, *optional*): Whether to output hidden states.
748
+ return_dict (`bool`, *optional*): Whether to return a ModelOutput instead of a tuple.
749
+ """
750
+ (
751
+ embedding_output,
752
+ extended_attention_mask,
753
+ processed_head_mask,
754
+ encoder_extended_attention_mask,
755
+ use_cache,
756
+ return_dict,
757
+ ) = self._prepare_inputs_and_masks(
758
+ input_ids,
759
+ inputs_embeds,
760
+ attention_mask,
761
+ token_type_ids,
762
+ position_ids,
763
+ head_mask,
764
+ past_key_values,
765
+ encoder_hidden_states,
766
+ encoder_attention_mask,
767
+ output_attentions,
768
+ output_hidden_states,
769
+ use_cache,
770
+ return_dict,
771
+ )
772
+
773
+ encoder_outputs = self.encoder(
774
+ embedding_output,
775
+ attention_mask=extended_attention_mask,
776
+ head_mask=processed_head_mask,
777
+ encoder_hidden_states=encoder_hidden_states,
778
+ encoder_attention_mask=encoder_extended_attention_mask,
779
+ past_key_values=past_key_values,
780
+ use_cache=use_cache,
781
+ output_attentions=output_attentions,
782
+ output_hidden_states=output_hidden_states,
783
+ return_dict=return_dict,
784
+ )
785
+ sequence_output = encoder_outputs[0]
786
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
787
+
788
+ if not return_dict:
789
+ return (sequence_output, pooled_output, *encoder_outputs[1:])
790
+
791
+ return BaseModelOutputWithPoolingAndCrossAttentions(
792
+ last_hidden_state=sequence_output,
793
+ pooler_output=pooled_output,
794
+ past_key_values=encoder_outputs.past_key_values,
795
+ hidden_states=encoder_outputs.hidden_states,
796
+ attentions=encoder_outputs.attentions,
797
+ cross_attentions=encoder_outputs.cross_attentions,
798
+ )
799
+
800
+
801
+ class BertForMaskedLM(BertPreTrainedModel):
802
+ """BERT model for masked language modeling."""
803
+
804
+ config_class = TEBertConfig
805
+ _tied_weights_keys: ClassVar[List[str]] = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
806
+
807
+ def __init__(self, config):
808
+ """Initialize the BertForMaskedLM.
809
+
810
+ Args:
811
+ config: Configuration object containing model parameters.
812
+ """
813
+ super().__init__(config)
814
+
815
+ if config.is_decoder:
816
+ logger.warning(
817
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
818
+ "bi-directional self-attention."
819
+ )
820
+
821
+ self.bert = BertModel(config, add_pooling_layer=False)
822
+ self.cls = BertOnlyMLMHead(config)
823
+
824
+ # Initialize weights and apply final processing
825
+ self.post_init()
826
+
827
+ def get_output_embeddings(self):
828
+ """Get the output embeddings."""
829
+ return self.cls.predictions.decoder
830
+
831
+ def set_output_embeddings(self, new_embeddings):
832
+ """Set the output embeddings."""
833
+ self.cls.predictions.decoder = new_embeddings
834
+ self.cls.predictions.bias = new_embeddings.bias
835
+
836
+ def forward(
837
+ self,
838
+ input_ids: Optional[torch.Tensor] = None,
839
+ attention_mask: Optional[torch.Tensor] = None,
840
+ token_type_ids: Optional[torch.Tensor] = None,
841
+ position_ids: Optional[torch.Tensor] = None,
842
+ head_mask: Optional[torch.Tensor] = None,
843
+ inputs_embeds: Optional[torch.Tensor] = None,
844
+ encoder_hidden_states: Optional[torch.Tensor] = None,
845
+ encoder_attention_mask: Optional[torch.Tensor] = None,
846
+ labels: Optional[torch.Tensor] = None,
847
+ output_attentions: Optional[bool] = None,
848
+ output_hidden_states: Optional[bool] = None,
849
+ return_dict: Optional[bool] = None,
850
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
851
+ r"""Forward pass for masked language modeling.
852
+
853
+ Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
854
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
855
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
856
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
857
+ """
858
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
859
+
860
+ outputs = self.bert(
861
+ input_ids,
862
+ attention_mask=attention_mask,
863
+ token_type_ids=token_type_ids,
864
+ position_ids=position_ids,
865
+ head_mask=head_mask,
866
+ inputs_embeds=inputs_embeds,
867
+ encoder_hidden_states=encoder_hidden_states,
868
+ encoder_attention_mask=encoder_attention_mask,
869
+ output_attentions=output_attentions,
870
+ output_hidden_states=output_hidden_states,
871
+ return_dict=return_dict,
872
+ )
873
+
874
+ sequence_output = outputs[0]
875
+ prediction_scores = self.cls(sequence_output)
876
+
877
+ masked_lm_loss = None
878
+ if labels is not None:
879
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
880
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
881
+
882
+ if not return_dict:
883
+ output = (prediction_scores, *outputs[2:])
884
+ return (masked_lm_loss, *output) if masked_lm_loss is not None else output
885
+
886
+ return MaskedLMOutput(
887
+ loss=masked_lm_loss,
888
+ logits=prediction_scores,
889
+ hidden_states=outputs.hidden_states,
890
+ attentions=outputs.attentions,
891
+ )
892
+
893
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
894
+ """Prepare inputs for generation."""
895
+ input_shape = input_ids.shape
896
+ effective_batch_size = input_shape[0]
897
+
898
+ # add a dummy token
899
+ if self.config.pad_token_id is None:
900
+ raise ValueError("The PAD token should be defined for generation")
901
+
902
+ attention_mask = torch.cat(
903
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
904
+ dim=-1,
905
+ )
906
+ dummy_token = torch.full(
907
+ (effective_batch_size, 1),
908
+ self.config.pad_token_id,
909
+ dtype=torch.long,
910
+ device=input_ids.device,
911
+ )
912
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
913
+
914
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
915
+
916
+ @classmethod
917
+ def can_generate(cls) -> bool:
918
+ """Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`.
919
+
920
+ Even though it has a `prepare_inputs_for_generation` method.
921
+ """
922
+ return False
923
+
924
+
925
+ __all__ = [
926
+ "BertForMaskedLM",
927
+ "BertLayer",
928
+ "BertModel",
929
+ "TEBertLayer",
930
+ ]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0dfca398accc3b54ce8d9d1574bf92f393a9289426563cd68fdd2666dd74f09
3
+ size 67302248