upload TE checkpoint
Browse files- LICENSE +178 -0
- README.md +167 -3
- config.json +35 -0
- geneformer.py +930 -0
- model.safetensors +3 -0
LICENSE
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 12 |
+
the copyright owner that is granting the License.
|
| 13 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 14 |
+
other entities that control, are controlled by, or are under common
|
| 15 |
+
control with that entity. For the purposes of this definition,
|
| 16 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 17 |
+
direction or management of such entity, whether by contract or
|
| 18 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 19 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 20 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 21 |
+
exercising permissions granted by this License.
|
| 22 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 23 |
+
including but not limited to software source code, documentation
|
| 24 |
+
source, and configuration files.
|
| 25 |
+
"Object" form shall mean any form resulting from mechanical
|
| 26 |
+
transformation or translation of a Source form, including but
|
| 27 |
+
not limited to compiled object code, generated documentation,
|
| 28 |
+
and conversions to other media types.
|
| 29 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 30 |
+
Object form, made available under the License, as indicated by a
|
| 31 |
+
copyright notice that is included in or attached to the work
|
| 32 |
+
(an example is provided in the Appendix below).
|
| 33 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 34 |
+
form, that is based on (or derived from) the Work and for which the
|
| 35 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 36 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 37 |
+
of this License, Derivative Works shall not include works that remain
|
| 38 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 39 |
+
the Work and Derivative Works thereof.
|
| 40 |
+
"Contribution" shall mean any work of authorship, including
|
| 41 |
+
the original version of the Work and any modifications or additions
|
| 42 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 43 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 44 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 45 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 46 |
+
means any form of electronic, verbal, or written communication sent
|
| 47 |
+
to the Licensor or its representatives, including but not limited to
|
| 48 |
+
communication on electronic mailing lists, source code control systems,
|
| 49 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 50 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 51 |
+
excluding communication that is conspicuously marked or otherwise
|
| 52 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 53 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 54 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 55 |
+
subsequently incorporated within the Work.
|
| 56 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 57 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 58 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 59 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 60 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 61 |
+
Work and such Derivative Works in Source or Object form.
|
| 62 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 63 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 64 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 65 |
+
(except as stated in this section) patent license to make, have made,
|
| 66 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 67 |
+
where such license applies only to those patent claims licensable
|
| 68 |
+
by such Contributor that are necessarily infringed by their
|
| 69 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 70 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 71 |
+
institute patent litigation against any entity (including a
|
| 72 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 73 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 74 |
+
or contributory patent infringement, then any patent licenses
|
| 75 |
+
granted to You under this License for that Work shall terminate
|
| 76 |
+
as of the date such litigation is filed.
|
| 77 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 78 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 79 |
+
modifications, and in Source or Object form, provided that You
|
| 80 |
+
meet the following conditions:
|
| 81 |
+
(a) You must give any other recipients of the Work or
|
| 82 |
+
Derivative Works a copy of this License; and
|
| 83 |
+
(b) You must cause any modified files to carry prominent notices
|
| 84 |
+
stating that You changed the files; and
|
| 85 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 86 |
+
that You distribute, all copyright, patent, trademark, and
|
| 87 |
+
attribution notices from the Source form of the Work,
|
| 88 |
+
excluding those notices that do not pertain to any part of
|
| 89 |
+
the Derivative Works; and
|
| 90 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 91 |
+
distribution, then any Derivative Works that You distribute must
|
| 92 |
+
include a readable copy of the attribution notices contained
|
| 93 |
+
within such NOTICE file, excluding those notices that do not
|
| 94 |
+
pertain to any part of the Derivative Works, in at least one
|
| 95 |
+
of the following places: within a NOTICE text file distributed
|
| 96 |
+
as part of the Derivative Works; within the Source form or
|
| 97 |
+
documentation, if provided along with the Derivative Works; or,
|
| 98 |
+
within a display generated by the Derivative Works, if and
|
| 99 |
+
wherever such third-party notices normally appear. The contents
|
| 100 |
+
of the NOTICE file are for informational purposes only and
|
| 101 |
+
do not modify the License. You may add Your own attribution
|
| 102 |
+
notices within Derivative Works that You distribute, alongside
|
| 103 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 104 |
+
that such additional attribution notices cannot be construed
|
| 105 |
+
as modifying the License.
|
| 106 |
+
You may add Your own copyright statement to Your modifications and
|
| 107 |
+
may provide additional or different license terms and conditions
|
| 108 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 109 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 110 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 111 |
+
the conditions stated in this License.
|
| 112 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 113 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 114 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 115 |
+
this License, without any additional terms or conditions.
|
| 116 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 117 |
+
the terms of any separate license agreement you may have executed
|
| 118 |
+
with Licensor regarding such Contributions.
|
| 119 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 120 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 121 |
+
except as required for reasonable and customary use in describing the
|
| 122 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 123 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 124 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 125 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 126 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 127 |
+
implied, including, without limitation, any warranties or conditions
|
| 128 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 129 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 130 |
+
appropriateness of using or redistributing the Work and assume any
|
| 131 |
+
risks associated with Your exercise of permissions under this License.
|
| 132 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 133 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 134 |
+
unless required by applicable law (such as deliberate and grossly
|
| 135 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 136 |
+
liable to You for damages, including any direct, indirect, special,
|
| 137 |
+
incidental, or consequential damages of any character arising as a
|
| 138 |
+
result of this License or out of the use or inability to use the
|
| 139 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 140 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 141 |
+
other commercial damages or losses), even if such Contributor
|
| 142 |
+
has been advised of the possibility of such damages.
|
| 143 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 144 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 145 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 146 |
+
or other liability obligations and/or rights consistent with this
|
| 147 |
+
License. However, in accepting such obligations, You may act only
|
| 148 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 149 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 150 |
+
defend, and hold each Contributor harmless for any liability
|
| 151 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 152 |
+
of your accepting any such warranty or additional liability.
|
| 153 |
+
END OF TERMS AND CONDITIONS
|
| 154 |
+
|
| 155 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 156 |
+
|
| 157 |
+
To apply the Apache License to your work, attach the following
|
| 158 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 159 |
+
replaced with your own identifying information. (Don't include
|
| 160 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 161 |
+
comment syntax for the file format. We also recommend that a
|
| 162 |
+
file or class name and description of purpose be included on the
|
| 163 |
+
same "printed page" as the copyright notice for easier
|
| 164 |
+
identification within third-party archives.
|
| 165 |
+
|
| 166 |
+
Copyright 2022 Theodoris Lab, Gladstone Institute and The HuggingFace Inc. team. All rights reserved.
|
| 167 |
+
Copyright 2025 NVIDIA CORPORATION. All rights reserved.
|
| 168 |
+
|
| 169 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 170 |
+
you may not use this file except in compliance with the License.
|
| 171 |
+
You may obtain a copy of the License at
|
| 172 |
+
|
| 173 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 174 |
+
Unless required by applicable law or agreed to in writing, software
|
| 175 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 176 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 177 |
+
See the License for the specific language governing permissions and
|
| 178 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,167 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
datasets: ctheodoris/Genecorpus-30M
|
| 3 |
+
library_name: transformers
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
tags:
|
| 6 |
+
- single-cell
|
| 7 |
+
- genomics
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Geneformer-10M (TransformerEngine-Optimized) Overview
|
| 11 |
+
|
| 12 |
+
## Description:
|
| 13 |
+
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology.
|
| 14 |
+
|
| 15 |
+
This version of the Geneformer model is optimized with NVIDIA's [TransformerEngine](https://github.com/NVIDIA/TransformerEngine) library. It is based on the original Geneformer V1 model, and (within numerical precision) has identical weights and outputs.
|
| 16 |
+
|
| 17 |
+
This model is ready for commercial/non-commercial use.
|
| 18 |
+
|
| 19 |
+
## Third-Party Community Consideration
|
| 20 |
+
This model is not owned or developed by NVIDIA. This model has been developed and built to a third-party's requirements for this application and use case; see link to Non-NVIDIA Model Card [Geneformer Model Card](https://huggingface.co/ctheodoris/Geneformer).
|
| 21 |
+
|
| 22 |
+
### License/Terms of Use:
|
| 23 |
+
Geneformer is licensed under the [Apache 2.0 license](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md).
|
| 24 |
+
|
| 25 |
+
### Deployment Geography:
|
| 26 |
+
Global
|
| 27 |
+
|
| 28 |
+
### Use Case:
|
| 29 |
+
Network biology and therapeutic discovery, particularly in data-limited settings such as rare diseases or diseases affecting hard-to-access tissues.
|
| 30 |
+
|
| 31 |
+
### Release Date:
|
| 32 |
+
Hugging Face 12/19/2025 via [https://huggingface.co/nvidia/geneformer_V1_10M](https://huggingface.co/nvidia/geneformer_V1_10M)
|
| 33 |
+
|
| 34 |
+
## Reference(s):
|
| 35 |
+
* [Transfer learning enables predictions in network biology](https://www.nature.com/articles/s41586-023-06139-9.epdf?sharing_token=u_5LUGVkd3A8zR-f73lU59RgN0jAjWel9jnR3ZoTv0N2UB4yyXENUK50s6uqjXH69sDxh4Z3J4plYCKlVME-W2WSuRiS96vx6t5ex2-krVDS46JkoVvAvJyWtYXIyj74pDWn_DutZq1oAlDaxfvBpUfSKDdBPJ8SKlTId8uT47M%3D) - details of the original model trained on ~30 million transcriptomes in June 2021 and the initial report of the in silico perturbation and cell and gene classification strategies.
|
| 36 |
+
* [Quantized multi-task learning for context-specific representations of gene network dynamics](https://www.biorxiv.org/content/10.1101/2024.08.16.608180v1.full.pdf) - the expanded model, trained on ~104 million transcriptomes, and continual learning, multitask learning, and quantization strategies.
|
| 37 |
+
* See [geneformer.readthedocs.io](https://geneformer.readthedocs.io/) for documentation.
|
| 38 |
+
|
| 39 |
+
## Model Architecture:
|
| 40 |
+
**Architecture Type:** Transformer
|
| 41 |
+
**Network Architecture:** BERT
|
| 42 |
+
|
| 43 |
+
**This model was developed based on:** [Geneformer](https://huggingface.co/ctheodoris/Geneformer) <br>
|
| 44 |
+
**Number of model parameters:** 1 x 10^7
|
| 45 |
+
|
| 46 |
+
## Input:
|
| 47 |
+
**Input Type:** Number (Row represents cell, containing gene names and single cell expression counts) <br>
|
| 48 |
+
**Input Format:** Array [AnnData](https://anndata.readthedocs.io/en/latest/) <br>
|
| 49 |
+
**Input Parameters:** One-Dimensional (1D) <br>
|
| 50 |
+
**Other Properties Related to Input:** This model supports a context length of 2048.
|
| 51 |
+
|
| 52 |
+
## Output:
|
| 53 |
+
**Output Type:** Dense Embedding Predictions <br>
|
| 54 |
+
**Output Format:** Vector <br>
|
| 55 |
+
**Output Parameters:** One-Dimensional (1D) <br>
|
| 56 |
+
**Other Properties Related to Output:** Numeric floating point vector (fp16, bf16, or fp32); Geneformer-10M outputs 256 dimensional embeddings.
|
| 57 |
+
|
| 58 |
+
Our AI models are designed and/or optimized to run on NVIDIA GPU-accelerated systems. By leveraging NVIDIA’s hardware (e.g. GPU cores) and software frameworks (e.g., CUDA libraries), the model achieves faster training and inference times compared to CPU-only solutions.
|
| 59 |
+
|
| 60 |
+
## Software Integration:
|
| 61 |
+
**Runtime Engine(s):**
|
| 62 |
+
* Transformer Engine
|
| 63 |
+
* PyTorch
|
| 64 |
+
|
| 65 |
+
**Supported Hardware Microarchitecture Compatibility:**
|
| 66 |
+
* A100
|
| 67 |
+
* H100
|
| 68 |
+
* H200
|
| 69 |
+
* GB200
|
| 70 |
+
|
| 71 |
+
**Preferred/Supported Operating System(s):**
|
| 72 |
+
* Linux
|
| 73 |
+
|
| 74 |
+
The integration of foundation and fine-tuned models into AI systems requires additional testing using use-case-specific data to ensure safe and effective deployment. Following the V-model methodology, iterative testing and validation at both unit and system levels are essential to mitigate risks, meet technical and functional requirements, and ensure compliance with safety and ethical standards before deployment.
|
| 75 |
+
|
| 76 |
+
## Model Version(s):
|
| 77 |
+
* Geneformer-V1-10M
|
| 78 |
+
* Geneformer-V2-104M
|
| 79 |
+
* Geneformer-V2-316M
|
| 80 |
+
* Geneformer-V2-104M_CLcancer
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
## Training and Evaluation Datasets:
|
| 84 |
+
|
| 85 |
+
## Training Datasets:
|
| 86 |
+
**Link:** [Genecorpus-30M](https://huggingface.co/datasets/ctheodoris/Genecorpus-30M)
|
| 87 |
+
|
| 88 |
+
**Data Modality:**
|
| 89 |
+
* Text (Human single-cell transcriptomes)
|
| 90 |
+
|
| 91 |
+
**Text Training Data Size:**
|
| 92 |
+
* 1 Billion to 10 Trillion Tokens
|
| 93 |
+
|
| 94 |
+
**Data Collection Method by dataset:**
|
| 95 |
+
* Human
|
| 96 |
+
|
| 97 |
+
**Labeling Method by dataset:**
|
| 98 |
+
* N/A
|
| 99 |
+
|
| 100 |
+
**Properties:** The single-cell transcriptomes were assembled from a broad range of publicly available data sources. The researchers collected raw counts from sources like NCBI Gene Expression Omnibus (GEO), Human Cell Atlas, and Tumor Immune Single-cell Hub (TISCH), among others. They excluded cells with high mutational burdens, such as malignant cells and immortalized cell lines, and included only droplet-based sequencing platforms to ensure data comparability. The raw data was then converted into a uniform loom HDF5 file format.
|
| 101 |
+
|
| 102 |
+
## Evaluation Datasets:
|
| 103 |
+
**Link:** [A cross-disorder dosage sensitivity map of the human genome](https://zenodo.org/records/6347673)
|
| 104 |
+
|
| 105 |
+
**Data Collection Method by dataset:**
|
| 106 |
+
* Human
|
| 107 |
+
|
| 108 |
+
**Labeling Method by dataset:**
|
| 109 |
+
* Not Applicable <!-- there are no labels for this dataset -->
|
| 110 |
+
|
| 111 |
+
**Properties:** The data was collected by harmonizing and meta-analyzing rare copy-number variants (rCNVs) from nearly one million individuals across 54 different disorders. This approach created a genome-wide catalog of dosage sensitivity.
|
| 112 |
+
|
| 113 |
+
**Link:** [Single-cell Transcriptome Analysis Reveals Dynamic Cell Populations and Differential Gene Expression Patterns in Control and Aneurysmal Human Aortic Tissue](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE155468)
|
| 114 |
+
|
| 115 |
+
**Data Collection Method by dataset:**
|
| 116 |
+
* Human
|
| 117 |
+
|
| 118 |
+
**Labeling Method by dataset:**
|
| 119 |
+
* Human
|
| 120 |
+
|
| 121 |
+
**Properties:** The data was collected by performing single-cell RNA sequencing (scRNA-seq) on human ascending aortic tissues. Tissues were obtained from 11 study participants, consisting of 8 patients with ascending thoracic aortic aneurysm (ATAA) and 3 control subjects.
|
| 122 |
+
|
| 123 |
+
**Link:** [Systematic Comparison of High-throughput Single-Cell and Single-Nucleus Transcriptomes during Cardiomyocyte Differentiation](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE129096)
|
| 124 |
+
|
| 125 |
+
**Data Collection Method by dataset:**
|
| 126 |
+
* Automated
|
| 127 |
+
|
| 128 |
+
**Labeling Method by dataset:**
|
| 129 |
+
* Human
|
| 130 |
+
|
| 131 |
+
**Properties:** The researchers used two different sequencing platforms to collect data from the same biological process: induced pluripotent stem cell (iPSC) differentiation into cardiomyocytes. The two platforms used were Drop-seq (single-cell) and DroNc-seq (single-nucleus). The study involved two iPSC lines and collected data over a 15-day time period.
|
| 132 |
+
|
| 133 |
+
**Link:** [A human cell atlas of fetal gene expression](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE156793)
|
| 134 |
+
|
| 135 |
+
**Data Collection Method by dataset:**
|
| 136 |
+
* Human
|
| 137 |
+
|
| 138 |
+
**Labeling Method by dataset:**
|
| 139 |
+
* Hybrid: Human, Automated
|
| 140 |
+
|
| 141 |
+
**Properties:** The data was collected by profiling the gene expression of millions of single cells from 15 different human fetal organs.
|
| 142 |
+
|
| 143 |
+
**Link:** [Single-nuclei profiling of human dilated and hypertrophic cardiomyopathy](https://singlecell.broadinstitute.org/single_cell/study/SCP1303/single-nuclei-profiling-of-human-dilated-and-hypertrophic-cardiomyopathy#study-summary)
|
| 144 |
+
|
| 145 |
+
**Data Collection Method by dataset:**
|
| 146 |
+
* Human
|
| 147 |
+
|
| 148 |
+
**Labeling Method by dataset:**
|
| 149 |
+
* Hybrid: Human, Automated
|
| 150 |
+
|
| 151 |
+
**Properties:** The data was collected by performing single-nucleus RNA sequencing (snRNA-seq) on left ventricle samples from human hearts. The study included samples from 11 hearts with dilated cardiomyopathy, 15 hearts with hypertrophic cardiomyopathy, and 16 non-failing hearts. In total, nearly 600,000 nuclei were sequenced.
|
| 152 |
+
|
| 153 |
+
## Inference:
|
| 154 |
+
**Acceleration Engine:** Transformer Engine, PyTorch
|
| 155 |
+
|
| 156 |
+
**Test Hardware:**
|
| 157 |
+
* A100
|
| 158 |
+
* H100
|
| 159 |
+
* H200
|
| 160 |
+
* GB200
|
| 161 |
+
|
| 162 |
+
## Ethical Considerations:
|
| 163 |
+
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
|
| 164 |
+
|
| 165 |
+
Users are responsible for ensuring the physical properties of model-generated molecules are appropriately evaluated and comply with applicable safety regulations and ethical standards.
|
| 166 |
+
|
| 167 |
+
Please report model quality, risk, security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
|
config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_attn_implementation_autoset": true,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertForMaskedLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.02,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "geneformer.TEBertConfig",
|
| 9 |
+
"AutoModel": "geneformer.BertModel",
|
| 10 |
+
"AutoModelForMaskedLM": "geneformer.BertForMaskedLM"
|
| 11 |
+
},
|
| 12 |
+
"classifier_dropout": null,
|
| 13 |
+
"framework": "pytorch",
|
| 14 |
+
"fuse_qkv_params": true,
|
| 15 |
+
"gradient_checkpointing": false,
|
| 16 |
+
"hidden_act": "relu",
|
| 17 |
+
"hidden_dropout_prob": 0.02,
|
| 18 |
+
"hidden_size": 256,
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"intermediate_size": 512,
|
| 21 |
+
"layer_norm_eps": 1e-12,
|
| 22 |
+
"max_position_embeddings": 2048,
|
| 23 |
+
"micro_batch_size": null,
|
| 24 |
+
"model_type": "bert",
|
| 25 |
+
"num_attention_heads": 4,
|
| 26 |
+
"num_hidden_layers": 6,
|
| 27 |
+
"pad_token_id": 0,
|
| 28 |
+
"position_embedding_type": "absolute",
|
| 29 |
+
"torch_dtype": "float32",
|
| 30 |
+
"transformers_version": "4.51.3",
|
| 31 |
+
"type_vocab_size": 2,
|
| 32 |
+
"use_cache": true,
|
| 33 |
+
"use_te_layers": true,
|
| 34 |
+
"vocab_size": 25426
|
| 35 |
+
}
|
geneformer.py
ADDED
|
@@ -0,0 +1,930 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-Apache2
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# coding=utf-8
|
| 17 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 18 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 19 |
+
#
|
| 20 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 21 |
+
# you may not use this file except in compliance with the License.
|
| 22 |
+
# You may obtain a copy of the License at
|
| 23 |
+
#
|
| 24 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 25 |
+
#
|
| 26 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 27 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 28 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 29 |
+
# See the License for the specific language governing permissions and
|
| 30 |
+
# limitations under the License.
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
"""PyTorch BERT model with and without transformer engine layers.
|
| 34 |
+
|
| 35 |
+
This file is a modified version of the BERT model from the Hugging Face Transformers library.
|
| 36 |
+
It includes a custom BERT encoder that can be used with or without transformer engine layers.
|
| 37 |
+
|
| 38 |
+
The BERT encoder is a modified version of the encoder from the Hugging Face Transformers library.
|
| 39 |
+
It includes a custom BERT layer that can be used with or without transformer engine layers.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
from typing import ClassVar, List, Optional, Tuple, Union
|
| 43 |
+
|
| 44 |
+
import torch
|
| 45 |
+
import transformer_engine.pytorch as te
|
| 46 |
+
from torch import nn
|
| 47 |
+
from torch.nn import CrossEntropyLoss
|
| 48 |
+
from transformers.modeling_attn_mask_utils import (
|
| 49 |
+
_prepare_4d_attention_mask_for_sdpa,
|
| 50 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 51 |
+
)
|
| 52 |
+
from transformers.modeling_outputs import (
|
| 53 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 54 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 55 |
+
MaskedLMOutput,
|
| 56 |
+
)
|
| 57 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
| 58 |
+
from transformers.models.bert.modeling_bert import (
|
| 59 |
+
BertEmbeddings,
|
| 60 |
+
BertLayer,
|
| 61 |
+
BertOnlyMLMHead,
|
| 62 |
+
BertPooler,
|
| 63 |
+
BertPreTrainedModel,
|
| 64 |
+
)
|
| 65 |
+
from transformers.utils import logging
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
logger = logging.get_logger(__name__)
|
| 69 |
+
|
| 70 |
+
_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
|
| 71 |
+
_CONFIG_FOR_DOC = "BertConfig"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TEBertConfig(BertConfig):
|
| 75 |
+
"""Configuration class for the TE BERT model.
|
| 76 |
+
|
| 77 |
+
This class is a subclass of BertConfig, and it adds the following attributes:
|
| 78 |
+
- torch_dtype: The dtype of the model parameters.
|
| 79 |
+
- use_te_layers: Whether to use the TE layers.
|
| 80 |
+
- micro_batch_size: The micro batch size for TE layers.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, **kwargs):
|
| 84 |
+
"""Initialize the TEBertConfig.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
**kwargs: Additional keyword arguments to pass to BertConfig.
|
| 88 |
+
"""
|
| 89 |
+
super().__init__(**kwargs)
|
| 90 |
+
# TODO(@jomitchell): Fix this in JIRA BIONEMO-2406
|
| 91 |
+
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
|
| 92 |
+
# Convert string dtype to torch dtype if needed
|
| 93 |
+
if isinstance(torch_dtype, str):
|
| 94 |
+
if torch_dtype == "bfloat16":
|
| 95 |
+
torch_dtype = torch.bfloat16
|
| 96 |
+
elif torch_dtype == "float16":
|
| 97 |
+
torch_dtype = torch.float16
|
| 98 |
+
elif torch_dtype == "float32":
|
| 99 |
+
torch_dtype = torch.float32
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unsupported dtype: {torch_dtype}")
|
| 102 |
+
self.torch_dtype = torch_dtype
|
| 103 |
+
self.use_te_layers = kwargs.get("use_te_layers", False)
|
| 104 |
+
self.micro_batch_size = kwargs.get("micro_batch_size", None)
|
| 105 |
+
self.fuse_qkv_params = kwargs.get("fuse_qkv_params", False)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class TEBertLayer(nn.Module):
|
| 109 |
+
"""Custom BERT layer using individual TE components for correct post-norm architecture.
|
| 110 |
+
|
| 111 |
+
This builds a BERT-style post-norm layer using:
|
| 112 |
+
- te.MultiheadAttention (with input_layernorm=False)
|
| 113 |
+
- te.LayerNorm for post-attention normalization as layernorm
|
| 114 |
+
- te.Linear for MLP layers (fc1, fc2) wrapped in layernorm_mlp module
|
| 115 |
+
- te.LayerNorm for post-MLP normalization as layernorm_mlp.layer_norm
|
| 116 |
+
|
| 117 |
+
Parameter naming matches convert.py expectations for weight loading from HF checkpoints.
|
| 118 |
+
|
| 119 |
+
DIVERGENCE FROM TYPICAL TRANSFORMERLAYER:
|
| 120 |
+
This implementation uses POST-norm architecture, which differs significantly from the
|
| 121 |
+
typical TransformerLayer that uses PRE-norm.
|
| 122 |
+
|
| 123 |
+
Geneformer/HF BERT (POST-norm, output_layernorm=True equivalent):
|
| 124 |
+
Input -> Attention -> Dropout -> Residual Add -> LayerNorm
|
| 125 |
+
-> MLP -> Dropout -> Residual Add -> LayerNorm -> Output
|
| 126 |
+
|
| 127 |
+
Typical TransformerLayer (PRE-norm, output_layernorm=False default):
|
| 128 |
+
Input -> [LayerNorm Attn inside MultiheadAttention] -> Dropout -> Residual Add
|
| 129 |
+
-> [LayerNorm MLP inside LayerNormMLP] -> Dropout -> Residual Add -> Output
|
| 130 |
+
|
| 131 |
+
Geneformer applies LayerNorm AFTER residual connections as
|
| 132 |
+
explicit separate modules, whereas typical TransformerLayer applies LayerNorm Before
|
| 133 |
+
operations via input_layernorm=True inside MultiheadAttention and LayerNormMLP modules.
|
| 134 |
+
|
| 135 |
+
For more information, see:
|
| 136 |
+
https://github.com/NVIDIA/TransformerEngine/blob/dd9433e7ad28c12f27da9770be54c9c584e85fa0/transformer_engine/pytorch/transformer.py#L822
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, config, layer_number=None):
|
| 140 |
+
"""Initialize the TEBertLayer.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
config: Configuration object containing model parameters.
|
| 144 |
+
layer_number: Optional layer number for identification.
|
| 145 |
+
"""
|
| 146 |
+
super().__init__()
|
| 147 |
+
|
| 148 |
+
self.hidden_size = config.hidden_size
|
| 149 |
+
self.num_attention_heads = config.num_attention_heads
|
| 150 |
+
self.layer_number = layer_number
|
| 151 |
+
self.is_decoder = config.is_decoder
|
| 152 |
+
self.add_cross_attention = config.add_cross_attention
|
| 153 |
+
|
| 154 |
+
# Self-attention using TE MultiheadAttention
|
| 155 |
+
self.self_attention = te.MultiheadAttention(
|
| 156 |
+
hidden_size=config.hidden_size,
|
| 157 |
+
num_attention_heads=config.num_attention_heads,
|
| 158 |
+
num_gqa_groups=config.num_attention_heads,
|
| 159 |
+
attention_dropout=config.attention_probs_dropout_prob,
|
| 160 |
+
input_layernorm=False, # No LayerNorm before attention
|
| 161 |
+
attention_type="self",
|
| 162 |
+
layer_number=layer_number,
|
| 163 |
+
attn_mask_type="padding",
|
| 164 |
+
params_dtype=config.torch_dtype,
|
| 165 |
+
fuse_qkv_params=getattr(config, "fuse_qkv_params", False),
|
| 166 |
+
window_size=(-1, -1), # No sliding window attention
|
| 167 |
+
qkv_format="bshd", # BERT uses [batch, seq, head, dim]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Post-attention TE LayerNorm
|
| 171 |
+
self.layernorm = te.LayerNorm(
|
| 172 |
+
normalized_shape=config.hidden_size,
|
| 173 |
+
eps=config.layer_norm_eps,
|
| 174 |
+
params_dtype=config.torch_dtype,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# MLP using TE Linear layers
|
| 178 |
+
self.layernorm_mlp = nn.Module()
|
| 179 |
+
self.layernorm_mlp.fc1 = te.Linear(
|
| 180 |
+
config.hidden_size,
|
| 181 |
+
config.intermediate_size,
|
| 182 |
+
bias=True,
|
| 183 |
+
params_dtype=config.torch_dtype,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if config.hidden_act != "relu":
|
| 187 |
+
raise ValueError(f"Geneformer requires hidden_act='relu', got '{config.hidden_act}'")
|
| 188 |
+
self.layernorm_mlp.activation = nn.ReLU()
|
| 189 |
+
|
| 190 |
+
self.layernorm_mlp.fc2 = te.Linear(
|
| 191 |
+
config.intermediate_size,
|
| 192 |
+
config.hidden_size,
|
| 193 |
+
bias=True,
|
| 194 |
+
params_dtype=config.torch_dtype,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Post-MLP LayerNorm
|
| 198 |
+
self.layernorm_mlp.layer_norm = te.LayerNorm(
|
| 199 |
+
normalized_shape=config.hidden_size,
|
| 200 |
+
eps=config.layer_norm_eps,
|
| 201 |
+
params_dtype=config.torch_dtype,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Dropout
|
| 205 |
+
self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 206 |
+
self.mlp_dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 207 |
+
|
| 208 |
+
def forward(
|
| 209 |
+
self,
|
| 210 |
+
hidden_states: torch.Tensor,
|
| 211 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 212 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 213 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 214 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 215 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 216 |
+
output_attentions: Optional[bool] = False,
|
| 217 |
+
) -> Tuple[torch.Tensor]:
|
| 218 |
+
"""Forward pass through the TE BERT layer.
|
| 219 |
+
|
| 220 |
+
Architecture
|
| 221 |
+
Input
|
| 222 |
+
→ Self-Attention
|
| 223 |
+
→ Dropout
|
| 224 |
+
→ Residual Connection
|
| 225 |
+
→ LayerNorm
|
| 226 |
+
→ MLP
|
| 227 |
+
→ Dropout
|
| 228 |
+
→ Residual Connection
|
| 229 |
+
→ LayerNorm
|
| 230 |
+
→ Output
|
| 231 |
+
|
| 232 |
+
This architecture is the key divergence from typical TransformerLayer
|
| 233 |
+
(with output_layernorm=False default) which uses PRE-norm.
|
| 234 |
+
|
| 235 |
+
In PRE-norm TransformerLayer, LayerNorm is applied Before operations:
|
| 236 |
+
- MultiheadAttention with input_layernorm=True applies LayerNorm internally before attention
|
| 237 |
+
- LayerNormMLP applies LayerNorm internally before MLP
|
| 238 |
+
- Residuals bypass these internal LayerNorms
|
| 239 |
+
|
| 240 |
+
In Geneformer's POST-norm, LayerNorm is applied after residual connections as explicit
|
| 241 |
+
separate modules, meaning the normalized output flows to the next layer.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
hidden_states: Input hidden states.
|
| 245 |
+
attention_mask: Attention mask.
|
| 246 |
+
head_mask: Head mask.
|
| 247 |
+
encoder_hidden_states: Encoder hidden states.
|
| 248 |
+
encoder_attention_mask: Encoder attention mask.
|
| 249 |
+
past_key_value: Past key value.
|
| 250 |
+
output_attentions: Whether to output attentions.
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
Tuple of tensors containing the layer output.
|
| 254 |
+
"""
|
| 255 |
+
# Attention mask handling for TE MultiheadAttention, [batch, 1, 1, seq_len], True=masked, False=attend
|
| 256 |
+
te_attention_mask = None
|
| 257 |
+
te_mask_type = "no_mask"
|
| 258 |
+
|
| 259 |
+
if attention_mask is not None:
|
| 260 |
+
# Check if there's actual padding (not all 1s for 2D or not all 0s for 4D)
|
| 261 |
+
if attention_mask.dim() == 2:
|
| 262 |
+
# Standard [batch, seq_len] where 1=attend, 0=masked
|
| 263 |
+
has_padding = not torch.all(attention_mask == 1)
|
| 264 |
+
if has_padding:
|
| 265 |
+
# Convert to TE format: [batch, 1, 1, seq_len], invert polarity
|
| 266 |
+
te_attention_mask = ~attention_mask.bool().unsqueeze(1).unsqueeze(1)
|
| 267 |
+
te_mask_type = "padding"
|
| 268 |
+
elif attention_mask.dim() in [3, 4]:
|
| 269 |
+
# Extended mask with -inf for masked positions
|
| 270 |
+
|
| 271 |
+
has_masking = torch.any(
|
| 272 |
+
attention_mask < -10000.0
|
| 273 |
+
) # Check if it's not a trivial mask (all zeros/no masking)
|
| 274 |
+
if has_masking:
|
| 275 |
+
# Extract padding mask and convert to TE format
|
| 276 |
+
if attention_mask.dim() == 4:
|
| 277 |
+
padding_mask = attention_mask[:, 0, 0, :] # [batch, seq_len]
|
| 278 |
+
else: # dim == 3
|
| 279 |
+
padding_mask = attention_mask[:, 0, :] # [batch, seq_len]
|
| 280 |
+
# -inf to True (masked), 0 to False (attend)
|
| 281 |
+
# Then reshape to [batch, 1, 1, seq_len]
|
| 282 |
+
te_attention_mask = (padding_mask < -10000.0).unsqueeze(1).unsqueeze(1)
|
| 283 |
+
te_mask_type = "padding"
|
| 284 |
+
|
| 285 |
+
# Self-Attention sub-layer
|
| 286 |
+
attention_output = self.self_attention(
|
| 287 |
+
hidden_states,
|
| 288 |
+
attention_mask=te_attention_mask,
|
| 289 |
+
attn_mask_type=te_mask_type,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Residual connection + dropout + LayerNorm (POST-norm)
|
| 293 |
+
attention_output = self.attention_dropout(attention_output)
|
| 294 |
+
hidden_states = hidden_states + attention_output
|
| 295 |
+
hidden_states = self.layernorm(hidden_states)
|
| 296 |
+
|
| 297 |
+
# MLP sub-layer
|
| 298 |
+
mlp_output = self.layernorm_mlp.fc1(hidden_states)
|
| 299 |
+
mlp_output = self.layernorm_mlp.activation(mlp_output)
|
| 300 |
+
mlp_output = self.layernorm_mlp.fc2(mlp_output)
|
| 301 |
+
|
| 302 |
+
# Residual connection + dropout + LayerNorm (POST-norm)
|
| 303 |
+
mlp_output = self.mlp_dropout(mlp_output)
|
| 304 |
+
hidden_states = hidden_states + mlp_output
|
| 305 |
+
hidden_states = self.layernorm_mlp.layer_norm(hidden_states)
|
| 306 |
+
|
| 307 |
+
return (hidden_states,)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class BertEncoder(nn.Module):
|
| 311 |
+
def __init__(self, config):
|
| 312 |
+
super().__init__()
|
| 313 |
+
self.config = config
|
| 314 |
+
if self.config.use_te_layers:
|
| 315 |
+
self.layer = nn.ModuleList(
|
| 316 |
+
[TEBertLayer(config, layer_number=i + 1) for i in range(config.num_hidden_layers)]
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
| 320 |
+
self.gradient_checkpointing = False
|
| 321 |
+
|
| 322 |
+
def _process_layer_outputs(
|
| 323 |
+
self,
|
| 324 |
+
layer_outputs,
|
| 325 |
+
hidden_states,
|
| 326 |
+
all_hidden_states,
|
| 327 |
+
all_self_attentions,
|
| 328 |
+
all_cross_attentions,
|
| 329 |
+
output_hidden_states,
|
| 330 |
+
output_attentions,
|
| 331 |
+
use_cache,
|
| 332 |
+
next_decoder_cache,
|
| 333 |
+
):
|
| 334 |
+
"""Process outputs from a single layer."""
|
| 335 |
+
hidden_states = layer_outputs[0]
|
| 336 |
+
|
| 337 |
+
if use_cache and next_decoder_cache is not None:
|
| 338 |
+
next_decoder_cache = (*next_decoder_cache, layer_outputs[-1])
|
| 339 |
+
|
| 340 |
+
if output_attentions and len(layer_outputs) > 1:
|
| 341 |
+
if all_self_attentions is None:
|
| 342 |
+
all_self_attentions = (layer_outputs[1],)
|
| 343 |
+
else:
|
| 344 |
+
all_self_attentions = (*all_self_attentions, layer_outputs[1])
|
| 345 |
+
if self.config.add_cross_attention and len(layer_outputs) > 2:
|
| 346 |
+
if all_cross_attentions is None:
|
| 347 |
+
all_cross_attentions = (layer_outputs[2],)
|
| 348 |
+
else:
|
| 349 |
+
all_cross_attentions = (*all_cross_attentions, layer_outputs[2])
|
| 350 |
+
|
| 351 |
+
return hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions, next_decoder_cache
|
| 352 |
+
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
hidden_states: torch.Tensor,
|
| 356 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 357 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 358 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 359 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 360 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 361 |
+
use_cache: Optional[bool] = None,
|
| 362 |
+
output_attentions: Optional[bool] = False,
|
| 363 |
+
output_hidden_states: Optional[bool] = False,
|
| 364 |
+
return_dict: Optional[bool] = True,
|
| 365 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 366 |
+
all_hidden_states = () if output_hidden_states else None
|
| 367 |
+
all_self_attentions = () if output_attentions else None
|
| 368 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 369 |
+
|
| 370 |
+
if self.gradient_checkpointing and self.training:
|
| 371 |
+
if use_cache:
|
| 372 |
+
logger.warning(
|
| 373 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 374 |
+
)
|
| 375 |
+
use_cache = False
|
| 376 |
+
|
| 377 |
+
next_decoder_cache = () if use_cache else None
|
| 378 |
+
for i, layer_module in enumerate(self.layer):
|
| 379 |
+
if output_hidden_states:
|
| 380 |
+
if all_hidden_states is None:
|
| 381 |
+
all_hidden_states = (hidden_states,)
|
| 382 |
+
else:
|
| 383 |
+
all_hidden_states = (*all_hidden_states, hidden_states)
|
| 384 |
+
|
| 385 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 386 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 387 |
+
|
| 388 |
+
if self.gradient_checkpointing and self.training:
|
| 389 |
+
from torch.utils.checkpoint import checkpoint
|
| 390 |
+
|
| 391 |
+
layer_outputs = checkpoint(
|
| 392 |
+
layer_module,
|
| 393 |
+
hidden_states,
|
| 394 |
+
attention_mask,
|
| 395 |
+
layer_head_mask,
|
| 396 |
+
encoder_hidden_states,
|
| 397 |
+
encoder_attention_mask,
|
| 398 |
+
past_key_value,
|
| 399 |
+
output_attentions,
|
| 400 |
+
use_reentrant=False,
|
| 401 |
+
)
|
| 402 |
+
else:
|
| 403 |
+
layer_outputs = layer_module(
|
| 404 |
+
hidden_states,
|
| 405 |
+
attention_mask,
|
| 406 |
+
layer_head_mask,
|
| 407 |
+
encoder_hidden_states,
|
| 408 |
+
encoder_attention_mask,
|
| 409 |
+
past_key_value,
|
| 410 |
+
output_attentions,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions, next_decoder_cache = (
|
| 414 |
+
self._process_layer_outputs(
|
| 415 |
+
layer_outputs,
|
| 416 |
+
hidden_states,
|
| 417 |
+
all_hidden_states,
|
| 418 |
+
all_self_attentions,
|
| 419 |
+
all_cross_attentions,
|
| 420 |
+
output_hidden_states,
|
| 421 |
+
output_attentions,
|
| 422 |
+
use_cache,
|
| 423 |
+
next_decoder_cache,
|
| 424 |
+
)
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if output_hidden_states:
|
| 428 |
+
if all_hidden_states is None:
|
| 429 |
+
all_hidden_states = (hidden_states,)
|
| 430 |
+
else:
|
| 431 |
+
all_hidden_states = (*all_hidden_states, hidden_states)
|
| 432 |
+
|
| 433 |
+
if not return_dict:
|
| 434 |
+
return tuple(
|
| 435 |
+
v
|
| 436 |
+
for v in [
|
| 437 |
+
hidden_states,
|
| 438 |
+
next_decoder_cache,
|
| 439 |
+
all_hidden_states,
|
| 440 |
+
all_self_attentions,
|
| 441 |
+
all_cross_attentions,
|
| 442 |
+
]
|
| 443 |
+
if v is not None
|
| 444 |
+
)
|
| 445 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 446 |
+
last_hidden_state=hidden_states,
|
| 447 |
+
past_key_values=next_decoder_cache,
|
| 448 |
+
hidden_states=all_hidden_states,
|
| 449 |
+
attentions=all_self_attentions,
|
| 450 |
+
cross_attentions=all_cross_attentions,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
class BertModel(BertPreTrainedModel):
|
| 455 |
+
"""BERT model for encoding and decoding.
|
| 456 |
+
|
| 457 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 458 |
+
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
| 459 |
+
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 460 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 461 |
+
|
| 462 |
+
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
| 463 |
+
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
| 464 |
+
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
config_class = TEBertConfig
|
| 468 |
+
|
| 469 |
+
# TODO(@jomitchell) Can start swapping layers here for TE layers.
|
| 470 |
+
_no_split_modules: ClassVar[List[str]] = ["BertEmbeddings", "BertLayer", "TEBertLayer"]
|
| 471 |
+
|
| 472 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 473 |
+
"""Initialize the BertModel.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
config: Configuration object containing model parameters.
|
| 477 |
+
add_pooling_layer: Whether to add a pooling layer on top of the encoder.
|
| 478 |
+
"""
|
| 479 |
+
super().__init__(config)
|
| 480 |
+
self.config = config
|
| 481 |
+
|
| 482 |
+
self.embeddings = BertEmbeddings(config)
|
| 483 |
+
self.encoder = BertEncoder(config)
|
| 484 |
+
|
| 485 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 486 |
+
|
| 487 |
+
self.attn_implementation = config._attn_implementation
|
| 488 |
+
self.position_embedding_type = config.position_embedding_type
|
| 489 |
+
|
| 490 |
+
# Initialize weights and apply final processing
|
| 491 |
+
self.post_init()
|
| 492 |
+
|
| 493 |
+
def get_input_embeddings(self):
|
| 494 |
+
"""Get the input embeddings."""
|
| 495 |
+
return self.embeddings.word_embeddings
|
| 496 |
+
|
| 497 |
+
def set_input_embeddings(self, value):
|
| 498 |
+
"""Set the input embeddings."""
|
| 499 |
+
self.embeddings.word_embeddings = value
|
| 500 |
+
|
| 501 |
+
def _prune_heads(self, heads_to_prune):
|
| 502 |
+
"""Prunes heads of the model.
|
| 503 |
+
|
| 504 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 505 |
+
class PreTrainedModel.
|
| 506 |
+
"""
|
| 507 |
+
for layer, heads in heads_to_prune.items():
|
| 508 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 509 |
+
|
| 510 |
+
def _validate_and_prepare_inputs(
|
| 511 |
+
self,
|
| 512 |
+
input_ids,
|
| 513 |
+
inputs_embeds,
|
| 514 |
+
attention_mask,
|
| 515 |
+
token_type_ids,
|
| 516 |
+
position_ids,
|
| 517 |
+
past_key_values,
|
| 518 |
+
):
|
| 519 |
+
"""Validate inputs and prepare basic input data."""
|
| 520 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 521 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 522 |
+
elif input_ids is not None:
|
| 523 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 524 |
+
input_shape = input_ids.size()
|
| 525 |
+
elif inputs_embeds is not None:
|
| 526 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 527 |
+
else:
|
| 528 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 529 |
+
|
| 530 |
+
batch_size, seq_length = input_shape
|
| 531 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 532 |
+
|
| 533 |
+
# past_key_values_length
|
| 534 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 535 |
+
|
| 536 |
+
if token_type_ids is None:
|
| 537 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
| 538 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
| 539 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
| 540 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 541 |
+
else:
|
| 542 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 543 |
+
|
| 544 |
+
embedding_output = self.embeddings(
|
| 545 |
+
input_ids=input_ids,
|
| 546 |
+
position_ids=position_ids,
|
| 547 |
+
token_type_ids=token_type_ids,
|
| 548 |
+
inputs_embeds=inputs_embeds,
|
| 549 |
+
past_key_values_length=past_key_values_length,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
if attention_mask is None:
|
| 553 |
+
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
| 554 |
+
|
| 555 |
+
return (
|
| 556 |
+
input_shape,
|
| 557 |
+
batch_size,
|
| 558 |
+
seq_length,
|
| 559 |
+
device,
|
| 560 |
+
past_key_values_length,
|
| 561 |
+
token_type_ids,
|
| 562 |
+
embedding_output,
|
| 563 |
+
attention_mask,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
def _prepare_attention_masks(
|
| 567 |
+
self,
|
| 568 |
+
attention_mask,
|
| 569 |
+
input_shape,
|
| 570 |
+
embedding_output,
|
| 571 |
+
past_key_values_length,
|
| 572 |
+
seq_length,
|
| 573 |
+
device,
|
| 574 |
+
head_mask,
|
| 575 |
+
output_attentions,
|
| 576 |
+
encoder_hidden_states,
|
| 577 |
+
encoder_attention_mask,
|
| 578 |
+
):
|
| 579 |
+
"""Prepare attention masks for the forward pass."""
|
| 580 |
+
use_sdpa_attention_masks = (
|
| 581 |
+
self.attn_implementation == "sdpa"
|
| 582 |
+
and self.position_embedding_type == "absolute"
|
| 583 |
+
and head_mask is None
|
| 584 |
+
and not output_attentions
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
# Expand the attention mask
|
| 588 |
+
if use_sdpa_attention_masks and attention_mask.dim() == 2:
|
| 589 |
+
# Expand the attention mask for SDPA.
|
| 590 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
| 591 |
+
if self.config.is_decoder:
|
| 592 |
+
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 593 |
+
attention_mask,
|
| 594 |
+
input_shape,
|
| 595 |
+
embedding_output,
|
| 596 |
+
past_key_values_length,
|
| 597 |
+
)
|
| 598 |
+
else:
|
| 599 |
+
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 600 |
+
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 604 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 605 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 606 |
+
|
| 607 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 608 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 609 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
| 610 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 611 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 612 |
+
if encoder_attention_mask is None:
|
| 613 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 614 |
+
|
| 615 |
+
if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
|
| 616 |
+
# Expand the attention mask for SDPA.
|
| 617 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
| 618 |
+
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 619 |
+
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
| 620 |
+
)
|
| 621 |
+
else:
|
| 622 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 623 |
+
else:
|
| 624 |
+
encoder_extended_attention_mask = None
|
| 625 |
+
|
| 626 |
+
return extended_attention_mask, encoder_extended_attention_mask
|
| 627 |
+
|
| 628 |
+
def _prepare_inputs_and_masks(
|
| 629 |
+
self,
|
| 630 |
+
input_ids,
|
| 631 |
+
inputs_embeds,
|
| 632 |
+
attention_mask,
|
| 633 |
+
token_type_ids,
|
| 634 |
+
position_ids,
|
| 635 |
+
head_mask,
|
| 636 |
+
past_key_values,
|
| 637 |
+
encoder_hidden_states,
|
| 638 |
+
encoder_attention_mask,
|
| 639 |
+
output_attentions,
|
| 640 |
+
output_hidden_states,
|
| 641 |
+
use_cache,
|
| 642 |
+
return_dict,
|
| 643 |
+
):
|
| 644 |
+
"""Prepare inputs and attention masks for the forward pass."""
|
| 645 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 646 |
+
output_hidden_states = (
|
| 647 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 648 |
+
)
|
| 649 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 650 |
+
|
| 651 |
+
if self.config.is_decoder:
|
| 652 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 653 |
+
else:
|
| 654 |
+
use_cache = False
|
| 655 |
+
|
| 656 |
+
(
|
| 657 |
+
input_shape,
|
| 658 |
+
batch_size,
|
| 659 |
+
seq_length,
|
| 660 |
+
device,
|
| 661 |
+
past_key_values_length,
|
| 662 |
+
token_type_ids,
|
| 663 |
+
embedding_output,
|
| 664 |
+
attention_mask,
|
| 665 |
+
) = self._validate_and_prepare_inputs(
|
| 666 |
+
input_ids,
|
| 667 |
+
inputs_embeds,
|
| 668 |
+
attention_mask,
|
| 669 |
+
token_type_ids,
|
| 670 |
+
position_ids,
|
| 671 |
+
past_key_values,
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
extended_attention_mask, encoder_extended_attention_mask = self._prepare_attention_masks(
|
| 675 |
+
attention_mask,
|
| 676 |
+
input_shape,
|
| 677 |
+
embedding_output,
|
| 678 |
+
past_key_values_length,
|
| 679 |
+
seq_length,
|
| 680 |
+
device,
|
| 681 |
+
head_mask,
|
| 682 |
+
output_attentions,
|
| 683 |
+
encoder_hidden_states,
|
| 684 |
+
encoder_attention_mask,
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# Prepare head mask if needed
|
| 688 |
+
# 1.0 in head_mask indicate we keep the head
|
| 689 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 690 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 691 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 692 |
+
processed_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 693 |
+
|
| 694 |
+
return (
|
| 695 |
+
embedding_output,
|
| 696 |
+
extended_attention_mask,
|
| 697 |
+
processed_head_mask,
|
| 698 |
+
encoder_extended_attention_mask,
|
| 699 |
+
use_cache,
|
| 700 |
+
return_dict,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
def forward(
|
| 704 |
+
self,
|
| 705 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 706 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 707 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 708 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 709 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 710 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 711 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 712 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 713 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 714 |
+
use_cache: Optional[bool] = None,
|
| 715 |
+
output_attentions: Optional[bool] = None,
|
| 716 |
+
output_hidden_states: Optional[bool] = None,
|
| 717 |
+
return_dict: Optional[bool] = None,
|
| 718 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 719 |
+
r"""Forward pass of the BertModel.
|
| 720 |
+
|
| 721 |
+
Args:
|
| 722 |
+
input_ids (`torch.Tensor`, *optional*): Input token IDs.
|
| 723 |
+
attention_mask (`torch.Tensor`, *optional*): Attention mask.
|
| 724 |
+
token_type_ids (`torch.Tensor`, *optional*): Token type IDs.
|
| 725 |
+
position_ids (`torch.Tensor`, *optional*): Position IDs.
|
| 726 |
+
head_mask (`torch.Tensor`, *optional*): Head mask.
|
| 727 |
+
inputs_embeds (`torch.Tensor`, *optional*): Input embeddings.
|
| 728 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 729 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 730 |
+
the model is configured as a decoder.
|
| 731 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 732 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 733 |
+
the model is configured as a decoder.
|
| 734 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
|
| 735 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 736 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 737 |
+
|
| 738 |
+
- 1 for tokens that are **not masked**,
|
| 739 |
+
- 0 for tokens that are **masked**.
|
| 740 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 741 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding (see
|
| 742 |
+
`past_key_values`).
|
| 743 |
+
use_cache (`bool`, *optional*):
|
| 744 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 745 |
+
`past_key_values`).
|
| 746 |
+
output_attentions (`bool`, *optional*): Whether to output attentions.
|
| 747 |
+
output_hidden_states (`bool`, *optional*): Whether to output hidden states.
|
| 748 |
+
return_dict (`bool`, *optional*): Whether to return a ModelOutput instead of a tuple.
|
| 749 |
+
"""
|
| 750 |
+
(
|
| 751 |
+
embedding_output,
|
| 752 |
+
extended_attention_mask,
|
| 753 |
+
processed_head_mask,
|
| 754 |
+
encoder_extended_attention_mask,
|
| 755 |
+
use_cache,
|
| 756 |
+
return_dict,
|
| 757 |
+
) = self._prepare_inputs_and_masks(
|
| 758 |
+
input_ids,
|
| 759 |
+
inputs_embeds,
|
| 760 |
+
attention_mask,
|
| 761 |
+
token_type_ids,
|
| 762 |
+
position_ids,
|
| 763 |
+
head_mask,
|
| 764 |
+
past_key_values,
|
| 765 |
+
encoder_hidden_states,
|
| 766 |
+
encoder_attention_mask,
|
| 767 |
+
output_attentions,
|
| 768 |
+
output_hidden_states,
|
| 769 |
+
use_cache,
|
| 770 |
+
return_dict,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
encoder_outputs = self.encoder(
|
| 774 |
+
embedding_output,
|
| 775 |
+
attention_mask=extended_attention_mask,
|
| 776 |
+
head_mask=processed_head_mask,
|
| 777 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 778 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 779 |
+
past_key_values=past_key_values,
|
| 780 |
+
use_cache=use_cache,
|
| 781 |
+
output_attentions=output_attentions,
|
| 782 |
+
output_hidden_states=output_hidden_states,
|
| 783 |
+
return_dict=return_dict,
|
| 784 |
+
)
|
| 785 |
+
sequence_output = encoder_outputs[0]
|
| 786 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 787 |
+
|
| 788 |
+
if not return_dict:
|
| 789 |
+
return (sequence_output, pooled_output, *encoder_outputs[1:])
|
| 790 |
+
|
| 791 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 792 |
+
last_hidden_state=sequence_output,
|
| 793 |
+
pooler_output=pooled_output,
|
| 794 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 795 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 796 |
+
attentions=encoder_outputs.attentions,
|
| 797 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
| 802 |
+
"""BERT model for masked language modeling."""
|
| 803 |
+
|
| 804 |
+
config_class = TEBertConfig
|
| 805 |
+
_tied_weights_keys: ClassVar[List[str]] = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 806 |
+
|
| 807 |
+
def __init__(self, config):
|
| 808 |
+
"""Initialize the BertForMaskedLM.
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
config: Configuration object containing model parameters.
|
| 812 |
+
"""
|
| 813 |
+
super().__init__(config)
|
| 814 |
+
|
| 815 |
+
if config.is_decoder:
|
| 816 |
+
logger.warning(
|
| 817 |
+
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
|
| 818 |
+
"bi-directional self-attention."
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 822 |
+
self.cls = BertOnlyMLMHead(config)
|
| 823 |
+
|
| 824 |
+
# Initialize weights and apply final processing
|
| 825 |
+
self.post_init()
|
| 826 |
+
|
| 827 |
+
def get_output_embeddings(self):
|
| 828 |
+
"""Get the output embeddings."""
|
| 829 |
+
return self.cls.predictions.decoder
|
| 830 |
+
|
| 831 |
+
def set_output_embeddings(self, new_embeddings):
|
| 832 |
+
"""Set the output embeddings."""
|
| 833 |
+
self.cls.predictions.decoder = new_embeddings
|
| 834 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 835 |
+
|
| 836 |
+
def forward(
|
| 837 |
+
self,
|
| 838 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 839 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 840 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 841 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 842 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 843 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 844 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 845 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 846 |
+
labels: Optional[torch.Tensor] = None,
|
| 847 |
+
output_attentions: Optional[bool] = None,
|
| 848 |
+
output_hidden_states: Optional[bool] = None,
|
| 849 |
+
return_dict: Optional[bool] = None,
|
| 850 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 851 |
+
r"""Forward pass for masked language modeling.
|
| 852 |
+
|
| 853 |
+
Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 854 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 855 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 856 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 857 |
+
"""
|
| 858 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 859 |
+
|
| 860 |
+
outputs = self.bert(
|
| 861 |
+
input_ids,
|
| 862 |
+
attention_mask=attention_mask,
|
| 863 |
+
token_type_ids=token_type_ids,
|
| 864 |
+
position_ids=position_ids,
|
| 865 |
+
head_mask=head_mask,
|
| 866 |
+
inputs_embeds=inputs_embeds,
|
| 867 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 868 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 869 |
+
output_attentions=output_attentions,
|
| 870 |
+
output_hidden_states=output_hidden_states,
|
| 871 |
+
return_dict=return_dict,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
sequence_output = outputs[0]
|
| 875 |
+
prediction_scores = self.cls(sequence_output)
|
| 876 |
+
|
| 877 |
+
masked_lm_loss = None
|
| 878 |
+
if labels is not None:
|
| 879 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 880 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 881 |
+
|
| 882 |
+
if not return_dict:
|
| 883 |
+
output = (prediction_scores, *outputs[2:])
|
| 884 |
+
return (masked_lm_loss, *output) if masked_lm_loss is not None else output
|
| 885 |
+
|
| 886 |
+
return MaskedLMOutput(
|
| 887 |
+
loss=masked_lm_loss,
|
| 888 |
+
logits=prediction_scores,
|
| 889 |
+
hidden_states=outputs.hidden_states,
|
| 890 |
+
attentions=outputs.attentions,
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
| 894 |
+
"""Prepare inputs for generation."""
|
| 895 |
+
input_shape = input_ids.shape
|
| 896 |
+
effective_batch_size = input_shape[0]
|
| 897 |
+
|
| 898 |
+
# add a dummy token
|
| 899 |
+
if self.config.pad_token_id is None:
|
| 900 |
+
raise ValueError("The PAD token should be defined for generation")
|
| 901 |
+
|
| 902 |
+
attention_mask = torch.cat(
|
| 903 |
+
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
|
| 904 |
+
dim=-1,
|
| 905 |
+
)
|
| 906 |
+
dummy_token = torch.full(
|
| 907 |
+
(effective_batch_size, 1),
|
| 908 |
+
self.config.pad_token_id,
|
| 909 |
+
dtype=torch.long,
|
| 910 |
+
device=input_ids.device,
|
| 911 |
+
)
|
| 912 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 913 |
+
|
| 914 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 915 |
+
|
| 916 |
+
@classmethod
|
| 917 |
+
def can_generate(cls) -> bool:
|
| 918 |
+
"""Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`.
|
| 919 |
+
|
| 920 |
+
Even though it has a `prepare_inputs_for_generation` method.
|
| 921 |
+
"""
|
| 922 |
+
return False
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
__all__ = [
|
| 926 |
+
"BertForMaskedLM",
|
| 927 |
+
"BertLayer",
|
| 928 |
+
"BertModel",
|
| 929 |
+
"TEBertLayer",
|
| 930 |
+
]
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d0dfca398accc3b54ce8d9d1574bf92f393a9289426563cd68fdd2666dd74f09
|
| 3 |
+
size 67302248
|