Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- assets/control.png +3 -0
- assets/subject.png +3 -0
- generation/control/ControlNet/font/DejaVuSans.ttf +3 -0
- generation/control/ControlNet/ldm/modules/image_degradation/utils/test.png +3 -0
- llama/data/MetaMathQA-40K.json +3 -0
- llama/data/MetaMathQA.json +3 -0
- llama/output/cp1e4/ft/adapter_model.safetensors +3 -0
- llama/output/cp1e4/ft/tokenizer.model +3 -0
- llama/output/cp1e5/ft/adapter_model.safetensors +3 -0
- llama/output/cp1e5N/ft/adapter_model.safetensors +3 -0
- llama/output/cp1e5N/ft/tokenizer.model +3 -0
- llama/output/cp3e5/ft/adapter_model.safetensors +3 -0
- llama/output/cp3e5N/ft/adapter_model.safetensors +3 -0
- llama/output/cp3e5N/ft/tokenizer.model +3 -0
- llama/output/cpr1/ft/adapter_model.safetensors +3 -0
- llama/output/cpr1/ft/tokenizer.model +3 -0
- llama/output/cpr2/ft/adapter_model.safetensors +3 -0
- llama/output/cpr2/ft/tokenizer.model +3 -0
- nlu/DeBERTa.egg-info/PKG-INFO +39 -0
- nlu/DeBERTa.egg-info/SOURCES.txt +73 -0
- nlu/DeBERTa.egg-info/dependency_links.txt +1 -0
- nlu/DeBERTa.egg-info/requires.txt +19 -0
- nlu/DeBERTa.egg-info/top_level.txt +2 -0
- nlu/DeBERTa/apps/tasks/task_registry.py +70 -0
- nlu/DeBERTa/data/__init__.py +5 -0
- nlu/DeBERTa/data/async_data.py +38 -0
- nlu/DeBERTa/data/data_sampler.py +76 -0
- nlu/DeBERTa/data/dataloader.py +511 -0
- nlu/DeBERTa/data/dynamic_dataset.py +60 -0
- nlu/DeBERTa/data/example.py +105 -0
- nlu/DeBERTa/deberta/__init__.py +22 -0
- nlu/DeBERTa/deberta/bert.py +308 -0
- nlu/DeBERTa/deberta/cache_utils.py +135 -0
- nlu/DeBERTa/deberta/config.py +90 -0
- nlu/DeBERTa/deberta/da_utils.py +68 -0
- nlu/DeBERTa/deberta/deberta.py +145 -0
- nlu/DeBERTa/deberta/disentangled_attention.py +221 -0
- nlu/DeBERTa/deberta/gpt2_bpe_utils.py +163 -0
- nlu/DeBERTa/deberta/gpt2_tokenizer.py +216 -0
- nlu/DeBERTa/deberta/mlm.py +38 -0
- nlu/DeBERTa/deberta/nnmodule.py +137 -0
- nlu/DeBERTa/deberta/ops.py +228 -0
- nlu/DeBERTa/deberta/pooling.py +88 -0
- nlu/DeBERTa/deberta/pretrained_models.py +2 -0
- nlu/DeBERTa/deberta/spm_tokenizer.py +322 -0
- nlu/DeBERTa/deberta/tokenizers.py +16 -0
- nlu/DeBERTa/optims/__init__.py +16 -0
- nlu/DeBERTa/optims/args.py +100 -0
- nlu/DeBERTa/optims/fp16_optimizer.py +301 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/control.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/subject.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
generation/control/ControlNet/font/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
generation/control/ControlNet/ldm/modules/image_degradation/utils/test.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
llama/data/MetaMathQA-40K.json filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
llama/data/MetaMathQA.json filter=lfs diff=lfs merge=lfs -text
|
assets/control.png
ADDED
|
Git LFS Details
|
assets/subject.png
ADDED
|
Git LFS Details
|
generation/control/ControlNet/font/DejaVuSans.ttf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7da195a74c55bef988d0d48f9508bd5d849425c1770dba5d7bfc6ce9ed848954
|
| 3 |
+
size 757076
|
generation/control/ControlNet/ldm/modules/image_degradation/utils/test.png
ADDED
|
Git LFS Details
|
llama/data/MetaMathQA-40K.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c884f10e8aa1229a6e73a6bba2c9134ee0c7b7de92a02a7b8c9459085a59e117
|
| 3 |
+
size 31076207
|
llama/data/MetaMathQA.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb39a5d8c05c042ece92eae37dfd5ea414a5979df2bf3ad3b86411bef8205725
|
| 3 |
+
size 395626321
|
llama/output/cp1e4/ft/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e1c2fceb4f91331d69364aa56d01dd2103d4e59066f1519f1242a62ecca387a
|
| 3 |
+
size 1082171824
|
llama/output/cp1e4/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
llama/output/cp1e5/ft/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6121d3f7682fd21f70fc78ab9097b22ede67191507c54d44a9bd9c30adf44de
|
| 3 |
+
size 592928
|
llama/output/cp1e5N/ft/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d85146aea100acda2fd5bb5a011f8d1e14983756bb0c102bf85efe04ac176479
|
| 3 |
+
size 1082171824
|
llama/output/cp1e5N/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
llama/output/cp3e5/ft/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1945e74d818ded53f08bc892bb458dd0e6addcd548b2f864dbd16a476a8954ef
|
| 3 |
+
size 1082171824
|
llama/output/cp3e5N/ft/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2396d96c0a301cceddf424fbdf7c7f3518311f90140fa9aad9053706288e9fc
|
| 3 |
+
size 1082171824
|
llama/output/cp3e5N/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
llama/output/cpr1/ft/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:617c715b246fae47190ca1f8e304e9dbdadf6ac70bbfdd0f3bc3c4b1cd783c0d
|
| 3 |
+
size 1049665904
|
llama/output/cpr1/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
llama/output/cpr2/ft/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:daede58d9fd4806298d90f9af12ba478c119afab844244f355f35ab3829eb029
|
| 3 |
+
size 1049665904
|
llama/output/cpr2/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
nlu/DeBERTa.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: DeBERTa
|
| 3 |
+
Version: 0.1.13
|
| 4 |
+
Summary: Decoding enhanced BERT with Disentangled Attention
|
| 5 |
+
Home-page: https://github.com/microsoft/DeBERTa
|
| 6 |
+
Author: penhe
|
| 7 |
+
Author-email: penhe@microsoft.com
|
| 8 |
+
License: MIT
|
| 9 |
+
Keywords: NLP deep learning transformer pytorch Attention BERT RoBERTa DeBERTa
|
| 10 |
+
Classifier: Programming Language :: Python :: 3
|
| 11 |
+
Classifier: Programming Language :: Python :: 3.6
|
| 12 |
+
Classifier: Programming Language :: Python :: 3.7
|
| 13 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 14 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 15 |
+
Classifier: License :: OSI Approved :: MIT License
|
| 16 |
+
Classifier: Operating System :: OS Independent
|
| 17 |
+
Requires-Python: >=3.6
|
| 18 |
+
Description-Content-Type: text/markdown
|
| 19 |
+
License-File: LICENSE
|
| 20 |
+
Requires-Dist: nltk
|
| 21 |
+
Requires-Dist: spacy
|
| 22 |
+
Requires-Dist: numpy
|
| 23 |
+
Requires-Dist: pytest
|
| 24 |
+
Requires-Dist: regex
|
| 25 |
+
Requires-Dist: scipy
|
| 26 |
+
Requires-Dist: scikit-learn
|
| 27 |
+
Requires-Dist: tqdm
|
| 28 |
+
Requires-Dist: ujson
|
| 29 |
+
Requires-Dist: seqeval
|
| 30 |
+
Requires-Dist: psutil
|
| 31 |
+
Requires-Dist: sentencepiece
|
| 32 |
+
Requires-Dist: torch
|
| 33 |
+
Provides-Extra: docs
|
| 34 |
+
Requires-Dist: recommonmark; extra == "docs"
|
| 35 |
+
Requires-Dist: sphinx; extra == "docs"
|
| 36 |
+
Requires-Dist: sphinx-markdown-tables; extra == "docs"
|
| 37 |
+
Requires-Dist: sphinx-rtd-theme; extra == "docs"
|
| 38 |
+
|
| 39 |
+
deberta long des
|
nlu/DeBERTa.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
setup.cfg
|
| 3 |
+
setup.py
|
| 4 |
+
DeBERTa/__init__.py
|
| 5 |
+
DeBERTa.egg-info/PKG-INFO
|
| 6 |
+
DeBERTa.egg-info/SOURCES.txt
|
| 7 |
+
DeBERTa.egg-info/dependency_links.txt
|
| 8 |
+
DeBERTa.egg-info/requires.txt
|
| 9 |
+
DeBERTa.egg-info/top_level.txt
|
| 10 |
+
DeBERTa/apps/__init__.py
|
| 11 |
+
DeBERTa/apps/_utils.py
|
| 12 |
+
DeBERTa/apps/run.py
|
| 13 |
+
DeBERTa/apps/models/__init__.py
|
| 14 |
+
DeBERTa/apps/models/masked_language_model.py
|
| 15 |
+
DeBERTa/apps/models/multi_choice.py
|
| 16 |
+
DeBERTa/apps/models/ner.py
|
| 17 |
+
DeBERTa/apps/models/record_qa.py
|
| 18 |
+
DeBERTa/apps/models/replaced_token_detection_model.py
|
| 19 |
+
DeBERTa/apps/models/sequence_classification.py
|
| 20 |
+
DeBERTa/apps/tasks/__init__.py
|
| 21 |
+
DeBERTa/apps/tasks/glue_tasks.py
|
| 22 |
+
DeBERTa/apps/tasks/metrics.py
|
| 23 |
+
DeBERTa/apps/tasks/mlm_task.py
|
| 24 |
+
DeBERTa/apps/tasks/ner_task.py
|
| 25 |
+
DeBERTa/apps/tasks/race_task.py
|
| 26 |
+
DeBERTa/apps/tasks/record_eval.py
|
| 27 |
+
DeBERTa/apps/tasks/rtd_task.py
|
| 28 |
+
DeBERTa/apps/tasks/superglue_tasks.py
|
| 29 |
+
DeBERTa/apps/tasks/task.py
|
| 30 |
+
DeBERTa/apps/tasks/task_registry.py
|
| 31 |
+
DeBERTa/data/__init__.py
|
| 32 |
+
DeBERTa/data/async_data.py
|
| 33 |
+
DeBERTa/data/data_sampler.py
|
| 34 |
+
DeBERTa/data/dataloader.py
|
| 35 |
+
DeBERTa/data/dynamic_dataset.py
|
| 36 |
+
DeBERTa/data/example.py
|
| 37 |
+
DeBERTa/deberta/__init__.py
|
| 38 |
+
DeBERTa/deberta/bert.py
|
| 39 |
+
DeBERTa/deberta/cache_utils.py
|
| 40 |
+
DeBERTa/deberta/config.py
|
| 41 |
+
DeBERTa/deberta/da_utils.py
|
| 42 |
+
DeBERTa/deberta/deberta.py
|
| 43 |
+
DeBERTa/deberta/disentangled_attention.py
|
| 44 |
+
DeBERTa/deberta/gpt2_bpe_utils.py
|
| 45 |
+
DeBERTa/deberta/gpt2_tokenizer.py
|
| 46 |
+
DeBERTa/deberta/mlm.py
|
| 47 |
+
DeBERTa/deberta/nnmodule.py
|
| 48 |
+
DeBERTa/deberta/ops.py
|
| 49 |
+
DeBERTa/deberta/pooling.py
|
| 50 |
+
DeBERTa/deberta/pretrained_models.py
|
| 51 |
+
DeBERTa/deberta/spm_tokenizer.py
|
| 52 |
+
DeBERTa/deberta/tokenizers.py
|
| 53 |
+
DeBERTa/optims/__init__.py
|
| 54 |
+
DeBERTa/optims/args.py
|
| 55 |
+
DeBERTa/optims/fp16_optimizer.py
|
| 56 |
+
DeBERTa/optims/lr_schedulers.py
|
| 57 |
+
DeBERTa/optims/xadam.py
|
| 58 |
+
DeBERTa/sift/__init__.py
|
| 59 |
+
DeBERTa/sift/sift.py
|
| 60 |
+
DeBERTa/training/__init__.py
|
| 61 |
+
DeBERTa/training/_utils.py
|
| 62 |
+
DeBERTa/training/args.py
|
| 63 |
+
DeBERTa/training/dist_launcher.py
|
| 64 |
+
DeBERTa/training/optimizer_utils.py
|
| 65 |
+
DeBERTa/training/trainer.py
|
| 66 |
+
DeBERTa/utils/__init__.py
|
| 67 |
+
DeBERTa/utils/argument_types.py
|
| 68 |
+
DeBERTa/utils/jit_tracing.py
|
| 69 |
+
DeBERTa/utils/logger_util.py
|
| 70 |
+
DeBERTa/utils/xtqdm.py
|
| 71 |
+
adapterlib/__init__.py
|
| 72 |
+
adapterlib/layers.py
|
| 73 |
+
adapterlib/utils.py
|
nlu/DeBERTa.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
nlu/DeBERTa.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nltk
|
| 2 |
+
spacy
|
| 3 |
+
numpy
|
| 4 |
+
pytest
|
| 5 |
+
regex
|
| 6 |
+
scipy
|
| 7 |
+
scikit-learn
|
| 8 |
+
tqdm
|
| 9 |
+
ujson
|
| 10 |
+
seqeval
|
| 11 |
+
psutil
|
| 12 |
+
sentencepiece
|
| 13 |
+
torch
|
| 14 |
+
|
| 15 |
+
[docs]
|
| 16 |
+
recommonmark
|
| 17 |
+
sphinx
|
| 18 |
+
sphinx-markdown-tables
|
| 19 |
+
sphinx-rtd-theme
|
nlu/DeBERTa.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DeBERTa
|
| 2 |
+
adapterlib
|
nlu/DeBERTa/apps/tasks/task_registry.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
#
|
| 6 |
+
# Author: penhe@microsoft.com
|
| 7 |
+
# Date: 01/25/2019
|
| 8 |
+
#
|
| 9 |
+
|
| 10 |
+
from glob import glob
|
| 11 |
+
import os
|
| 12 |
+
import importlib
|
| 13 |
+
import pdb
|
| 14 |
+
import sys
|
| 15 |
+
from ...utils import get_logger
|
| 16 |
+
from .task import Task
|
| 17 |
+
|
| 18 |
+
__all__ = ['load_tasks', 'register_task', 'get_task']
|
| 19 |
+
tasks={}
|
| 20 |
+
|
| 21 |
+
logger=get_logger()
|
| 22 |
+
|
| 23 |
+
def register_task(name=None, desc=None):
|
| 24 |
+
def register_task_x(cls):
|
| 25 |
+
_name = name
|
| 26 |
+
if _name is None:
|
| 27 |
+
_name = cls.__name__
|
| 28 |
+
|
| 29 |
+
_desc = desc
|
| 30 |
+
if _desc is None:
|
| 31 |
+
_desc = _name
|
| 32 |
+
|
| 33 |
+
_name = _name.lower()
|
| 34 |
+
if _name in tasks:
|
| 35 |
+
logger.warning(f'{_name} already registered in the registry: {tasks[_name]}.')
|
| 36 |
+
assert issubclass(cls, Task), f'Registered class must be a subclass of Task.'
|
| 37 |
+
tasks[_name] = cls
|
| 38 |
+
cls._meta = {
|
| 39 |
+
'name': _name,
|
| 40 |
+
'desc': _desc}
|
| 41 |
+
return cls
|
| 42 |
+
|
| 43 |
+
if type(name)==type:
|
| 44 |
+
cls = name
|
| 45 |
+
name = None
|
| 46 |
+
return register_task_x(cls)
|
| 47 |
+
return register_task_x
|
| 48 |
+
|
| 49 |
+
def load_tasks(task_dir = None):
|
| 50 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 51 |
+
sys_tasks = glob(os.path.join(script_dir, "*.py"))
|
| 52 |
+
for t in sys_tasks:
|
| 53 |
+
m = os.path.splitext(os.path.basename(t))[0]
|
| 54 |
+
if not m.startswith('_'):
|
| 55 |
+
importlib.import_module(f'DeBERTa.apps.tasks.{m}')
|
| 56 |
+
|
| 57 |
+
if task_dir:
|
| 58 |
+
assert os.path.exists(task_dir), f"{task_dir} must be a valid directory."
|
| 59 |
+
customer_tasks = glob(os.path.join(task_dir, "*.py"))
|
| 60 |
+
sys.path.append(task_dir)
|
| 61 |
+
for t in customer_tasks:
|
| 62 |
+
m = os.path.splitext(os.path.basename(t))[0]
|
| 63 |
+
if not m.startswith('_'):
|
| 64 |
+
importlib.import_module(f'{m}')
|
| 65 |
+
|
| 66 |
+
def get_task(name=None):
|
| 67 |
+
if name is None:
|
| 68 |
+
return tasks
|
| 69 |
+
|
| 70 |
+
return tasks[name.lower()]
|
nlu/DeBERTa/data/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .example import ExampleInstance,ExampleSet,example_to_feature
|
| 2 |
+
from .dataloader import SequentialDataLoader
|
| 3 |
+
from .dynamic_dataset import *
|
| 4 |
+
from .data_sampler import *
|
| 5 |
+
from .async_data import *
|
nlu/DeBERTa/data/async_data.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
#
|
| 5 |
+
# Author: Pengcheng He (penhe@microsoft.com)
|
| 6 |
+
# Date: 05/15/2019
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
from queue import Queue,Empty
|
| 10 |
+
from threading import Thread
|
| 11 |
+
class AsyncDataLoader(object):
|
| 12 |
+
def __init__(self, dataloader, buffer_size=100):
|
| 13 |
+
self.buffer_size = buffer_size
|
| 14 |
+
self.dataloader = dataloader
|
| 15 |
+
|
| 16 |
+
def __iter__(self):
|
| 17 |
+
queue = Queue(self.buffer_size)
|
| 18 |
+
dl=iter(self.dataloader)
|
| 19 |
+
def _worker():
|
| 20 |
+
while True:
|
| 21 |
+
try:
|
| 22 |
+
queue.put(next(dl))
|
| 23 |
+
except StopIteration:
|
| 24 |
+
break
|
| 25 |
+
queue.put(None)
|
| 26 |
+
t=Thread(target=_worker)
|
| 27 |
+
t.start()
|
| 28 |
+
while True:
|
| 29 |
+
d = queue.get()
|
| 30 |
+
if d is None:
|
| 31 |
+
break
|
| 32 |
+
yield d
|
| 33 |
+
del t
|
| 34 |
+
del queue
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.dataloader)
|
| 38 |
+
|
nlu/DeBERTa/data/data_sampler.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
#
|
| 5 |
+
# Author: Pengcheng He (penhe@microsoft.com)
|
| 6 |
+
# Date: 05/15/2019
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
import math
|
| 12 |
+
import sys
|
| 13 |
+
from torch.utils.data import Sampler
|
| 14 |
+
|
| 15 |
+
__all__=['BatchSampler', 'DistributedBatchSampler', 'RandomSampler', 'SequentialSampler']
|
| 16 |
+
class BatchSampler(Sampler):
|
| 17 |
+
def __init__(self, sampler, batch_size):
|
| 18 |
+
self.sampler = sampler
|
| 19 |
+
self.batch_size = batch_size
|
| 20 |
+
|
| 21 |
+
def __iter__(self):
|
| 22 |
+
batch = []
|
| 23 |
+
for idx in self.sampler:
|
| 24 |
+
batch.append(idx)
|
| 25 |
+
if len(batch)==self.batch_size:
|
| 26 |
+
yield batch
|
| 27 |
+
batch = []
|
| 28 |
+
if len(batch)>0:
|
| 29 |
+
yield batch
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return (len(self.sampler) + self.batch_size - 1)//self.batch_size
|
| 33 |
+
|
| 34 |
+
class DistributedBatchSampler(Sampler):
|
| 35 |
+
def __init__(self, sampler, rank=0, world_size = 1, drop_last = False):
|
| 36 |
+
self.sampler = sampler
|
| 37 |
+
self.rank = rank
|
| 38 |
+
self.world_size = world_size
|
| 39 |
+
self.drop_last = drop_last
|
| 40 |
+
|
| 41 |
+
def __iter__(self):
|
| 42 |
+
for b in self.sampler:
|
| 43 |
+
if len(b)%self.world_size != 0:
|
| 44 |
+
if self.drop_last:
|
| 45 |
+
break
|
| 46 |
+
else:
|
| 47 |
+
b.extend([b[0] for _ in range(self.world_size-len(b)%self.world_size)])
|
| 48 |
+
chunk_size = len(b)//self.world_size
|
| 49 |
+
yield b[self.rank*chunk_size:(self.rank+1)*chunk_size]
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.sampler)
|
| 53 |
+
|
| 54 |
+
class RandomSampler(Sampler):
|
| 55 |
+
def __init__(self, total_samples:int, data_seed:int = 0):
|
| 56 |
+
self.indices = np.array(np.arange(total_samples))
|
| 57 |
+
self.rng = np.random.RandomState(data_seed)
|
| 58 |
+
|
| 59 |
+
def __iter__(self):
|
| 60 |
+
self.rng.shuffle(self.indices)
|
| 61 |
+
for i in self.indices:
|
| 62 |
+
yield i
|
| 63 |
+
|
| 64 |
+
def __len__(self):
|
| 65 |
+
return len(self.indices)
|
| 66 |
+
|
| 67 |
+
class SequentialSampler(Sampler):
|
| 68 |
+
def __init__(self, total_samples:int):
|
| 69 |
+
self.indices = np.array(np.arange(total_samples))
|
| 70 |
+
|
| 71 |
+
def __iter__(self):
|
| 72 |
+
for i in self.indices:
|
| 73 |
+
yield i
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(self.indices)
|
nlu/DeBERTa/data/dataloader.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import torch.multiprocessing as multiprocessing
|
| 4 |
+
from torch._C import _set_worker_signal_handlers, \
|
| 5 |
+
_remove_worker_pids, _error_if_any_worker_fails
|
| 6 |
+
|
| 7 |
+
from packaging import version
|
| 8 |
+
|
| 9 |
+
if version.Version(torch.__version__) >= version.Version('1.0.0'):
|
| 10 |
+
from torch._C import _set_worker_pids
|
| 11 |
+
else:
|
| 12 |
+
from torch._C import _update_worker_pids as _set_worker_pids
|
| 13 |
+
|
| 14 |
+
from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler, Sampler
|
| 15 |
+
import signal
|
| 16 |
+
import functools
|
| 17 |
+
import collections.abc
|
| 18 |
+
import re
|
| 19 |
+
import sys
|
| 20 |
+
import threading
|
| 21 |
+
import traceback
|
| 22 |
+
import os
|
| 23 |
+
import time
|
| 24 |
+
# from torch._six import string_classes
|
| 25 |
+
string_classes = str
|
| 26 |
+
|
| 27 |
+
IS_WINDOWS = sys.platform == "win32"
|
| 28 |
+
if IS_WINDOWS:
|
| 29 |
+
import ctypes
|
| 30 |
+
from ctypes.wintypes import DWORD, BOOL, HANDLE
|
| 31 |
+
|
| 32 |
+
if sys.version_info[0] == 2:
|
| 33 |
+
import Queue as queue
|
| 34 |
+
else:
|
| 35 |
+
import queue
|
| 36 |
+
|
| 37 |
+
__all__ = ['SequentialDataLoader']
|
| 38 |
+
|
| 39 |
+
class ExceptionWrapper(object):
|
| 40 |
+
r"""Wraps an exception plus traceback to communicate across threads"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, exc_info):
|
| 43 |
+
self.exc_type = exc_info[0]
|
| 44 |
+
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_use_shared_memory = False
|
| 48 |
+
r"""Whether to use shared memory in default_collate"""
|
| 49 |
+
|
| 50 |
+
MANAGER_STATUS_CHECK_INTERVAL = 5.0
|
| 51 |
+
|
| 52 |
+
if IS_WINDOWS:
|
| 53 |
+
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
| 54 |
+
# is gone, and the only way to check it through OS is to let the worker have a process handle
|
| 55 |
+
# of the manager and ask if the process status has changed.
|
| 56 |
+
class ManagerWatchdog(object):
|
| 57 |
+
def __init__(self):
|
| 58 |
+
self.manager_pid = os.getppid()
|
| 59 |
+
|
| 60 |
+
self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
|
| 61 |
+
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
|
| 62 |
+
self.kernel32.OpenProcess.restype = HANDLE
|
| 63 |
+
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
|
| 64 |
+
self.kernel32.WaitForSingleObject.restype = DWORD
|
| 65 |
+
|
| 66 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
|
| 67 |
+
SYNCHRONIZE = 0x00100000
|
| 68 |
+
self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
|
| 69 |
+
|
| 70 |
+
if not self.manager_handle:
|
| 71 |
+
raise ctypes.WinError(ctypes.get_last_error())
|
| 72 |
+
|
| 73 |
+
def is_alive(self):
|
| 74 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
| 75 |
+
return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0
|
| 76 |
+
else:
|
| 77 |
+
class ManagerWatchdog(object):
|
| 78 |
+
def __init__(self):
|
| 79 |
+
self.manager_pid = os.getppid()
|
| 80 |
+
|
| 81 |
+
def is_alive(self):
|
| 82 |
+
return os.getppid() == self.manager_pid
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _worker_loop(dataset, index_queue, data_queue, collate_fn, init_fn, worker_id):
|
| 86 |
+
global _use_shared_memory
|
| 87 |
+
_use_shared_memory = True
|
| 88 |
+
|
| 89 |
+
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
| 90 |
+
# module's handlers are executed after Python returns from C low-level
|
| 91 |
+
# handlers, likely when the same fatal signal happened again already.
|
| 92 |
+
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
|
| 93 |
+
_set_worker_signal_handlers()
|
| 94 |
+
|
| 95 |
+
torch.set_num_threads(1)
|
| 96 |
+
|
| 97 |
+
if init_fn is not None:
|
| 98 |
+
init_fn(worker_id)
|
| 99 |
+
|
| 100 |
+
watchdog = ManagerWatchdog()
|
| 101 |
+
|
| 102 |
+
while True:
|
| 103 |
+
try:
|
| 104 |
+
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
|
| 105 |
+
except queue.Empty:
|
| 106 |
+
if watchdog.is_alive():
|
| 107 |
+
continue
|
| 108 |
+
else:
|
| 109 |
+
break
|
| 110 |
+
if r is None:
|
| 111 |
+
break
|
| 112 |
+
idx, batch_indices = r
|
| 113 |
+
try:
|
| 114 |
+
samples = collate_fn([dataset[i] for i in batch_indices])
|
| 115 |
+
except Exception:
|
| 116 |
+
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
| 117 |
+
else:
|
| 118 |
+
data_queue.put((idx, samples))
|
| 119 |
+
del samples
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
|
| 123 |
+
if pin_memory:
|
| 124 |
+
torch.cuda.set_device(device_id)
|
| 125 |
+
|
| 126 |
+
while True:
|
| 127 |
+
try:
|
| 128 |
+
r = in_queue.get()
|
| 129 |
+
except Exception:
|
| 130 |
+
if done_event.is_set():
|
| 131 |
+
return
|
| 132 |
+
raise
|
| 133 |
+
if r is None:
|
| 134 |
+
break
|
| 135 |
+
if isinstance(r[1], ExceptionWrapper):
|
| 136 |
+
out_queue.put(r)
|
| 137 |
+
continue
|
| 138 |
+
idx, batch = r
|
| 139 |
+
try:
|
| 140 |
+
if pin_memory:
|
| 141 |
+
batch = pin_memory_batch(batch)
|
| 142 |
+
except Exception:
|
| 143 |
+
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
| 144 |
+
else:
|
| 145 |
+
out_queue.put((idx, batch))
|
| 146 |
+
|
| 147 |
+
numpy_type_map = {
|
| 148 |
+
'float64': torch.DoubleTensor,
|
| 149 |
+
'float32': torch.FloatTensor,
|
| 150 |
+
'float16': torch.HalfTensor,
|
| 151 |
+
'int64': torch.LongTensor,
|
| 152 |
+
'int32': torch.IntTensor,
|
| 153 |
+
'int16': torch.ShortTensor,
|
| 154 |
+
'int8': torch.CharTensor,
|
| 155 |
+
'uint8': torch.ByteTensor,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def default_collate(batch):
|
| 160 |
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
| 161 |
+
|
| 162 |
+
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
| 163 |
+
elem_type = type(batch[0])
|
| 164 |
+
if isinstance(batch[0], torch.Tensor):
|
| 165 |
+
out = None
|
| 166 |
+
if _use_shared_memory:
|
| 167 |
+
# If we're in a background process, concatenate directly into a
|
| 168 |
+
# shared memory tensor to avoid an extra copy
|
| 169 |
+
numel = sum([x.numel() for x in batch])
|
| 170 |
+
storage = batch[0].storage()._new_shared(numel)
|
| 171 |
+
out = batch[0].new(storage)
|
| 172 |
+
return torch.stack(batch, 0, out=out)
|
| 173 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
| 174 |
+
and elem_type.__name__ != 'string_':
|
| 175 |
+
elem = batch[0]
|
| 176 |
+
if elem_type.__name__ == 'ndarray':
|
| 177 |
+
# array of string classes and object
|
| 178 |
+
if re.search('[SaUO]', elem.dtype.str) is not None:
|
| 179 |
+
raise TypeError(error_msg.format(elem.dtype))
|
| 180 |
+
|
| 181 |
+
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
| 182 |
+
if elem.shape == (): # scalars
|
| 183 |
+
py_type = float if elem.dtype.name.startswith('float') else int
|
| 184 |
+
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
| 185 |
+
elif isinstance(batch[0], int):
|
| 186 |
+
return torch.LongTensor(batch)
|
| 187 |
+
elif isinstance(batch[0], float):
|
| 188 |
+
return torch.DoubleTensor(batch)
|
| 189 |
+
elif isinstance(batch[0], string_classes):
|
| 190 |
+
return batch
|
| 191 |
+
elif isinstance(batch[0], collections.abc.Mapping):
|
| 192 |
+
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
|
| 193 |
+
elif isinstance(batch[0], collections.abc.Sequence):
|
| 194 |
+
transposed = zip(*batch)
|
| 195 |
+
return [default_collate(samples) for samples in transposed]
|
| 196 |
+
|
| 197 |
+
raise TypeError((error_msg.format(type(batch[0]))))
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def pin_memory_batch(batch):
|
| 201 |
+
if isinstance(batch, torch.Tensor):
|
| 202 |
+
return batch.pin_memory()
|
| 203 |
+
elif isinstance(batch, string_classes):
|
| 204 |
+
return batch
|
| 205 |
+
elif isinstance(batch, collections.abc.Mapping):
|
| 206 |
+
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
|
| 207 |
+
elif isinstance(batch, collections.abc.Sequence):
|
| 208 |
+
return [pin_memory_batch(sample) for sample in batch]
|
| 209 |
+
else:
|
| 210 |
+
return batch
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
_SIGCHLD_handler_set = False
|
| 214 |
+
r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
|
| 215 |
+
handler needs to be set for all DataLoaders in a process."""
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def _set_SIGCHLD_handler():
|
| 219 |
+
# Windows doesn't support SIGCHLD handler
|
| 220 |
+
if sys.platform == 'win32':
|
| 221 |
+
return
|
| 222 |
+
# can't set signal in child threads
|
| 223 |
+
if not isinstance(threading.current_thread(), threading._MainThread):
|
| 224 |
+
return
|
| 225 |
+
global _SIGCHLD_handler_set
|
| 226 |
+
if _SIGCHLD_handler_set:
|
| 227 |
+
return
|
| 228 |
+
previous_handler = signal.getsignal(signal.SIGCHLD)
|
| 229 |
+
if not callable(previous_handler):
|
| 230 |
+
previous_handler = None
|
| 231 |
+
|
| 232 |
+
def handler(signum, frame):
|
| 233 |
+
# This following call uses `waitid` with WNOHANG from C side. Therefore,
|
| 234 |
+
# Python can still get and update the process status successfully.
|
| 235 |
+
_error_if_any_worker_fails()
|
| 236 |
+
if previous_handler is not None:
|
| 237 |
+
previous_handler(signum, frame)
|
| 238 |
+
|
| 239 |
+
signal.signal(signal.SIGCHLD, handler)
|
| 240 |
+
_SIGCHLD_handler_set = True
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class _SequentialDataLoaderIter(object):
|
| 244 |
+
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
|
| 245 |
+
|
| 246 |
+
def __init__(self, loader):
|
| 247 |
+
self.dataset = loader.dataset
|
| 248 |
+
self.collate_fn = loader.collate_fn
|
| 249 |
+
self.batch_sampler = loader.batch_sampler
|
| 250 |
+
self.num_workers = loader.num_workers
|
| 251 |
+
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
|
| 252 |
+
self.timeout = loader.timeout
|
| 253 |
+
self.done_event = threading.Event()
|
| 254 |
+
|
| 255 |
+
self.sample_iter = iter(self.batch_sampler)
|
| 256 |
+
|
| 257 |
+
if self.num_workers > 0:
|
| 258 |
+
self.worker_init_fn = loader.worker_init_fn
|
| 259 |
+
self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
|
| 260 |
+
self.worker_queue_idx = 0
|
| 261 |
+
self.worker_result_queue = multiprocessing.SimpleQueue()
|
| 262 |
+
self.batches_outstanding = 0
|
| 263 |
+
self.worker_pids_set = False
|
| 264 |
+
self.shutdown = False
|
| 265 |
+
self.send_idx = 0
|
| 266 |
+
self.rcvd_idx = 0
|
| 267 |
+
self.reorder_dict = {}
|
| 268 |
+
|
| 269 |
+
self.workers = [
|
| 270 |
+
multiprocessing.Process(
|
| 271 |
+
target=_worker_loop,
|
| 272 |
+
args=(self.dataset, self.index_queues[i],
|
| 273 |
+
self.worker_result_queue, self.collate_fn, self.worker_init_fn, i))
|
| 274 |
+
for i in range(self.num_workers)]
|
| 275 |
+
|
| 276 |
+
if self.pin_memory or self.timeout > 0:
|
| 277 |
+
self.data_queue = queue.Queue()
|
| 278 |
+
if self.pin_memory:
|
| 279 |
+
maybe_device_id = torch.cuda.current_device()
|
| 280 |
+
else:
|
| 281 |
+
# do not initialize cuda context if not necessary
|
| 282 |
+
maybe_device_id = None
|
| 283 |
+
self.worker_manager_thread = threading.Thread(
|
| 284 |
+
target=_worker_manager_loop,
|
| 285 |
+
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
|
| 286 |
+
maybe_device_id))
|
| 287 |
+
self.worker_manager_thread.daemon = True
|
| 288 |
+
self.worker_manager_thread.start()
|
| 289 |
+
else:
|
| 290 |
+
self.data_queue = self.worker_result_queue
|
| 291 |
+
|
| 292 |
+
for w in self.workers:
|
| 293 |
+
w.daemon = True # ensure that the worker exits on process exit
|
| 294 |
+
w.start()
|
| 295 |
+
|
| 296 |
+
_set_worker_pids(id(self), tuple(w.pid for w in self.workers))
|
| 297 |
+
_set_SIGCHLD_handler()
|
| 298 |
+
self.worker_pids_set = True
|
| 299 |
+
|
| 300 |
+
# prime the prefetch loop
|
| 301 |
+
for _ in range(2 * self.num_workers):
|
| 302 |
+
self._put_indices()
|
| 303 |
+
|
| 304 |
+
def __len__(self):
|
| 305 |
+
return len(self.batch_sampler)
|
| 306 |
+
|
| 307 |
+
def _get_batch(self):
|
| 308 |
+
if self.timeout > 0:
|
| 309 |
+
try:
|
| 310 |
+
return self.data_queue.get(timeout=self.timeout)
|
| 311 |
+
except queue.Empty:
|
| 312 |
+
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
|
| 313 |
+
else:
|
| 314 |
+
return self.data_queue.get()
|
| 315 |
+
|
| 316 |
+
def __next__(self):
|
| 317 |
+
if self.num_workers == 0: # same-process loading
|
| 318 |
+
indices = next(self.sample_iter) # may raise StopIteration
|
| 319 |
+
batch = self.collate_fn([self.dataset[i] for i in indices])
|
| 320 |
+
if self.pin_memory:
|
| 321 |
+
batch = pin_memory_batch(batch)
|
| 322 |
+
return batch
|
| 323 |
+
|
| 324 |
+
# check if the next sample has already been generated
|
| 325 |
+
if self.rcvd_idx in self.reorder_dict:
|
| 326 |
+
batch = self.reorder_dict.pop(self.rcvd_idx)
|
| 327 |
+
return self._process_next_batch(batch)
|
| 328 |
+
|
| 329 |
+
if self.batches_outstanding == 0:
|
| 330 |
+
self._shutdown_workers()
|
| 331 |
+
raise StopIteration
|
| 332 |
+
|
| 333 |
+
while True:
|
| 334 |
+
assert (not self.shutdown and self.batches_outstanding > 0)
|
| 335 |
+
idx, batch = self._get_batch()
|
| 336 |
+
self.batches_outstanding -= 1
|
| 337 |
+
if idx != self.rcvd_idx:
|
| 338 |
+
# store out-of-order samples
|
| 339 |
+
self.reorder_dict[idx] = batch
|
| 340 |
+
continue
|
| 341 |
+
return self._process_next_batch(batch)
|
| 342 |
+
|
| 343 |
+
next = __next__ # Python 2 compatibility
|
| 344 |
+
|
| 345 |
+
def __iter__(self):
|
| 346 |
+
return self
|
| 347 |
+
|
| 348 |
+
def _put_indices(self):
|
| 349 |
+
assert self.batches_outstanding < 2 * self.num_workers
|
| 350 |
+
indices = next(self.sample_iter, None)
|
| 351 |
+
if indices is None:
|
| 352 |
+
return
|
| 353 |
+
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
|
| 354 |
+
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
|
| 355 |
+
self.batches_outstanding += 1
|
| 356 |
+
self.send_idx += 1
|
| 357 |
+
|
| 358 |
+
def _process_next_batch(self, batch):
|
| 359 |
+
self.rcvd_idx += 1
|
| 360 |
+
self._put_indices()
|
| 361 |
+
if isinstance(batch, ExceptionWrapper):
|
| 362 |
+
raise batch.exc_type(batch.exc_msg)
|
| 363 |
+
return batch
|
| 364 |
+
|
| 365 |
+
def __getstate__(self):
|
| 366 |
+
# TODO: add limited pickling support for sharing an iterator
|
| 367 |
+
# across multiple threads for HOGWILD.
|
| 368 |
+
# Probably the best way to do this is by moving the sample pushing
|
| 369 |
+
# to a separate thread and then just sharing the data queue
|
| 370 |
+
# but signalling the end is tricky without a non-blocking API
|
| 371 |
+
raise NotImplementedError("_SequentialDataLoaderIter cannot be pickled")
|
| 372 |
+
|
| 373 |
+
def _shutdown_workers(self):
|
| 374 |
+
try:
|
| 375 |
+
if not self.shutdown:
|
| 376 |
+
self.shutdown = True
|
| 377 |
+
self.done_event.set()
|
| 378 |
+
for q in self.index_queues:
|
| 379 |
+
q.put(None)
|
| 380 |
+
# if some workers are waiting to put, make place for them
|
| 381 |
+
try:
|
| 382 |
+
while not self.worker_result_queue.empty():
|
| 383 |
+
self.worker_result_queue.get()
|
| 384 |
+
except (FileNotFoundError, ImportError):
|
| 385 |
+
# Many weird errors can happen here due to Python
|
| 386 |
+
# shutting down. These are more like obscure Python bugs.
|
| 387 |
+
# FileNotFoundError can happen when we rebuild the fd
|
| 388 |
+
# fetched from the queue but the socket is already closed
|
| 389 |
+
# from the worker side.
|
| 390 |
+
# ImportError can happen when the unpickler loads the
|
| 391 |
+
# resource from `get`.
|
| 392 |
+
pass
|
| 393 |
+
# done_event should be sufficient to exit worker_manager_thread,
|
| 394 |
+
# but be safe here and put another None
|
| 395 |
+
self.worker_result_queue.put(None)
|
| 396 |
+
finally:
|
| 397 |
+
# removes pids no matter what
|
| 398 |
+
if self.worker_pids_set:
|
| 399 |
+
_remove_worker_pids(id(self))
|
| 400 |
+
self.worker_pids_set = False
|
| 401 |
+
|
| 402 |
+
def __del__(self):
|
| 403 |
+
if self.num_workers > 0:
|
| 404 |
+
self._shutdown_workers()
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class SequentialDataLoader(object):
|
| 408 |
+
r"""
|
| 409 |
+
Sequential Data loader. Combines a dataset and a sampler, and provides
|
| 410 |
+
single- or multi-process iterators over the dataset.
|
| 411 |
+
This is modified from Pytorch.DataLoader by disable random state touch as for sequential data loading,
|
| 412 |
+
we don't want it to touch any random state.
|
| 413 |
+
Arguments:
|
| 414 |
+
dataset (Dataset): dataset from which to load the data.
|
| 415 |
+
batch_size (int, optional): how many samples per batch to load
|
| 416 |
+
(default: 1).
|
| 417 |
+
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
| 418 |
+
at every epoch (default: False).
|
| 419 |
+
sampler (Sampler, optional): defines the strategy to draw samples from
|
| 420 |
+
the dataset. If specified, ``shuffle`` must be False.
|
| 421 |
+
batch_sampler (Sampler, optional): like sampler, but returns a batch of
|
| 422 |
+
indices at a time. Mutually exclusive with batch_size, shuffle,
|
| 423 |
+
sampler, and drop_last.
|
| 424 |
+
num_workers (int, optional): how many subprocesses to use for data
|
| 425 |
+
loading. 0 means that the data will be loaded in the main process.
|
| 426 |
+
(default: 0)
|
| 427 |
+
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
|
| 428 |
+
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
|
| 429 |
+
into CUDA pinned memory before returning them.
|
| 430 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
| 431 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
| 432 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
| 433 |
+
will be smaller. (default: False)
|
| 434 |
+
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
| 435 |
+
from workers. Should always be non-negative. (default: 0)
|
| 436 |
+
worker_init_fn (callable, optional): If not None, this will be called on each
|
| 437 |
+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
| 438 |
+
input, after seeding and before data loading. (default: None)
|
| 439 |
+
|
| 440 |
+
.. note:: By default, each worker will have its PyTorch seed set to
|
| 441 |
+
``base_seed + worker_id``, where ``base_seed`` is a long generated
|
| 442 |
+
by main process using its RNG. However, seeds for other libraies
|
| 443 |
+
may be duplicated upon initializing workers (w.g., NumPy), causing
|
| 444 |
+
each worker to return identical random numbers. (See
|
| 445 |
+
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
|
| 446 |
+
use ``torch.initial_seed()`` to access the PyTorch seed for each
|
| 447 |
+
worker in :attr:`worker_init_fn`, and use it to set other seeds
|
| 448 |
+
before data loading.
|
| 449 |
+
|
| 450 |
+
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
|
| 451 |
+
unpicklable object, e.g., a lambda function.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
__initialized = False
|
| 455 |
+
|
| 456 |
+
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
| 457 |
+
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
|
| 458 |
+
timeout=0, worker_init_fn=None):
|
| 459 |
+
self.dataset = dataset
|
| 460 |
+
self.batch_size = batch_size
|
| 461 |
+
self.num_workers = num_workers
|
| 462 |
+
self.collate_fn = collate_fn
|
| 463 |
+
self.pin_memory = pin_memory
|
| 464 |
+
self.drop_last = drop_last
|
| 465 |
+
self.timeout = timeout
|
| 466 |
+
self.worker_init_fn = worker_init_fn
|
| 467 |
+
|
| 468 |
+
if timeout < 0:
|
| 469 |
+
raise ValueError('timeout option should be non-negative')
|
| 470 |
+
|
| 471 |
+
if batch_sampler is not None:
|
| 472 |
+
if batch_size > 1 or shuffle or sampler is not None or drop_last:
|
| 473 |
+
raise ValueError('batch_sampler option is mutually exclusive '
|
| 474 |
+
'with batch_size, shuffle, sampler, and '
|
| 475 |
+
'drop_last')
|
| 476 |
+
self.batch_size = None
|
| 477 |
+
self.drop_last = None
|
| 478 |
+
|
| 479 |
+
if sampler is not None and shuffle:
|
| 480 |
+
raise ValueError('sampler option is mutually exclusive with '
|
| 481 |
+
'shuffle')
|
| 482 |
+
|
| 483 |
+
if self.num_workers < 0:
|
| 484 |
+
raise ValueError('num_workers option cannot be negative; '
|
| 485 |
+
'use num_workers=0 to disable multiprocessing.')
|
| 486 |
+
|
| 487 |
+
if batch_sampler is None:
|
| 488 |
+
if sampler is None:
|
| 489 |
+
if shuffle:
|
| 490 |
+
sampler = RandomSampler(dataset)
|
| 491 |
+
else:
|
| 492 |
+
sampler = SequentialSampler(dataset)
|
| 493 |
+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
| 494 |
+
|
| 495 |
+
self.sampler = sampler
|
| 496 |
+
self.batch_sampler = batch_sampler
|
| 497 |
+
self.__initialized = True
|
| 498 |
+
|
| 499 |
+
def __setattr__(self, attr, val):
|
| 500 |
+
if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
|
| 501 |
+
raise ValueError('{} attribute should not be set after {} is '
|
| 502 |
+
'initialized'.format(attr, self.__class__.__name__))
|
| 503 |
+
|
| 504 |
+
super(SequentialDataLoader, self).__setattr__(attr, val)
|
| 505 |
+
|
| 506 |
+
def __iter__(self):
|
| 507 |
+
return _SequentialDataLoaderIter(self)
|
| 508 |
+
|
| 509 |
+
def __len__(self):
|
| 510 |
+
return len(self.batch_sampler)
|
| 511 |
+
|
nlu/DeBERTa/data/dynamic_dataset.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
#
|
| 6 |
+
# Author: penhe@microsoft.com
|
| 7 |
+
# Date: 05/15/2019
|
| 8 |
+
#
|
| 9 |
+
|
| 10 |
+
import pdb
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
import random
|
| 13 |
+
import mmap
|
| 14 |
+
import numpy as np
|
| 15 |
+
from bisect import bisect
|
| 16 |
+
from ..utils import get_logger
|
| 17 |
+
logger=get_logger()
|
| 18 |
+
|
| 19 |
+
__all__ = ['DynamicDataset']
|
| 20 |
+
|
| 21 |
+
class DynamicDataset(Dataset):
|
| 22 |
+
def __init__(self, corpus, feature_fn, dataset_size=None, shuffle=False, **kwargs):
|
| 23 |
+
self.corpus = corpus
|
| 24 |
+
self.ds_len = len(self.corpus)
|
| 25 |
+
logger.info(f'Total corpus examples: {self.ds_len}')
|
| 26 |
+
self.feature_fn = feature_fn
|
| 27 |
+
|
| 28 |
+
if not dataset_size:
|
| 29 |
+
self.dataset_size = self.ds_len
|
| 30 |
+
else:
|
| 31 |
+
self.dataset_size = int(dataset_size)
|
| 32 |
+
|
| 33 |
+
self.shuffle = shuffle
|
| 34 |
+
index_buf = mmap.mmap(-1, self.dataset_size*8)
|
| 35 |
+
shuffle_idx = np.ndarray(shape=(self.dataset_size, ), buffer=index_buf, dtype=int)
|
| 36 |
+
shuffle_idx[:] = np.arange(self.dataset_size)[:]
|
| 37 |
+
if self.shuffle:
|
| 38 |
+
#rng = np.random.RandomState(0)
|
| 39 |
+
rng = random.Random(0)
|
| 40 |
+
rng.shuffle(shuffle_idx)
|
| 41 |
+
self.shuffle_idx = shuffle_idx
|
| 42 |
+
self.index_offset = 0
|
| 43 |
+
if 'index_offset' in kwargs:
|
| 44 |
+
self.index_offset = kwargs['index_offset']
|
| 45 |
+
|
| 46 |
+
def __len__(self):
|
| 47 |
+
return self.dataset_size
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, idx):
|
| 50 |
+
if isinstance(idx, tuple) or isinstance(idx, list):
|
| 51 |
+
idx, ext_params = idx
|
| 52 |
+
else:
|
| 53 |
+
ext_params = None
|
| 54 |
+
idx += self.index_offset
|
| 55 |
+
seed = idx
|
| 56 |
+
rng = random.Random(seed)
|
| 57 |
+
# get seq length
|
| 58 |
+
example_idx = self.shuffle_idx[idx%self.dataset_size]%self.ds_len
|
| 59 |
+
example = self.corpus[example_idx, rng, ext_params]
|
| 60 |
+
return self.feature_fn(example, rng, ext_params = ext_params)
|
nlu/DeBERTa/data/example.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
import numpy as np
|
| 5 |
+
import tempfile
|
| 6 |
+
import numpy as np
|
| 7 |
+
import mmap
|
| 8 |
+
import pickle
|
| 9 |
+
import signal
|
| 10 |
+
import sys
|
| 11 |
+
import pdb
|
| 12 |
+
|
| 13 |
+
from ..utils import xtqdm as tqdm
|
| 14 |
+
|
| 15 |
+
__all__=['ExampleInstance', 'example_to_feature', 'ExampleSet']
|
| 16 |
+
|
| 17 |
+
class ExampleInstance:
|
| 18 |
+
def __init__(self, segments, label=None, **kwv):
|
| 19 |
+
self.segments = segments
|
| 20 |
+
self.label = label
|
| 21 |
+
self.__dict__.update(kwv)
|
| 22 |
+
|
| 23 |
+
def __repr__(self):
|
| 24 |
+
return f'segments: {self.segments}\nlabel: {self.label}'
|
| 25 |
+
|
| 26 |
+
def __getitem__(self, i):
|
| 27 |
+
return self.segments[i]
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return len(self.segments)
|
| 31 |
+
|
| 32 |
+
class ExampleSet:
|
| 33 |
+
def __init__(self, pairs):
|
| 34 |
+
self._data = np.array([pickle.dumps(p) for p in pairs])
|
| 35 |
+
self.total = len(self._data)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx):
|
| 38 |
+
"""
|
| 39 |
+
return pair
|
| 40 |
+
"""
|
| 41 |
+
if isinstance(idx, tuple):
|
| 42 |
+
idx,rng, ext_params = idx
|
| 43 |
+
else:
|
| 44 |
+
rng,ext_params=None, None
|
| 45 |
+
content = self._data[idx]
|
| 46 |
+
example = pickle.loads(content)
|
| 47 |
+
return example
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return self.total
|
| 51 |
+
|
| 52 |
+
def __iter__(self):
|
| 53 |
+
for i in range(self.total):
|
| 54 |
+
yield self[i]
|
| 55 |
+
|
| 56 |
+
def _truncate_segments(segments, max_num_tokens, rng):
|
| 57 |
+
"""
|
| 58 |
+
Truncate sequence pair according to original BERT implementation:
|
| 59 |
+
https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L391
|
| 60 |
+
"""
|
| 61 |
+
while True:
|
| 62 |
+
if sum(len(s) for s in segments)<=max_num_tokens:
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
segments = sorted(segments, key=lambda s:len(s), reverse=True)
|
| 66 |
+
trunc_tokens = segments[0]
|
| 67 |
+
|
| 68 |
+
assert len(trunc_tokens) >= 1
|
| 69 |
+
|
| 70 |
+
if rng.random() < 0.5:
|
| 71 |
+
trunc_tokens.pop(0)
|
| 72 |
+
else:
|
| 73 |
+
trunc_tokens.pop()
|
| 74 |
+
return segments
|
| 75 |
+
|
| 76 |
+
def example_to_feature(tokenizer, example, max_seq_len=512, rng=None, mask_generator = None, ext_params=None, label_type='int', **kwargs):
|
| 77 |
+
if not rng:
|
| 78 |
+
rng = random
|
| 79 |
+
max_num_tokens = max_seq_len - len(example.segments) - 1
|
| 80 |
+
segments = _truncate_segments([tokenizer.tokenize(s) for s in example.segments], max_num_tokens, rng)
|
| 81 |
+
tokens = ['[CLS]']
|
| 82 |
+
type_ids = [0]
|
| 83 |
+
for i,s in enumerate(segments):
|
| 84 |
+
tokens.extend(s)
|
| 85 |
+
tokens.append('[SEP]')
|
| 86 |
+
type_ids.extend([i]*(len(s)+1))
|
| 87 |
+
if mask_generator:
|
| 88 |
+
tokens, lm_labels = mask_generator.mask_tokens(tokens, rng)
|
| 89 |
+
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 90 |
+
pos_ids = list(range(len(token_ids)))
|
| 91 |
+
input_mask = [1]*len(token_ids)
|
| 92 |
+
features = OrderedDict(input_ids = token_ids,
|
| 93 |
+
type_ids = type_ids,
|
| 94 |
+
position_ids = pos_ids,
|
| 95 |
+
input_mask = input_mask)
|
| 96 |
+
if mask_generator:
|
| 97 |
+
features['lm_labels'] = lm_labels
|
| 98 |
+
padding_size = max(0, max_seq_len - len(token_ids))
|
| 99 |
+
for f in features:
|
| 100 |
+
features[f].extend([0]*padding_size)
|
| 101 |
+
features[f] = torch.tensor(features[f], dtype=torch.int)
|
| 102 |
+
label_type = torch.int if label_type=='int' else torch.float
|
| 103 |
+
if example.label is not None:
|
| 104 |
+
features['labels'] = torch.tensor(example.label, dtype=label_type)
|
| 105 |
+
return features
|
nlu/DeBERTa/deberta/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Author: penhe@microsoft.com
|
| 3 |
+
# Date: 04/25/2019
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
""" Components for NN
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import absolute_import
|
| 10 |
+
from __future__ import division
|
| 11 |
+
from __future__ import print_function
|
| 12 |
+
|
| 13 |
+
from .tokenizers import *
|
| 14 |
+
from .pooling import *
|
| 15 |
+
from .mlm import MLMPredictionHead
|
| 16 |
+
from .nnmodule import NNModule
|
| 17 |
+
from .deberta import *
|
| 18 |
+
from .disentangled_attention import *
|
| 19 |
+
from .ops import *
|
| 20 |
+
from .bert import *
|
| 21 |
+
from .config import *
|
| 22 |
+
from .cache_utils import *
|
nlu/DeBERTa/deberta/bert.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 2 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This piece of code is modified based on https://github.com/huggingface/transformers
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
from collections.abc import Sequence
|
| 13 |
+
from packaging import version
|
| 14 |
+
import numpy as np
|
| 15 |
+
import math
|
| 16 |
+
import os
|
| 17 |
+
import pdb
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
from .ops import *
|
| 21 |
+
from .disentangled_attention import *
|
| 22 |
+
from .da_utils import *
|
| 23 |
+
|
| 24 |
+
from adapterlib import adapter_dict
|
| 25 |
+
|
| 26 |
+
__all__ = ['BertEncoder', 'BertEmbeddings', 'ACT2FN', 'LayerNorm', 'BertLMPredictionHead']
|
| 27 |
+
|
| 28 |
+
class BertSelfOutput(nn.Module):
|
| 29 |
+
def __init__(self, config):
|
| 30 |
+
super().__init__()
|
| 31 |
+
# self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 32 |
+
if config.inject_adapter != 'linear':
|
| 33 |
+
self.dense = adapter_dict[config.inject_adapter](config.hidden_size, config.hidden_size, config=config)
|
| 34 |
+
else:
|
| 35 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 36 |
+
|
| 37 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 38 |
+
self.dropout = StableDropout(config.hidden_dropout_prob)
|
| 39 |
+
self.config = config
|
| 40 |
+
|
| 41 |
+
def forward(self, hidden_states, input_states, mask=None):
|
| 42 |
+
hidden_states = self.dense(hidden_states)
|
| 43 |
+
hidden_states = self.dropout(hidden_states)
|
| 44 |
+
hidden_states += input_states
|
| 45 |
+
hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
|
| 46 |
+
return hidden_states
|
| 47 |
+
|
| 48 |
+
class BertAttention(nn.Module):
|
| 49 |
+
def __init__(self, config):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.self = DisentangledSelfAttention(config)
|
| 52 |
+
self.output = BertSelfOutput(config)
|
| 53 |
+
self.config = config
|
| 54 |
+
|
| 55 |
+
def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
|
| 56 |
+
output = self.self(hidden_states, attention_mask, return_att, query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
|
| 57 |
+
self_output, att_matrix, att_logits_=output['hidden_states'], output['attention_probs'], output['attention_logits']
|
| 58 |
+
if query_states is None:
|
| 59 |
+
query_states = hidden_states
|
| 60 |
+
attention_output = self.output(self_output, query_states, attention_mask)
|
| 61 |
+
|
| 62 |
+
if return_att:
|
| 63 |
+
return (attention_output, att_matrix)
|
| 64 |
+
else:
|
| 65 |
+
return attention_output
|
| 66 |
+
|
| 67 |
+
class BertIntermediate(nn.Module):
|
| 68 |
+
def __init__(self, config):
|
| 69 |
+
super().__init__()
|
| 70 |
+
# self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 71 |
+
if config.inject_adapter != 'linear':
|
| 72 |
+
self.dense = adapter_dict[config.inject_adapter](config.hidden_size, config.intermediate_size, config=config)
|
| 73 |
+
else:
|
| 74 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 75 |
+
|
| 76 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
| 77 |
+
if isinstance(config.hidden_act, str) else config.hidden_act
|
| 78 |
+
|
| 79 |
+
def forward(self, hidden_states):
|
| 80 |
+
hidden_states = self.dense(hidden_states)
|
| 81 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 82 |
+
return hidden_states
|
| 83 |
+
|
| 84 |
+
class BertOutput(nn.Module):
|
| 85 |
+
def __init__(self, config):
|
| 86 |
+
super(BertOutput, self).__init__()
|
| 87 |
+
# self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 88 |
+
if config.inject_adapter != 'linear':
|
| 89 |
+
self.dense = adapter_dict[config.inject_adapter](config.intermediate_size, config.hidden_size, config=config)
|
| 90 |
+
else:
|
| 91 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 92 |
+
|
| 93 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 94 |
+
self.dropout = StableDropout(config.hidden_dropout_prob)
|
| 95 |
+
self.config = config
|
| 96 |
+
|
| 97 |
+
def forward(self, hidden_states, input_states, mask=None):
|
| 98 |
+
hidden_states = self.dense(hidden_states)
|
| 99 |
+
hidden_states = self.dropout(hidden_states)
|
| 100 |
+
hidden_states += input_states
|
| 101 |
+
hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
|
| 102 |
+
return hidden_states
|
| 103 |
+
|
| 104 |
+
class BertLayer(nn.Module):
|
| 105 |
+
def __init__(self, config):
|
| 106 |
+
super(BertLayer, self).__init__()
|
| 107 |
+
self.attention = BertAttention(config)
|
| 108 |
+
self.intermediate = BertIntermediate(config)
|
| 109 |
+
self.output = BertOutput(config)
|
| 110 |
+
|
| 111 |
+
def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
|
| 112 |
+
attention_output = self.attention(hidden_states, attention_mask, return_att=return_att, \
|
| 113 |
+
query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
|
| 114 |
+
if return_att:
|
| 115 |
+
attention_output, att_matrix = attention_output
|
| 116 |
+
intermediate_output = self.intermediate(attention_output)
|
| 117 |
+
layer_output = self.output(intermediate_output, attention_output, attention_mask)
|
| 118 |
+
if return_att:
|
| 119 |
+
return (layer_output, att_matrix)
|
| 120 |
+
else:
|
| 121 |
+
return layer_output
|
| 122 |
+
|
| 123 |
+
class ConvLayer(nn.Module):
|
| 124 |
+
def __init__(self, config):
|
| 125 |
+
super().__init__()
|
| 126 |
+
kernel_size = getattr(config, 'conv_kernel_size', 3)
|
| 127 |
+
groups = getattr(config, 'conv_groups', 1)
|
| 128 |
+
self.conv_act = getattr(config, 'conv_act', 'tanh')
|
| 129 |
+
self.conv = torch.nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size, padding = (kernel_size-1)//2, groups = groups)
|
| 130 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 131 |
+
self.dropout = StableDropout(config.hidden_dropout_prob)
|
| 132 |
+
self.config = config
|
| 133 |
+
|
| 134 |
+
def forward(self, hidden_states, residual_states, input_mask):
|
| 135 |
+
out = self.conv(hidden_states.permute(0,2,1).contiguous()).permute(0,2,1).contiguous()
|
| 136 |
+
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
|
| 137 |
+
rmask = (1-input_mask).bool()
|
| 138 |
+
else:
|
| 139 |
+
rmask = (1-input_mask).byte()
|
| 140 |
+
out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
|
| 141 |
+
out = ACT2FN[self.conv_act](self.dropout(out))
|
| 142 |
+
output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask)
|
| 143 |
+
|
| 144 |
+
return output_states
|
| 145 |
+
|
| 146 |
+
class BertEncoder(nn.Module):
|
| 147 |
+
""" Modified BertEncoder with relative position bias support
|
| 148 |
+
"""
|
| 149 |
+
def __init__(self, config):
|
| 150 |
+
super().__init__()
|
| 151 |
+
#layer = BertLayer(config)
|
| 152 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
| 153 |
+
self.relative_attention = getattr(config, 'relative_attention', False)
|
| 154 |
+
if self.relative_attention:
|
| 155 |
+
self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
|
| 156 |
+
if self.max_relative_positions <1:
|
| 157 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 158 |
+
self.position_buckets = getattr(config, 'position_buckets', -1)
|
| 159 |
+
pos_ebd_size = self.max_relative_positions*2
|
| 160 |
+
if self.position_buckets>0:
|
| 161 |
+
pos_ebd_size = self.position_buckets*2
|
| 162 |
+
self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
|
| 163 |
+
|
| 164 |
+
self.norm_rel_ebd = [x.strip() for x in getattr(config, 'norm_rel_ebd', 'none').lower().split('|')]
|
| 165 |
+
if 'layer_norm' in self.norm_rel_ebd:
|
| 166 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine = True)
|
| 167 |
+
kernel_size = getattr(config, 'conv_kernel_size', 0)
|
| 168 |
+
self.with_conv = False
|
| 169 |
+
if kernel_size > 0:
|
| 170 |
+
self.with_conv = True
|
| 171 |
+
self.conv = ConvLayer(config)
|
| 172 |
+
|
| 173 |
+
def get_rel_embedding(self):
|
| 174 |
+
rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
|
| 175 |
+
if rel_embeddings is not None and ('layer_norm' in self.norm_rel_ebd):
|
| 176 |
+
rel_embeddings = self.LayerNorm(rel_embeddings)
|
| 177 |
+
return rel_embeddings
|
| 178 |
+
|
| 179 |
+
def get_attention_mask(self, attention_mask):
|
| 180 |
+
if attention_mask.dim()<=2:
|
| 181 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 182 |
+
attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1)
|
| 183 |
+
attention_mask = attention_mask.byte()
|
| 184 |
+
elif attention_mask.dim()==3:
|
| 185 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 186 |
+
|
| 187 |
+
return attention_mask
|
| 188 |
+
|
| 189 |
+
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
| 190 |
+
if self.relative_attention and relative_pos is None:
|
| 191 |
+
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
|
| 192 |
+
relative_pos = build_relative_position(q, hidden_states.size(-2), bucket_size = self.position_buckets, \
|
| 193 |
+
max_position=self.max_relative_positions, device = hidden_states.device)
|
| 194 |
+
return relative_pos
|
| 195 |
+
|
| 196 |
+
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None):
|
| 197 |
+
if attention_mask.dim()<=2:
|
| 198 |
+
input_mask = attention_mask
|
| 199 |
+
else:
|
| 200 |
+
input_mask = (attention_mask.sum(-2)>0).byte()
|
| 201 |
+
attention_mask = self.get_attention_mask(attention_mask)
|
| 202 |
+
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
|
| 203 |
+
|
| 204 |
+
all_encoder_layers = []
|
| 205 |
+
att_matrices = []
|
| 206 |
+
if isinstance(hidden_states, Sequence):
|
| 207 |
+
next_kv = hidden_states[0]
|
| 208 |
+
else:
|
| 209 |
+
next_kv = hidden_states
|
| 210 |
+
rel_embeddings = self.get_rel_embedding()
|
| 211 |
+
for i, layer_module in enumerate(self.layer):
|
| 212 |
+
output_states = layer_module(next_kv, attention_mask, return_att, query_states = query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
|
| 213 |
+
if return_att:
|
| 214 |
+
output_states, att_m = output_states
|
| 215 |
+
|
| 216 |
+
if i == 0 and self.with_conv:
|
| 217 |
+
prenorm = output_states #output['prenorm_states']
|
| 218 |
+
output_states = self.conv(hidden_states, prenorm, input_mask)
|
| 219 |
+
|
| 220 |
+
if query_states is not None:
|
| 221 |
+
query_states = output_states
|
| 222 |
+
if isinstance(hidden_states, Sequence):
|
| 223 |
+
next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None
|
| 224 |
+
else:
|
| 225 |
+
next_kv = output_states
|
| 226 |
+
|
| 227 |
+
if output_all_encoded_layers:
|
| 228 |
+
all_encoder_layers.append(output_states)
|
| 229 |
+
if return_att:
|
| 230 |
+
att_matrices.append(att_m)
|
| 231 |
+
if not output_all_encoded_layers:
|
| 232 |
+
all_encoder_layers.append(output_states)
|
| 233 |
+
if return_att:
|
| 234 |
+
att_matrices.append(att_m)
|
| 235 |
+
return {
|
| 236 |
+
'hidden_states': all_encoder_layers,
|
| 237 |
+
'attention_matrices': att_matrices
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
class BertEmbeddings(nn.Module):
|
| 241 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
| 242 |
+
"""
|
| 243 |
+
def __init__(self, config):
|
| 244 |
+
super(BertEmbeddings, self).__init__()
|
| 245 |
+
padding_idx = getattr(config, 'padding_idx', 0)
|
| 246 |
+
self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
|
| 247 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx = padding_idx)
|
| 248 |
+
self.position_biased_input = getattr(config, 'position_biased_input', True)
|
| 249 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
|
| 250 |
+
|
| 251 |
+
if config.type_vocab_size>0:
|
| 252 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
|
| 253 |
+
|
| 254 |
+
if self.embedding_size != config.hidden_size:
|
| 255 |
+
self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
|
| 256 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 257 |
+
self.dropout = StableDropout(config.hidden_dropout_prob)
|
| 258 |
+
self.output_to_half = False
|
| 259 |
+
self.config = config
|
| 260 |
+
|
| 261 |
+
def forward(self, input_ids, token_type_ids=None, position_ids=None, mask = None):
|
| 262 |
+
seq_length = input_ids.size(1)
|
| 263 |
+
if position_ids is None:
|
| 264 |
+
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
|
| 265 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
| 266 |
+
if token_type_ids is None:
|
| 267 |
+
token_type_ids = torch.zeros_like(input_ids)
|
| 268 |
+
|
| 269 |
+
words_embeddings = self.word_embeddings(input_ids)
|
| 270 |
+
position_embeddings = self.position_embeddings(position_ids.long())
|
| 271 |
+
|
| 272 |
+
embeddings = words_embeddings
|
| 273 |
+
if self.config.type_vocab_size>0:
|
| 274 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 275 |
+
embeddings += token_type_embeddings
|
| 276 |
+
|
| 277 |
+
if self.position_biased_input:
|
| 278 |
+
embeddings += position_embeddings
|
| 279 |
+
|
| 280 |
+
if self.embedding_size != self.config.hidden_size:
|
| 281 |
+
embeddings = self.embed_proj(embeddings)
|
| 282 |
+
embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask)
|
| 283 |
+
embeddings = self.dropout(embeddings)
|
| 284 |
+
return {
|
| 285 |
+
'embeddings': embeddings,
|
| 286 |
+
'position_embeddings': position_embeddings}
|
| 287 |
+
|
| 288 |
+
class BertLMPredictionHead(nn.Module):
|
| 289 |
+
def __init__(self, config, vocab_size):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
|
| 292 |
+
self.dense = nn.Linear(config.hidden_size, self.embedding_size)
|
| 293 |
+
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
| 294 |
+
if isinstance(config.hidden_act, str) else config.hidden_act
|
| 295 |
+
|
| 296 |
+
self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps, elementwise_affine=True)
|
| 297 |
+
|
| 298 |
+
self.bias = nn.Parameter(torch.zeros(vocab_size))
|
| 299 |
+
|
| 300 |
+
def forward(self, hidden_states, embeding_weight):
|
| 301 |
+
hidden_states = self.dense(hidden_states)
|
| 302 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 303 |
+
# b x s x d
|
| 304 |
+
hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
|
| 305 |
+
|
| 306 |
+
# b x s x v
|
| 307 |
+
logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias
|
| 308 |
+
return logits
|
nlu/DeBERTa/deberta/cache_utils.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
#
|
| 6 |
+
# Author: penhe@microsoft.com
|
| 7 |
+
# Date: 05/15/2020
|
| 8 |
+
#
|
| 9 |
+
|
| 10 |
+
import pdb
|
| 11 |
+
import torch
|
| 12 |
+
import os
|
| 13 |
+
import requests
|
| 14 |
+
from .config import ModelConfig
|
| 15 |
+
import pathlib
|
| 16 |
+
from ..utils import xtqdm as tqdm
|
| 17 |
+
from zipfile import ZipFile
|
| 18 |
+
from ..utils import get_logger
|
| 19 |
+
logger = get_logger()
|
| 20 |
+
|
| 21 |
+
__all__ = ['pretrained_models', 'load_model_state', 'load_vocab']
|
| 22 |
+
|
| 23 |
+
class PretrainedModel:
|
| 24 |
+
def __init__(self, name, vocab, vocab_type, model='pytorch_model.bin', config='config.json', **kwargs):
|
| 25 |
+
self.__dict__.update(kwargs)
|
| 26 |
+
host = f'https://huggingface.co/microsoft/{name}/resolve/main/'
|
| 27 |
+
self.name = name
|
| 28 |
+
self.model_url = host + model
|
| 29 |
+
self.config_url = host + config
|
| 30 |
+
self.vocab_url = host + vocab
|
| 31 |
+
self.vocab_type = vocab_type
|
| 32 |
+
|
| 33 |
+
pretrained_models= {
|
| 34 |
+
'base': PretrainedModel('deberta-base', 'bpe_encoder.bin', 'gpt2'),
|
| 35 |
+
'large': PretrainedModel('deberta-large', 'bpe_encoder.bin', 'gpt2'),
|
| 36 |
+
'xlarge': PretrainedModel('deberta-xlarge', 'bpe_encoder.bin', 'gpt2'),
|
| 37 |
+
'base-mnli': PretrainedModel('deberta-base-mnli', 'bpe_encoder.bin', 'gpt2'),
|
| 38 |
+
'large-mnli': PretrainedModel('deberta-large-mnli', 'bpe_encoder.bin', 'gpt2'),
|
| 39 |
+
'xlarge-mnli': PretrainedModel('deberta-xlarge-mnli', 'bpe_encoder.bin', 'gpt2'),
|
| 40 |
+
'xlarge-v2': PretrainedModel('deberta-v2-xlarge', 'spm.model', 'spm'),
|
| 41 |
+
'xxlarge-v2': PretrainedModel('deberta-v2-xxlarge', 'spm.model', 'spm'),
|
| 42 |
+
'xlarge-v2-mnli': PretrainedModel('deberta-v2-xlarge-mnli', 'spm.model', 'spm'),
|
| 43 |
+
'xxlarge-v2-mnli': PretrainedModel('deberta-v2-xxlarge-mnli', 'spm.model', 'spm'),
|
| 44 |
+
'deberta-v3-small': PretrainedModel('deberta-v3-small', 'spm.model', 'spm'),
|
| 45 |
+
'deberta-v3-base': PretrainedModel('deberta-v3-base', 'spm.model', 'spm'),
|
| 46 |
+
'deberta-v3-large': PretrainedModel('deberta-v3-large', 'spm.model', 'spm'),
|
| 47 |
+
'mdeberta-v3-base': PretrainedModel('mdeberta-v3-base', 'spm.model', 'spm'),
|
| 48 |
+
'deberta-v3-xsmall': PretrainedModel('deberta-v3-xsmall', 'spm.model', 'spm'),
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def download_asset(url, name, tag=None, no_cache=False, cache_dir=None):
|
| 52 |
+
_tag = tag
|
| 53 |
+
if _tag is None:
|
| 54 |
+
_tag = 'latest'
|
| 55 |
+
if not cache_dir:
|
| 56 |
+
cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/')
|
| 57 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 58 |
+
output=os.path.join(cache_dir, name)
|
| 59 |
+
if os.path.exists(output) and (not no_cache):
|
| 60 |
+
return output
|
| 61 |
+
|
| 62 |
+
#repo=f'https://huggingface.co/microsoft/deberta-{name}/blob/main/bpe_encoder.bin'
|
| 63 |
+
headers = {}
|
| 64 |
+
headers['Accept'] = 'application/octet-stream'
|
| 65 |
+
resp = requests.get(url, stream=True, headers=headers)
|
| 66 |
+
if resp.status_code != 200:
|
| 67 |
+
raise Exception(f'Request for {url} return {resp.status_code}, {resp.text}')
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
with open(output, 'wb') as fs:
|
| 71 |
+
progress = tqdm(total=int(resp.headers['Content-Length']) if 'Content-Length' in resp.headers else -1, ncols=80, desc=f'Downloading {name}')
|
| 72 |
+
for c in resp.iter_content(chunk_size=1024*1024):
|
| 73 |
+
fs.write(c)
|
| 74 |
+
progress.update(len(c))
|
| 75 |
+
progress.close()
|
| 76 |
+
except:
|
| 77 |
+
os.remove(output)
|
| 78 |
+
raise
|
| 79 |
+
|
| 80 |
+
return output
|
| 81 |
+
|
| 82 |
+
def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=None):
|
| 83 |
+
model_path = path_or_pretrained_id
|
| 84 |
+
if model_path and (not os.path.exists(model_path)) and (path_or_pretrained_id.lower() in pretrained_models):
|
| 85 |
+
_tag = tag
|
| 86 |
+
if 'deberta-v3-base' in path_or_pretrained_id:
|
| 87 |
+
pretrained = pretrained_models['deberta-v3-base']
|
| 88 |
+
else:
|
| 89 |
+
pretrained = pretrained_models[path_or_pretrained_id.lower()]
|
| 90 |
+
if _tag is None:
|
| 91 |
+
_tag = 'latest'
|
| 92 |
+
if not cache_dir:
|
| 93 |
+
cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
|
| 94 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 95 |
+
model_path = os.path.join(cache_dir, 'pytorch_model.bin')
|
| 96 |
+
if (not os.path.exists(model_path)) or no_cache:
|
| 97 |
+
asset = download_asset(pretrained.model_url, 'pytorch_model.bin', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
|
| 98 |
+
asset = download_asset(pretrained.config_url, 'model_config.json', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
|
| 99 |
+
elif not model_path:
|
| 100 |
+
return None,None
|
| 101 |
+
|
| 102 |
+
model_path = os.path.join(model_path, 'pytorch_model.bin')
|
| 103 |
+
config_path = os.path.join(os.path.dirname(model_path), 'model_config.json')
|
| 104 |
+
model_state = torch.load(model_path, map_location='cpu')
|
| 105 |
+
logger.info("Loaded pretrained model file {}".format(model_path))
|
| 106 |
+
if 'config' in model_state:
|
| 107 |
+
model_config = ModelConfig.from_dict(model_state['config'])
|
| 108 |
+
elif os.path.exists(config_path):
|
| 109 |
+
model_config = ModelConfig.from_json_file(config_path)
|
| 110 |
+
else:
|
| 111 |
+
model_config = None
|
| 112 |
+
return model_state, model_config
|
| 113 |
+
|
| 114 |
+
def load_vocab(vocab_path=None, vocab_type=None, pretrained_id=None, tag=None, no_cache=False, cache_dir=None):
|
| 115 |
+
if pretrained_id and (pretrained_id.lower() in pretrained_models):
|
| 116 |
+
_tag = tag
|
| 117 |
+
if _tag is None:
|
| 118 |
+
_tag = 'latest'
|
| 119 |
+
|
| 120 |
+
pretrained = pretrained_models[pretrained_id.lower()]
|
| 121 |
+
if not cache_dir:
|
| 122 |
+
cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
|
| 123 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 124 |
+
vocab_type = pretrained.vocab_type
|
| 125 |
+
url = pretrained.vocab_url
|
| 126 |
+
outname = os.path.basename(url)
|
| 127 |
+
vocab_path =os.path.join(cache_dir, outname)
|
| 128 |
+
if (not os.path.exists(vocab_path)) or no_cache:
|
| 129 |
+
asset = download_asset(url, outname, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
|
| 130 |
+
if vocab_type is None:
|
| 131 |
+
vocab_type = 'spm'
|
| 132 |
+
return vocab_path, vocab_type
|
| 133 |
+
|
| 134 |
+
def test_download():
|
| 135 |
+
vocab = load_vocab()
|
nlu/DeBERTa/deberta/config.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import copy
|
| 3 |
+
|
| 4 |
+
__all__=['AbsModelConfig', 'ModelConfig']
|
| 5 |
+
|
| 6 |
+
class AbsModelConfig(object):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
@classmethod
|
| 11 |
+
def from_dict(cls, json_object):
|
| 12 |
+
"""Constructs a `ModelConfig` from a Python dictionary of parameters."""
|
| 13 |
+
config = cls()
|
| 14 |
+
for key, value in json_object.items():
|
| 15 |
+
if isinstance(value, dict):
|
| 16 |
+
value = AbsModelConfig.from_dict(value)
|
| 17 |
+
config.__dict__[key] = value
|
| 18 |
+
return config
|
| 19 |
+
|
| 20 |
+
@classmethod
|
| 21 |
+
def from_json_file(cls, json_file):
|
| 22 |
+
"""Constructs a `ModelConfig` from a json file of parameters."""
|
| 23 |
+
with open(json_file, "r", encoding='utf-8') as reader:
|
| 24 |
+
text = reader.read()
|
| 25 |
+
return cls.from_dict(json.loads(text))
|
| 26 |
+
|
| 27 |
+
def __repr__(self):
|
| 28 |
+
return str(self.to_json_string())
|
| 29 |
+
|
| 30 |
+
def to_dict(self):
|
| 31 |
+
"""Serializes this instance to a Python dictionary."""
|
| 32 |
+
output = copy.deepcopy(self.__dict__)
|
| 33 |
+
return output
|
| 34 |
+
|
| 35 |
+
def to_json_string(self):
|
| 36 |
+
"""Serializes this instance to a JSON string."""
|
| 37 |
+
def _json_default(obj):
|
| 38 |
+
if isinstance(obj, AbsModelConfig):
|
| 39 |
+
return obj.__dict__
|
| 40 |
+
return json.dumps(self.__dict__, indent=2, sort_keys=True, default=_json_default) + "\n"
|
| 41 |
+
|
| 42 |
+
class ModelConfig(AbsModelConfig):
|
| 43 |
+
"""Configuration class to store the configuration of a :class:`~DeBERTa.deberta.DeBERTa` model.
|
| 44 |
+
|
| 45 |
+
Attributes:
|
| 46 |
+
hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`.
|
| 47 |
+
num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`.
|
| 48 |
+
num_attention_heads (int): Number of attention heads for each attention layer in
|
| 49 |
+
the Transformer encoder, default: `12`.
|
| 50 |
+
intermediate_size (int): The size of the "intermediate" (i.e., feed-forward)
|
| 51 |
+
layer in the Transformer encoder, default: `3072`.
|
| 52 |
+
hidden_act (str): The non-linear activation function (function or string) in the
|
| 53 |
+
encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`.
|
| 54 |
+
hidden_dropout_prob (float): The dropout probabilitiy for all fully connected
|
| 55 |
+
layers in the embeddings, encoder, and pooler, default: `0.1`.
|
| 56 |
+
attention_probs_dropout_prob (float): The dropout ratio for the attention
|
| 57 |
+
probabilities, default: `0.1`.
|
| 58 |
+
max_position_embeddings (int): The maximum sequence length that this model might
|
| 59 |
+
ever be used with. Typically set this to something large just in case
|
| 60 |
+
(e.g., 512 or 1024 or 2048), default: `512`.
|
| 61 |
+
type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into
|
| 62 |
+
`DeBERTa` model, default: `-1`.
|
| 63 |
+
initializer_range (int): The sttdev of the _normal_initializer for
|
| 64 |
+
initializing all weight matrices, default: `0.02`.
|
| 65 |
+
relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`.
|
| 66 |
+
max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`.
|
| 67 |
+
padding_idx (int): The value used to pad input_ids, default: `0`.
|
| 68 |
+
position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`.
|
| 69 |
+
pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p"., default: "None".
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
def __init__(self):
|
| 74 |
+
"""Constructs ModelConfig.
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
self.hidden_size = 768
|
| 79 |
+
self.num_hidden_layers = 12
|
| 80 |
+
self.num_attention_heads = 12
|
| 81 |
+
self.hidden_act = "gelu"
|
| 82 |
+
self.intermediate_size = 3072
|
| 83 |
+
self.hidden_dropout_prob = 0.1
|
| 84 |
+
self.attention_probs_dropout_prob = 0.1
|
| 85 |
+
self.max_position_embeddings = 512
|
| 86 |
+
self.type_vocab_size = 0
|
| 87 |
+
self.initializer_range = 0.02
|
| 88 |
+
self.layer_norm_eps = 1e-7
|
| 89 |
+
self.padding_idx = 0
|
| 90 |
+
self.vocab_size = -1
|
nlu/DeBERTa/deberta/da_utils.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pdb
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
__all__=['build_relative_position', 'make_log_bucket_position']
|
| 8 |
+
|
| 9 |
+
@lru_cache(maxsize=128)
|
| 10 |
+
def make_log_bucket_dict(bucket_size, max_position, device=None):
|
| 11 |
+
relative_pos = torch.arange(-max_position, max_position, device=device)
|
| 12 |
+
sign = torch.sign(relative_pos)
|
| 13 |
+
mid = bucket_size//2
|
| 14 |
+
abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), torch.tensor(mid-1).to(relative_pos), torch.abs(relative_pos))
|
| 15 |
+
log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid
|
| 16 |
+
bucket_pos = torch.where(abs_pos<=mid, relative_pos, (log_pos*sign).to(relative_pos)).to(torch.long)
|
| 17 |
+
return bucket_pos
|
| 18 |
+
|
| 19 |
+
# Faster version
|
| 20 |
+
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
| 21 |
+
relative_pos = torch.clamp(relative_pos,-max_position+1, max_position-1) + max_position
|
| 22 |
+
bucket_dict = make_log_bucket_dict(bucket_size, max_position, relative_pos.device)
|
| 23 |
+
for d in range(relative_pos.dim()-1):
|
| 24 |
+
bucket_dict = bucket_dict.unsqueeze(0)
|
| 25 |
+
bucket_pos = torch.gather(bucket_dict.expand(list(relative_pos.size())[:-1] + [bucket_dict.size(-1)]), index=relative_pos.long(), dim=-1)
|
| 26 |
+
return bucket_pos
|
| 27 |
+
|
| 28 |
+
@lru_cache(maxsize=128)
|
| 29 |
+
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
|
| 30 |
+
q_ids = torch.arange(0, query_size)
|
| 31 |
+
k_ids = torch.arange(0, key_size)
|
| 32 |
+
if device is not None:
|
| 33 |
+
q_ids = q_ids.to(device)
|
| 34 |
+
k_ids = k_ids.to(device)
|
| 35 |
+
rel_pos_ids = q_ids.view(-1,1) - k_ids.view(1,-1)
|
| 36 |
+
#q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
|
| 37 |
+
if bucket_size>0 and max_position > 0:
|
| 38 |
+
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
|
| 39 |
+
#rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
|
| 40 |
+
rel_pos_ids = rel_pos_ids[:query_size, :]
|
| 41 |
+
rel_pos_ids = rel_pos_ids.unsqueeze(0)
|
| 42 |
+
return rel_pos_ids
|
| 43 |
+
|
| 44 |
+
def build_relative_position_from_abs(query_pos, key_pos, bucket_size=-1, max_position=-1, device=None):
|
| 45 |
+
if isinstance(query_pos, tuple):
|
| 46 |
+
q_ids = torch.tensor(query_pos)
|
| 47 |
+
else:
|
| 48 |
+
q_ids = query_pos
|
| 49 |
+
if isinstance(key_pos, tuple):
|
| 50 |
+
k_ids = torch.tensor(key_pos)
|
| 51 |
+
else:
|
| 52 |
+
k_ids = key_pos
|
| 53 |
+
|
| 54 |
+
if device is not None:
|
| 55 |
+
q_ids = q_ids.to(device)
|
| 56 |
+
k_ids = k_ids.to(device)
|
| 57 |
+
rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.unsqueeze(-2)
|
| 58 |
+
#q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
|
| 59 |
+
if bucket_size>0 and max_position > 0:
|
| 60 |
+
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
|
| 61 |
+
#rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
|
| 62 |
+
return rel_pos_ids
|
| 63 |
+
|
| 64 |
+
def test_log_bucket():
|
| 65 |
+
x=np.arange(-511,511)
|
| 66 |
+
y=make_log_bucket_position(x, 128, 512)
|
| 67 |
+
pdb.set_trace()
|
| 68 |
+
|
nlu/DeBERTa/deberta/deberta.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
#
|
| 6 |
+
# Author: penhe@microsoft.com
|
| 7 |
+
# Date: 01/15/2020
|
| 8 |
+
#
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
import torch
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
from .ops import *
|
| 16 |
+
from .bert import *
|
| 17 |
+
from .config import ModelConfig
|
| 18 |
+
from .cache_utils import load_model_state
|
| 19 |
+
import pdb
|
| 20 |
+
|
| 21 |
+
__all__ = ['DeBERTa']
|
| 22 |
+
|
| 23 |
+
class DeBERTa(torch.nn.Module):
|
| 24 |
+
""" DeBERTa encoder
|
| 25 |
+
This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.
|
| 26 |
+
|
| 27 |
+
Parameters:
|
| 28 |
+
config:
|
| 29 |
+
A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
|
| 30 |
+
for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`
|
| 31 |
+
|
| 32 |
+
pre_trained:
|
| 33 |
+
The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \
|
| 34 |
+
i.e. [**base, large, base_mnli, large_mnli**]
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, config=None, pre_trained=None):
|
| 39 |
+
super().__init__()
|
| 40 |
+
state = None
|
| 41 |
+
if pre_trained is not None:
|
| 42 |
+
state, model_config = load_model_state(pre_trained)
|
| 43 |
+
if config is not None and model_config is not None:
|
| 44 |
+
for k in config.__dict__:
|
| 45 |
+
if k not in ['hidden_size',
|
| 46 |
+
'intermediate_size',
|
| 47 |
+
'num_attention_heads',
|
| 48 |
+
'num_hidden_layers',
|
| 49 |
+
'vocab_size',
|
| 50 |
+
'max_position_embeddings']:
|
| 51 |
+
model_config.__dict__[k] = config.__dict__[k]
|
| 52 |
+
config = copy.copy(model_config)
|
| 53 |
+
self.embeddings = BertEmbeddings(config)
|
| 54 |
+
self.encoder = BertEncoder(config)
|
| 55 |
+
self.config = config
|
| 56 |
+
self.pre_trained = pre_trained
|
| 57 |
+
self.apply_state(state)
|
| 58 |
+
|
| 59 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
|
| 60 |
+
"""
|
| 61 |
+
Args:
|
| 62 |
+
input_ids:
|
| 63 |
+
a torch.LongTensor of shape [batch_size, sequence_length] \
|
| 64 |
+
with the word token indices in the vocabulary
|
| 65 |
+
|
| 66 |
+
attention_mask:
|
| 67 |
+
an optional parameter for input mask or attention mask.
|
| 68 |
+
|
| 69 |
+
- If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
|
| 70 |
+
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
|
| 71 |
+
input sequence length in the current batch. It's the mask that we typically use for attention when \
|
| 72 |
+
a batch has varying length sentences.
|
| 73 |
+
|
| 74 |
+
- If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
|
| 75 |
+
In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
|
| 76 |
+
|
| 77 |
+
token_type_ids:
|
| 78 |
+
an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
|
| 79 |
+
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
|
| 80 |
+
a `sentence B` token (see BERT paper for more details).
|
| 81 |
+
|
| 82 |
+
output_all_encoded_layers:
|
| 83 |
+
whether to output results of all encoder layers, default, True
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
|
| 87 |
+
- The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
|
| 88 |
+
the last layer of stacked transformer layers
|
| 89 |
+
|
| 90 |
+
- Attention matrix of self-attention layers if `return_att=True`
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
Example::
|
| 94 |
+
|
| 95 |
+
# Batch of wordPiece token ids.
|
| 96 |
+
# Each sample was padded with zero to the maxium length of the batch
|
| 97 |
+
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
| 98 |
+
# Mask of valid input ids
|
| 99 |
+
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
| 100 |
+
|
| 101 |
+
# DeBERTa model initialized with pretrained base model
|
| 102 |
+
bert = DeBERTa(pre_trained='base')
|
| 103 |
+
|
| 104 |
+
encoder_layers = bert(input_ids, attention_mask=attention_mask)
|
| 105 |
+
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
if attention_mask is None:
|
| 109 |
+
attention_mask = torch.ones_like(input_ids)
|
| 110 |
+
if token_type_ids is None:
|
| 111 |
+
token_type_ids = torch.zeros_like(input_ids)
|
| 112 |
+
|
| 113 |
+
ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, attention_mask)
|
| 114 |
+
embedding_output = ebd_output['embeddings']
|
| 115 |
+
encoder_output = self.encoder(embedding_output,
|
| 116 |
+
attention_mask,
|
| 117 |
+
output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
|
| 118 |
+
encoder_output.update(ebd_output)
|
| 119 |
+
return encoder_output
|
| 120 |
+
|
| 121 |
+
def apply_state(self, state = None):
|
| 122 |
+
""" Load state from previous loaded model state dictionary.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
|
| 126 |
+
If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
|
| 127 |
+
the `DeBERTa` model
|
| 128 |
+
"""
|
| 129 |
+
if self.pre_trained is None and state is None:
|
| 130 |
+
return
|
| 131 |
+
if state is None:
|
| 132 |
+
state, config = load_model_state(self.pre_trained)
|
| 133 |
+
self.config = config
|
| 134 |
+
|
| 135 |
+
prefix = ''
|
| 136 |
+
for k in state:
|
| 137 |
+
if 'embeddings.' in k:
|
| 138 |
+
if not k.startswith('embeddings.'):
|
| 139 |
+
prefix = k[:k.index('embeddings.')]
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
missing_keys = []
|
| 143 |
+
unexpected_keys = []
|
| 144 |
+
error_msgs = []
|
| 145 |
+
self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
|
nlu/DeBERTa/deberta/disentangled_attention.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
#
|
| 6 |
+
# Author: penhe@microsoft.com
|
| 7 |
+
# Date: 01/15/2020
|
| 8 |
+
#
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Disentangled SelfAttention module
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import math
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
import functools
|
| 19 |
+
import pdb
|
| 20 |
+
|
| 21 |
+
from .ops import *
|
| 22 |
+
from .da_utils import build_relative_position
|
| 23 |
+
|
| 24 |
+
from ..utils import get_logger
|
| 25 |
+
logger=get_logger()
|
| 26 |
+
|
| 27 |
+
from adapterlib import adapter_dict
|
| 28 |
+
|
| 29 |
+
__all__=['DisentangledSelfAttention']
|
| 30 |
+
class DisentangledSelfAttention(nn.Module):
|
| 31 |
+
def __init__(self, config):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.num_attention_heads = config.num_attention_heads
|
| 34 |
+
_attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 35 |
+
self.attention_head_size = getattr(config, 'attention_head_size', _attention_head_size)
|
| 36 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 37 |
+
|
| 38 |
+
# -----------------------------------------------------------------------------------------------------------------------
|
| 39 |
+
if config.inject_adapter != 'linear':
|
| 40 |
+
self.query_proj = adapter_dict[config.inject_adapter](config.hidden_size, self.all_head_size, config=config)
|
| 41 |
+
else:
|
| 42 |
+
self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 43 |
+
|
| 44 |
+
# self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 45 |
+
if config.inject_adapter != 'linear':
|
| 46 |
+
self.key_proj = adapter_dict[config.inject_adapter](config.hidden_size, self.all_head_size, config=config)
|
| 47 |
+
else:
|
| 48 |
+
self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 49 |
+
|
| 50 |
+
if config.inject_adapter != 'linear':
|
| 51 |
+
self.value_proj = adapter_dict[config.inject_adapter](config.hidden_size, self.all_head_size, config=config)
|
| 52 |
+
else:
|
| 53 |
+
self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 54 |
+
|
| 55 |
+
# -----------------------------------------------------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
# self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 58 |
+
# self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 59 |
+
# self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 60 |
+
|
| 61 |
+
self.share_att_key = getattr(config, 'share_att_key', False)
|
| 62 |
+
self.pos_att_type = [x.strip() for x in getattr(config, 'pos_att_type', 'c2p').lower().split('|')] # c2p|p2c
|
| 63 |
+
self.relative_attention = getattr(config, 'relative_attention', False)
|
| 64 |
+
|
| 65 |
+
if self.relative_attention:
|
| 66 |
+
self.position_buckets = getattr(config, 'position_buckets', -1)
|
| 67 |
+
self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
|
| 68 |
+
if self.max_relative_positions <1:
|
| 69 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 70 |
+
self.pos_ebd_size = self.max_relative_positions
|
| 71 |
+
if self.position_buckets>0:
|
| 72 |
+
self.pos_ebd_size = self.position_buckets
|
| 73 |
+
# For backward compitable
|
| 74 |
+
|
| 75 |
+
self.pos_dropout = StableDropout(config.hidden_dropout_prob)
|
| 76 |
+
|
| 77 |
+
if (not self.share_att_key):
|
| 78 |
+
if 'c2p' in self.pos_att_type or 'p2p' in self.pos_att_type:
|
| 79 |
+
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 80 |
+
if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
|
| 81 |
+
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
|
| 82 |
+
|
| 83 |
+
self.dropout = StableDropout(config.attention_probs_dropout_prob)
|
| 84 |
+
self._register_load_state_dict_pre_hook(self._pre_load_hook)
|
| 85 |
+
|
| 86 |
+
def transpose_for_scores(self, x, attention_heads):
|
| 87 |
+
new_x_shape = x.size()[:-1] + (attention_heads, -1)
|
| 88 |
+
x = x.view(*new_x_shape)
|
| 89 |
+
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
|
| 90 |
+
|
| 91 |
+
def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
|
| 92 |
+
if query_states is None:
|
| 93 |
+
query_states = hidden_states
|
| 94 |
+
query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads).float()
|
| 95 |
+
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads).float()
|
| 96 |
+
value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
|
| 97 |
+
|
| 98 |
+
rel_att = None
|
| 99 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 100 |
+
scale_factor = 1
|
| 101 |
+
if 'c2p' in self.pos_att_type:
|
| 102 |
+
scale_factor += 1
|
| 103 |
+
if 'p2c' in self.pos_att_type:
|
| 104 |
+
scale_factor += 1
|
| 105 |
+
if 'p2p' in self.pos_att_type:
|
| 106 |
+
scale_factor += 1
|
| 107 |
+
scale = 1/math.sqrt(query_layer.size(-1)*scale_factor)
|
| 108 |
+
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)*scale)
|
| 109 |
+
if self.relative_attention:
|
| 110 |
+
rel_embeddings = self.pos_dropout(rel_embeddings)
|
| 111 |
+
rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
|
| 112 |
+
|
| 113 |
+
if rel_att is not None:
|
| 114 |
+
attention_scores = (attention_scores + rel_att)
|
| 115 |
+
attention_scores = (attention_scores - attention_scores.max(dim=-1, keepdim=True).values.detach()).to(hidden_states)
|
| 116 |
+
attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1))
|
| 117 |
+
|
| 118 |
+
# bxhxlxd
|
| 119 |
+
_attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
|
| 120 |
+
attention_probs = self.dropout(_attention_probs)
|
| 121 |
+
context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer)
|
| 122 |
+
context_layer = context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)).permute(0, 2, 1, 3).contiguous()
|
| 123 |
+
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
| 124 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 125 |
+
|
| 126 |
+
return {
|
| 127 |
+
'hidden_states': context_layer,
|
| 128 |
+
'attention_probs': _attention_probs,
|
| 129 |
+
'attention_logits': attention_scores
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
| 133 |
+
if relative_pos is None:
|
| 134 |
+
q = query_layer.size(-2)
|
| 135 |
+
relative_pos = build_relative_position(q, key_layer.size(-2), bucket_size = self.position_buckets, \
|
| 136 |
+
max_position = self.max_relative_positions, device=query_layer.device)
|
| 137 |
+
if relative_pos.dim()==2:
|
| 138 |
+
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
|
| 139 |
+
elif relative_pos.dim()==3:
|
| 140 |
+
relative_pos = relative_pos.unsqueeze(1)
|
| 141 |
+
# bxhxqxk
|
| 142 |
+
elif relative_pos.dim()!=4:
|
| 143 |
+
raise ValueError(f'Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}')
|
| 144 |
+
|
| 145 |
+
att_span = self.pos_ebd_size
|
| 146 |
+
relative_pos = relative_pos.long().to(query_layer.device)
|
| 147 |
+
|
| 148 |
+
rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span:self.pos_ebd_size + att_span, :].unsqueeze(0) #.repeat(query_layer.size(0)//self.num_attention_heads, 1, 1)
|
| 149 |
+
if self.share_att_key:
|
| 150 |
+
pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings), self.num_attention_heads)\
|
| 151 |
+
.repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
|
| 152 |
+
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads)\
|
| 153 |
+
.repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
|
| 154 |
+
else:
|
| 155 |
+
if 'c2p' in self.pos_att_type or 'p2p' in self.pos_att_type:
|
| 156 |
+
pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads)\
|
| 157 |
+
.repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
|
| 158 |
+
if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
|
| 159 |
+
pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads)\
|
| 160 |
+
.repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
|
| 161 |
+
|
| 162 |
+
score = 0
|
| 163 |
+
# content->position
|
| 164 |
+
if 'c2p' in self.pos_att_type:
|
| 165 |
+
scale = 1/math.sqrt(pos_key_layer.size(-1)*scale_factor)
|
| 166 |
+
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2).to(query_layer)*scale)
|
| 167 |
+
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span*2-1).squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)])
|
| 168 |
+
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos)
|
| 169 |
+
score += c2p_att
|
| 170 |
+
|
| 171 |
+
# position->content
|
| 172 |
+
if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
|
| 173 |
+
scale = 1/math.sqrt(pos_query_layer.size(-1)*scale_factor)
|
| 174 |
+
|
| 175 |
+
if 'p2c' in self.pos_att_type:
|
| 176 |
+
p2c_att = torch.bmm(pos_query_layer.to(key_layer)*scale, key_layer.transpose(-1, -2))
|
| 177 |
+
p2c_att = torch.gather(p2c_att, dim=-2, index=c2p_pos)
|
| 178 |
+
score += p2c_att
|
| 179 |
+
|
| 180 |
+
# position->position
|
| 181 |
+
if 'p2p' in self.pos_att_type:
|
| 182 |
+
pos_query = pos_query_layer[:,:,att_span:,:]
|
| 183 |
+
p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
|
| 184 |
+
p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
|
| 185 |
+
if query_layer.size(-2) != key_layer.size(-2):
|
| 186 |
+
p2p_att = torch.gather(p2p_att, dim=-2, index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))))
|
| 187 |
+
p2p_att = torch.gather(p2p_att, dim=-1, index=c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]))
|
| 188 |
+
score += p2p_att
|
| 189 |
+
|
| 190 |
+
return score
|
| 191 |
+
|
| 192 |
+
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
|
| 193 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 194 |
+
self_state = self.state_dict()
|
| 195 |
+
if ((prefix + 'query_proj.weight') not in state_dict) and ((prefix + 'in_proj.weight') in state_dict):
|
| 196 |
+
v1_proj = state_dict[prefix+'in_proj.weight']
|
| 197 |
+
v1_proj = v1_proj.unsqueeze(0).reshape(self.num_attention_heads, -1, v1_proj.size(-1))
|
| 198 |
+
q,k,v=v1_proj.chunk(3, dim=1)
|
| 199 |
+
state_dict[prefix + 'query_proj.weight'] = q.reshape(-1, v1_proj.size(-1))
|
| 200 |
+
state_dict[prefix + 'key_proj.weight'] = k.reshape(-1, v1_proj.size(-1))
|
| 201 |
+
state_dict[prefix + 'key_proj.bias'] = self_state['key_proj.bias']
|
| 202 |
+
state_dict[prefix + 'value_proj.weight'] = v.reshape(-1, v1_proj.size(-1))
|
| 203 |
+
v1_query_bias = state_dict[prefix + 'q_bias']
|
| 204 |
+
state_dict[prefix + 'query_proj.bias'] = v1_query_bias
|
| 205 |
+
v1_value_bias = state_dict[prefix +'v_bias']
|
| 206 |
+
state_dict[prefix + 'value_proj.bias'] = v1_value_bias
|
| 207 |
+
|
| 208 |
+
v1_pos_key_proj = state_dict[prefix + 'pos_proj.weight']
|
| 209 |
+
state_dict[prefix + 'pos_key_proj.weight'] = v1_pos_key_proj
|
| 210 |
+
v1_pos_query_proj = state_dict[prefix + 'pos_q_proj.weight']
|
| 211 |
+
state_dict[prefix + 'pos_query_proj.weight'] = v1_pos_query_proj
|
| 212 |
+
v1_pos_query_proj_bias = state_dict[prefix + 'pos_q_proj.bias']
|
| 213 |
+
state_dict[prefix + 'pos_query_proj.bias'] = v1_pos_query_proj_bias
|
| 214 |
+
state_dict[prefix + 'pos_key_proj.bias'] = self_state['pos_key_proj.bias']
|
| 215 |
+
|
| 216 |
+
del state_dict[prefix + 'in_proj.weight']
|
| 217 |
+
del state_dict[prefix + 'q_bias']
|
| 218 |
+
del state_dict[prefix + 'v_bias']
|
| 219 |
+
del state_dict[prefix + 'pos_proj.weight']
|
| 220 |
+
del state_dict[prefix + 'pos_q_proj.weight']
|
| 221 |
+
del state_dict[prefix + 'pos_q_proj.bias']
|
nlu/DeBERTa/deberta/gpt2_bpe_utils.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte pair encoding utilities from GPT-2.
|
| 3 |
+
|
| 4 |
+
Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
| 5 |
+
Original license: MIT
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from functools import lru_cache
|
| 9 |
+
import json
|
| 10 |
+
import random
|
| 11 |
+
import unicodedata
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import regex as re
|
| 15 |
+
except ImportError:
|
| 16 |
+
raise ImportError('Please install regex with: pip install regex')
|
| 17 |
+
|
| 18 |
+
@lru_cache()
|
| 19 |
+
def bytes_to_unicode():
|
| 20 |
+
"""
|
| 21 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 22 |
+
The reversible bpe codes work on unicode strings.
|
| 23 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 24 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 25 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 26 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 27 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 28 |
+
"""
|
| 29 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 30 |
+
cs = bs[:]
|
| 31 |
+
n = 0
|
| 32 |
+
for b in range(2**8):
|
| 33 |
+
if b not in bs:
|
| 34 |
+
bs.append(b)
|
| 35 |
+
cs.append(2**8+n)
|
| 36 |
+
n += 1
|
| 37 |
+
cs = [chr(n) for n in cs]
|
| 38 |
+
return dict(zip(bs, cs))
|
| 39 |
+
|
| 40 |
+
def get_pairs(word):
|
| 41 |
+
"""Return set of symbol pairs in a word.
|
| 42 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 43 |
+
"""
|
| 44 |
+
pairs = set()
|
| 45 |
+
prev_char = word[0]
|
| 46 |
+
for char in word[1:]:
|
| 47 |
+
pairs.add((prev_char, char))
|
| 48 |
+
prev_char = char
|
| 49 |
+
return pairs
|
| 50 |
+
|
| 51 |
+
class Encoder:
|
| 52 |
+
|
| 53 |
+
def __init__(self, encoder, bpe_merges, errors='replace'):
|
| 54 |
+
self.encoder = encoder
|
| 55 |
+
self.decoder = {v:k for k,v in self.encoder.items()}
|
| 56 |
+
self.errors = errors # how to handle errors in decoding
|
| 57 |
+
self.byte_encoder = bytes_to_unicode()
|
| 58 |
+
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
| 59 |
+
self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges))))
|
| 60 |
+
self.cache = {}
|
| 61 |
+
self.random = random.Random(0)
|
| 62 |
+
|
| 63 |
+
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
| 64 |
+
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
| 65 |
+
|
| 66 |
+
def bpe(self, token):
|
| 67 |
+
if token in self.cache:
|
| 68 |
+
return self.cache[token]
|
| 69 |
+
word = tuple(token)
|
| 70 |
+
pairs = get_pairs(word)
|
| 71 |
+
|
| 72 |
+
if not pairs:
|
| 73 |
+
return token
|
| 74 |
+
|
| 75 |
+
while True:
|
| 76 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 77 |
+
if bigram not in self.bpe_ranks:
|
| 78 |
+
break
|
| 79 |
+
first, second = bigram
|
| 80 |
+
new_word = []
|
| 81 |
+
i = 0
|
| 82 |
+
while i < len(word):
|
| 83 |
+
try:
|
| 84 |
+
j = word.index(first, i)
|
| 85 |
+
new_word.extend(word[i:j])
|
| 86 |
+
i = j
|
| 87 |
+
except:
|
| 88 |
+
new_word.extend(word[i:])
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 92 |
+
new_word.append(first+second)
|
| 93 |
+
i += 2
|
| 94 |
+
else:
|
| 95 |
+
new_word.append(word[i])
|
| 96 |
+
i += 1
|
| 97 |
+
new_word = tuple(new_word)
|
| 98 |
+
word = new_word
|
| 99 |
+
if len(word) == 1:
|
| 100 |
+
break
|
| 101 |
+
else:
|
| 102 |
+
pairs = get_pairs(word)
|
| 103 |
+
word = ' '.join(word)
|
| 104 |
+
self.cache[token] = word
|
| 105 |
+
return word
|
| 106 |
+
|
| 107 |
+
def split_to_words(self, text):
|
| 108 |
+
return list(re.findall(self.pat, text))
|
| 109 |
+
|
| 110 |
+
def encode(self, text):
|
| 111 |
+
bpe_tokens = []
|
| 112 |
+
for token in self.split_to_words(text):
|
| 113 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 114 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 115 |
+
return bpe_tokens
|
| 116 |
+
|
| 117 |
+
def decode(self, tokens):
|
| 118 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 119 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
| 120 |
+
return text
|
| 121 |
+
|
| 122 |
+
def get_encoder(encoder, vocab):
|
| 123 |
+
return Encoder(
|
| 124 |
+
encoder=encoder,
|
| 125 |
+
bpe_merges=vocab,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def _is_whitespace(char):
|
| 129 |
+
"""Checks whether `chars` is a whitespace character."""
|
| 130 |
+
# \t, \n, and \r are technically contorl characters but we treat them
|
| 131 |
+
# as whitespace since they are generally considered as such.
|
| 132 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
| 133 |
+
return True
|
| 134 |
+
cat = unicodedata.category(char)
|
| 135 |
+
if cat == "Zs":
|
| 136 |
+
return True
|
| 137 |
+
return False
|
| 138 |
+
|
| 139 |
+
def _is_control(char):
|
| 140 |
+
"""Checks whether `chars` is a control character."""
|
| 141 |
+
# These are technically control characters but we count them as whitespace
|
| 142 |
+
# characters.
|
| 143 |
+
if char == "\t" or char == "\n" or char == "\r":
|
| 144 |
+
return False
|
| 145 |
+
cat = unicodedata.category(char)
|
| 146 |
+
if cat.startswith("C"):
|
| 147 |
+
return True
|
| 148 |
+
return False
|
| 149 |
+
|
| 150 |
+
def _is_punctuation(char):
|
| 151 |
+
"""Checks whether `chars` is a punctuation character."""
|
| 152 |
+
cp = ord(char)
|
| 153 |
+
# We treat all non-letter/number ASCII as punctuation.
|
| 154 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
| 155 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
| 156 |
+
# consistency.
|
| 157 |
+
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
| 158 |
+
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
| 159 |
+
return True
|
| 160 |
+
cat = unicodedata.category(char)
|
| 161 |
+
if cat.startswith("P"):
|
| 162 |
+
return True
|
| 163 |
+
return False
|
nlu/DeBERTa/deberta/gpt2_tokenizer.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Author: penhe@microsoft.com
|
| 8 |
+
# Date: 01/15/2020
|
| 9 |
+
#
|
| 10 |
+
|
| 11 |
+
# This piece of code is derived from https://github.com/pytorch/fairseq/blob/master/fairseq/data/encoders/gpt2_bpe.py
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import unicodedata
|
| 15 |
+
import os
|
| 16 |
+
from .gpt2_bpe_utils import get_encoder,_is_control,_is_whitespace,_is_punctuation
|
| 17 |
+
from .cache_utils import load_vocab
|
| 18 |
+
|
| 19 |
+
__all__ = ['GPT2Tokenizer']
|
| 20 |
+
|
| 21 |
+
class GPT2Tokenizer(object):
|
| 22 |
+
""" A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
|
| 26 |
+
vocab_file (:obj:`str`, optional):
|
| 27 |
+
The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, \
|
| 28 |
+
e.g. "bpe_encoder", default: `None`.
|
| 29 |
+
|
| 30 |
+
If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file is a \
|
| 31 |
+
state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. \
|
| 32 |
+
|
| 33 |
+
The difference between our wrapped GPT2 tokenizer and RoBERTa wrapped tokenizer are,
|
| 34 |
+
|
| 35 |
+
- Special tokens, unlike `RoBERTa` which use `<s>`, `</s>` as the `start` token and `end` token of a sentence. We use `[CLS]` and `[SEP]` as the `start` and `end`\
|
| 36 |
+
token of input sentence which is the same as `BERT`.
|
| 37 |
+
|
| 38 |
+
- We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264
|
| 39 |
+
|
| 40 |
+
do_lower_case (:obj:`bool`, optional):
|
| 41 |
+
Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**.
|
| 42 |
+
|
| 43 |
+
special_tokens (:obj:`list`, optional):
|
| 44 |
+
List of special tokens to be added to the end of the vocabulary.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None):
|
| 49 |
+
self.pad_token='[PAD]'
|
| 50 |
+
self.sep_token='[SEP]'
|
| 51 |
+
self.unk_token='[UNK]'
|
| 52 |
+
self.cls_token='[CLS]'
|
| 53 |
+
|
| 54 |
+
self.symbols = []
|
| 55 |
+
self.count = []
|
| 56 |
+
self.indices = {}
|
| 57 |
+
self.pad_token_id = self.add_symbol(self.pad_token)
|
| 58 |
+
self.cls_token_id = self.add_symbol(self.cls_token)
|
| 59 |
+
self.sep_token_id = self.add_symbol(self.sep_token)
|
| 60 |
+
self.unk_token_id = self.add_symbol(self.unk_token)
|
| 61 |
+
|
| 62 |
+
self.gpt2_encoder = torch.load(vocab_file)
|
| 63 |
+
self.bpe = get_encoder(self.gpt2_encoder['encoder'], self.gpt2_encoder['vocab'])
|
| 64 |
+
for w,n in self.gpt2_encoder['dict_map']:
|
| 65 |
+
self.add_symbol(w, n)
|
| 66 |
+
|
| 67 |
+
self.mask_token='[MASK]'
|
| 68 |
+
self.mask_id = self.add_symbol(self.mask_token)
|
| 69 |
+
self.special_tokens = ['[MASK]', '[SEP]', '[PAD]', '[UNK]', '[CLS]']
|
| 70 |
+
if special_tokens is not None:
|
| 71 |
+
for t in special_tokens:
|
| 72 |
+
self.add_special_token(t)
|
| 73 |
+
|
| 74 |
+
self.vocab = self.indices
|
| 75 |
+
self.ids_to_tokens = self.symbols
|
| 76 |
+
|
| 77 |
+
def tokenize(self, text):
|
| 78 |
+
""" Convert an input text to tokens.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
|
| 82 |
+
text (:obj:`str`): input text to be tokenized.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
A list of byte tokens where each token represent the byte id in GPT2 byte dictionary
|
| 86 |
+
|
| 87 |
+
Example::
|
| 88 |
+
|
| 89 |
+
>>> tokenizer = GPT2Tokenizer()
|
| 90 |
+
>>> text = "Hello world!"
|
| 91 |
+
>>> tokens = tokenizer.tokenize(text)
|
| 92 |
+
>>> print(tokens)
|
| 93 |
+
['15496', '995', '0']
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
bpe = self._encode(text)
|
| 97 |
+
|
| 98 |
+
return [t for t in bpe.split(' ') if t]
|
| 99 |
+
|
| 100 |
+
def convert_tokens_to_ids(self, tokens):
|
| 101 |
+
""" Convert list of tokens to ids.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
|
| 105 |
+
tokens (:obj:`list<str>`): list of tokens
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
|
| 109 |
+
List of ids
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
return [self.vocab[t] for t in tokens]
|
| 113 |
+
|
| 114 |
+
def convert_ids_to_tokens(self, ids):
|
| 115 |
+
""" Convert list of ids to tokens.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
|
| 119 |
+
ids (:obj:`list<int>`): list of ids
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
|
| 123 |
+
List of tokens
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
tokens = []
|
| 127 |
+
for i in ids:
|
| 128 |
+
tokens.append(self.ids_to_tokens[i])
|
| 129 |
+
return tokens
|
| 130 |
+
|
| 131 |
+
def split_to_words(self, text):
|
| 132 |
+
return self.bpe.split_to_words(text)
|
| 133 |
+
|
| 134 |
+
def decode(self, tokens):
|
| 135 |
+
""" Decode list of tokens to text strings.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
|
| 139 |
+
tokens (:obj:`list<str>`): list of tokens.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
|
| 143 |
+
Text string corresponds to the input tokens.
|
| 144 |
+
|
| 145 |
+
Example::
|
| 146 |
+
|
| 147 |
+
>>> tokenizer = GPT2Tokenizer()
|
| 148 |
+
>>> text = "Hello world!"
|
| 149 |
+
>>> tokens = tokenizer.tokenize(text)
|
| 150 |
+
>>> print(tokens)
|
| 151 |
+
['15496', '995', '0']
|
| 152 |
+
|
| 153 |
+
>>> tokenizer.decode(tokens)
|
| 154 |
+
'Hello world!'
|
| 155 |
+
|
| 156 |
+
"""
|
| 157 |
+
return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens])
|
| 158 |
+
|
| 159 |
+
def add_special_token(self, token):
|
| 160 |
+
"""Adds a special token to the dictionary.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
token (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
The id of new token in the vocabulary.
|
| 167 |
+
|
| 168 |
+
"""
|
| 169 |
+
self.special_tokens.append(token)
|
| 170 |
+
return self.add_symbol(token)
|
| 171 |
+
|
| 172 |
+
def part_of_whole_word(self, token, is_bos=False):
|
| 173 |
+
if is_bos:
|
| 174 |
+
return True
|
| 175 |
+
s = self._decode(token)
|
| 176 |
+
if (len(s)==1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0]))):
|
| 177 |
+
return False
|
| 178 |
+
|
| 179 |
+
return not s.startswith(' ')
|
| 180 |
+
|
| 181 |
+
def sym(self, id):
|
| 182 |
+
return self.ids_to_tokens[id]
|
| 183 |
+
|
| 184 |
+
def id(self, sym):
|
| 185 |
+
return self.vocab[sym]
|
| 186 |
+
|
| 187 |
+
def _encode(self, x: str) -> str:
|
| 188 |
+
return ' '.join(map(str, self.bpe.encode(x)))
|
| 189 |
+
|
| 190 |
+
def _decode(self, x: str) -> str:
|
| 191 |
+
return self.bpe.decode(map(int, x.split()))
|
| 192 |
+
|
| 193 |
+
def add_symbol(self, word, n=1):
|
| 194 |
+
"""Adds a word to the dictionary.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
word (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
| 198 |
+
n (int, optional): The frequency of the word.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
The id of the new word.
|
| 202 |
+
|
| 203 |
+
"""
|
| 204 |
+
if word in self.indices:
|
| 205 |
+
idx = self.indices[word]
|
| 206 |
+
self.count[idx] = self.count[idx] + n
|
| 207 |
+
return idx
|
| 208 |
+
else:
|
| 209 |
+
idx = len(self.symbols)
|
| 210 |
+
self.indices[word] = idx
|
| 211 |
+
self.symbols.append(word)
|
| 212 |
+
self.count.append(n)
|
| 213 |
+
return idx
|
| 214 |
+
|
| 215 |
+
def save_pretrained(self, path: str):
|
| 216 |
+
torch.save(self.gpt2_encoder, path)
|
nlu/DeBERTa/deberta/mlm.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 2 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This piece of code is modified based on https://github.com/huggingface/transformers
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
import pdb
|
| 12 |
+
|
| 13 |
+
from .bert import LayerNorm,ACT2FN
|
| 14 |
+
|
| 15 |
+
__all__ = ['MLMPredictionHead']
|
| 16 |
+
|
| 17 |
+
class MLMPredictionHead(nn.Module):
|
| 18 |
+
def __init__(self, config, vocab_size):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
|
| 21 |
+
self.dense = nn.Linear(config.hidden_size, self.embedding_size)
|
| 22 |
+
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
| 23 |
+
if isinstance(config.hidden_act, str) else config.hidden_act
|
| 24 |
+
|
| 25 |
+
self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps)
|
| 26 |
+
self.bias = nn.Parameter(torch.zeros(vocab_size))
|
| 27 |
+
self.pre_norm = PreLayerNorm(config)
|
| 28 |
+
|
| 29 |
+
def forward(self, hidden_states, embeding_weight):
|
| 30 |
+
hidden_states = self.pre_norm(hidden_states)
|
| 31 |
+
hidden_states = self.dense(hidden_states)
|
| 32 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 33 |
+
# b x s x d
|
| 34 |
+
hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
|
| 35 |
+
|
| 36 |
+
# b x s x v
|
| 37 |
+
logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias
|
| 38 |
+
return logits
|
nlu/DeBERTa/deberta/nnmodule.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pdb
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import copy
|
| 5 |
+
from torch import nn
|
| 6 |
+
from .config import ModelConfig
|
| 7 |
+
from ..utils import xtqdm as tqdm
|
| 8 |
+
from .cache_utils import load_model_state
|
| 9 |
+
|
| 10 |
+
from ..utils import get_logger
|
| 11 |
+
logger = get_logger()
|
| 12 |
+
|
| 13 |
+
__all__ = ['NNModule']
|
| 14 |
+
|
| 15 |
+
class NNModule(nn.Module):
|
| 16 |
+
""" An abstract class to handle weights initialization and \
|
| 17 |
+
a simple interface for dowloading and loading pretrained models.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
|
| 21 |
+
config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.config = config
|
| 28 |
+
|
| 29 |
+
def init_weights(self, module):
|
| 30 |
+
""" Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
|
| 34 |
+
module (:obj:`torch.nn.Module`): The module to apply the initialization.
|
| 35 |
+
|
| 36 |
+
Example::
|
| 37 |
+
|
| 38 |
+
class MyModule(NNModule):
|
| 39 |
+
def __init__(self, config):
|
| 40 |
+
# Add construction instructions
|
| 41 |
+
self.bert = DeBERTa(config)
|
| 42 |
+
|
| 43 |
+
# Add other modules
|
| 44 |
+
...
|
| 45 |
+
|
| 46 |
+
# Apply initialization
|
| 47 |
+
self.apply(self.init_weights)
|
| 48 |
+
|
| 49 |
+
"""
|
| 50 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 51 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 52 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 53 |
+
module.bias.data.zero_()
|
| 54 |
+
|
| 55 |
+
def export_onnx(self, onnx_path, input):
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
|
| 60 |
+
""" Instantiate a sub-class of NNModule from a pre-trained model file.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
|
| 64 |
+
model_path (:obj:`str`): Path or name of the pre-trained model which can be either,
|
| 65 |
+
|
| 66 |
+
- The path of pre-trained model
|
| 67 |
+
|
| 68 |
+
- The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**].
|
| 69 |
+
|
| 70 |
+
If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models.
|
| 71 |
+
|
| 72 |
+
model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order:
|
| 73 |
+
|
| 74 |
+
1. ['config'] in the model state dictionary.
|
| 75 |
+
|
| 76 |
+
2. `model_config.json` aside the `model_path`.
|
| 77 |
+
|
| 78 |
+
If it failed to find a config the method will fail.
|
| 79 |
+
|
| 80 |
+
tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`.
|
| 81 |
+
|
| 82 |
+
no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`.
|
| 83 |
+
|
| 84 |
+
cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa`
|
| 85 |
+
|
| 86 |
+
Return:
|
| 87 |
+
|
| 88 |
+
:obj:`NNModule` : The sub-class object.
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
# Load config
|
| 92 |
+
if model_config:
|
| 93 |
+
config = ModelConfig.from_json_file(model_config)
|
| 94 |
+
else:
|
| 95 |
+
config = None
|
| 96 |
+
model_config = None
|
| 97 |
+
model_state = None
|
| 98 |
+
if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
|
| 99 |
+
model_path = None
|
| 100 |
+
try:
|
| 101 |
+
model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
|
| 102 |
+
except Exception as exp:
|
| 103 |
+
raise Exception(f'Failed to get model {model_path}. Exception: {exp}')
|
| 104 |
+
|
| 105 |
+
if config is not None and model_config is not None:
|
| 106 |
+
for k in config.__dict__:
|
| 107 |
+
if k not in ['hidden_size',
|
| 108 |
+
'intermediate_size',
|
| 109 |
+
'num_attention_heads',
|
| 110 |
+
'num_hidden_layers',
|
| 111 |
+
'vocab_size',
|
| 112 |
+
'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0):
|
| 113 |
+
model_config.__dict__[k] = config.__dict__[k]
|
| 114 |
+
if model_config is not None:
|
| 115 |
+
config = copy.copy(model_config)
|
| 116 |
+
vocab_size = config.vocab_size
|
| 117 |
+
# Instantiate model.
|
| 118 |
+
model = cls(config, *inputs, **kwargs)
|
| 119 |
+
if not model_state:
|
| 120 |
+
return model
|
| 121 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 122 |
+
state_dict = model_state.copy()
|
| 123 |
+
|
| 124 |
+
missing_keys = []
|
| 125 |
+
unexpected_keys = []
|
| 126 |
+
error_msgs = []
|
| 127 |
+
metadata = getattr(state_dict, '_metadata', None)
|
| 128 |
+
def load(module, prefix=''):
|
| 129 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| 130 |
+
module._load_from_state_dict(
|
| 131 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
| 132 |
+
for name, child in module._modules.items():
|
| 133 |
+
if child is not None:
|
| 134 |
+
load(child, prefix + name + '.')
|
| 135 |
+
load(model)
|
| 136 |
+
logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}')
|
| 137 |
+
return model
|
nlu/DeBERTa/deberta/ops.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
#
|
| 6 |
+
# Author: penhe@microsoft.com
|
| 7 |
+
# Date: 01/15/2020
|
| 8 |
+
#
|
| 9 |
+
|
| 10 |
+
import pdb
|
| 11 |
+
import math
|
| 12 |
+
from packaging import version
|
| 13 |
+
import torch
|
| 14 |
+
from torch.nn import LayerNorm
|
| 15 |
+
from ..utils.jit_tracing import traceable
|
| 16 |
+
|
| 17 |
+
if version.Version(torch.__version__) >= version.Version('1.0.0'):
|
| 18 |
+
from torch import _softmax_backward_data as _softmax_backward_data
|
| 19 |
+
else:
|
| 20 |
+
from torch import softmax_backward_data as _softmax_backward_data
|
| 21 |
+
|
| 22 |
+
__all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax', 'ACT2FN', 'LayerNorm']
|
| 23 |
+
|
| 24 |
+
@traceable
|
| 25 |
+
class XSoftmax(torch.autograd.Function):
|
| 26 |
+
""" Masked Softmax which is optimized for saving memory
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
|
| 30 |
+
input (:obj:`torch.tensor`): The input tensor that will apply softmax.
|
| 31 |
+
mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation.
|
| 32 |
+
dim (int): The dimenssion that will apply softmax.
|
| 33 |
+
|
| 34 |
+
Example::
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from DeBERTa.deberta import XSoftmax
|
| 38 |
+
# Make a tensor
|
| 39 |
+
x = torch.randn([4,20,100])
|
| 40 |
+
# Create a mask
|
| 41 |
+
mask = (x>0).int()
|
| 42 |
+
y = XSoftmax.apply(x, mask, dim=-1)
|
| 43 |
+
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def forward(self, input, mask, dim):
|
| 48 |
+
"""
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
self.dim = dim
|
| 52 |
+
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
|
| 53 |
+
rmask = ~(mask.bool())
|
| 54 |
+
else:
|
| 55 |
+
rmask = (1-mask).byte() # This line is not supported by Onnx tracing.
|
| 56 |
+
|
| 57 |
+
output = input.masked_fill(rmask, float('-inf'))
|
| 58 |
+
output = torch.softmax(output, self.dim)
|
| 59 |
+
output.masked_fill_(rmask, 0)
|
| 60 |
+
self.save_for_backward(output)
|
| 61 |
+
return output
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def backward(self, grad_output):
|
| 65 |
+
"""
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
output, = self.saved_tensors
|
| 69 |
+
if version.Version(torch.__version__) >= version.Version('1.11.0a'):
|
| 70 |
+
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
|
| 71 |
+
else:
|
| 72 |
+
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
|
| 73 |
+
return inputGrad, None, None
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def symbolic(g, self, mask, dim):
|
| 77 |
+
import torch.onnx.symbolic_helper as sym_help
|
| 78 |
+
from torch.onnx.symbolic_opset9 import masked_fill, softmax
|
| 79 |
+
|
| 80 |
+
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx['Long'])
|
| 81 |
+
r_mask = g.op("Cast", g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx['Byte'])
|
| 82 |
+
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float('-inf'))))
|
| 83 |
+
output = softmax(g, output, dim)
|
| 84 |
+
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
| 85 |
+
|
| 86 |
+
class DropoutContext(object):
|
| 87 |
+
def __init__(self):
|
| 88 |
+
self.dropout = 0
|
| 89 |
+
self.mask = None
|
| 90 |
+
self.scale = 1
|
| 91 |
+
self.reuse_mask = True
|
| 92 |
+
|
| 93 |
+
def get_mask(input, local_context):
|
| 94 |
+
if not isinstance(local_context, DropoutContext):
|
| 95 |
+
dropout = local_context
|
| 96 |
+
mask = None
|
| 97 |
+
else:
|
| 98 |
+
dropout = local_context.dropout
|
| 99 |
+
dropout *= local_context.scale
|
| 100 |
+
mask = local_context.mask if local_context.reuse_mask else None
|
| 101 |
+
|
| 102 |
+
if dropout>0 and mask is None:
|
| 103 |
+
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
|
| 104 |
+
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool()
|
| 105 |
+
else:
|
| 106 |
+
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte()
|
| 107 |
+
|
| 108 |
+
if isinstance(local_context, DropoutContext):
|
| 109 |
+
if local_context.mask is None:
|
| 110 |
+
local_context.mask = mask
|
| 111 |
+
|
| 112 |
+
return mask, dropout
|
| 113 |
+
|
| 114 |
+
@traceable
|
| 115 |
+
class XDropout(torch.autograd.Function):
|
| 116 |
+
@staticmethod
|
| 117 |
+
def forward(ctx, input, local_ctx):
|
| 118 |
+
mask, dropout = get_mask(input, local_ctx)
|
| 119 |
+
ctx.scale=1.0/(1-dropout)
|
| 120 |
+
if dropout>0:
|
| 121 |
+
ctx.save_for_backward(mask)
|
| 122 |
+
return input.masked_fill(mask, 0)*ctx.scale
|
| 123 |
+
else:
|
| 124 |
+
return input
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def backward(ctx, grad_output):
|
| 128 |
+
if ctx.scale > 1:
|
| 129 |
+
mask, = ctx.saved_tensors
|
| 130 |
+
return grad_output.masked_fill(mask, 0)*ctx.scale, None
|
| 131 |
+
else:
|
| 132 |
+
return grad_output, None
|
| 133 |
+
|
| 134 |
+
class StableDropout(torch.nn.Module):
|
| 135 |
+
""" Optimized dropout module for stabilizing the training
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
|
| 139 |
+
drop_prob (float): the dropout probabilities
|
| 140 |
+
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self, drop_prob):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.drop_prob = drop_prob
|
| 146 |
+
self.count = 0
|
| 147 |
+
self.context_stack = None
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
""" Call the module
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
|
| 154 |
+
x (:obj:`torch.tensor`): The input tensor to apply dropout
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
"""
|
| 158 |
+
if self.training and self.drop_prob>0:
|
| 159 |
+
return XDropout.apply(x, self.get_context())
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
def clear_context(self):
|
| 163 |
+
self.count = 0
|
| 164 |
+
self.context_stack = None
|
| 165 |
+
|
| 166 |
+
def init_context(self, reuse_mask=True, scale = 1):
|
| 167 |
+
if self.context_stack is None:
|
| 168 |
+
self.context_stack = []
|
| 169 |
+
self.count = 0
|
| 170 |
+
for c in self.context_stack:
|
| 171 |
+
c.reuse_mask = reuse_mask
|
| 172 |
+
c.scale = scale
|
| 173 |
+
|
| 174 |
+
def get_context(self):
|
| 175 |
+
if self.context_stack is not None:
|
| 176 |
+
if self.count >= len(self.context_stack):
|
| 177 |
+
self.context_stack.append(DropoutContext())
|
| 178 |
+
ctx = self.context_stack[self.count]
|
| 179 |
+
ctx.dropout = self.drop_prob
|
| 180 |
+
self.count += 1
|
| 181 |
+
return ctx
|
| 182 |
+
else:
|
| 183 |
+
return self.drop_prob
|
| 184 |
+
|
| 185 |
+
def MaskedLayerNorm(layerNorm, input, mask = None):
|
| 186 |
+
""" Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function
|
| 190 |
+
input (:obj:`torch.tensor`): The input tensor
|
| 191 |
+
mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0`
|
| 192 |
+
|
| 193 |
+
Example::
|
| 194 |
+
|
| 195 |
+
# Create a tensor b x n x d
|
| 196 |
+
x = torch.randn([1,10,100])
|
| 197 |
+
m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int)
|
| 198 |
+
LayerNorm = DeBERTa.deberta.LayerNorm(100)
|
| 199 |
+
y = MaskedLayerNorm(LayerNorm, x, m)
|
| 200 |
+
|
| 201 |
+
"""
|
| 202 |
+
output = layerNorm(input).to(input)
|
| 203 |
+
if mask is None:
|
| 204 |
+
return output
|
| 205 |
+
if mask.dim()!=input.dim():
|
| 206 |
+
if mask.dim()==4:
|
| 207 |
+
mask=mask.squeeze(1).squeeze(1)
|
| 208 |
+
mask = mask.unsqueeze(2)
|
| 209 |
+
mask = mask.to(output.dtype)
|
| 210 |
+
return output*mask
|
| 211 |
+
|
| 212 |
+
def gelu(x):
|
| 213 |
+
"""Implementation of the gelu activation function.
|
| 214 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
| 215 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
| 216 |
+
"""
|
| 217 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def swish(x):
|
| 221 |
+
return x * torch.sigmoid(x)
|
| 222 |
+
|
| 223 |
+
def linear_act(x):
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish, "tanh": torch.tanh, "linear": linear_act, 'sigmoid': torch.sigmoid}
|
| 227 |
+
|
| 228 |
+
|
nlu/DeBERTa/deberta/pooling.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Author: penhe@microsoft.com
|
| 3 |
+
# Date: 01/25/2019
|
| 4 |
+
#
|
| 5 |
+
"""
|
| 6 |
+
Pooling functions
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from torch import nn
|
| 10 |
+
import copy
|
| 11 |
+
import json
|
| 12 |
+
import pdb
|
| 13 |
+
from .bert import ACT2FN
|
| 14 |
+
from .ops import StableDropout
|
| 15 |
+
from .config import AbsModelConfig
|
| 16 |
+
|
| 17 |
+
__all__ = ['PoolConfig', 'ContextPooler']
|
| 18 |
+
|
| 19 |
+
class PoolConfig(AbsModelConfig):
|
| 20 |
+
"""Configuration class to store the configuration of `pool layer`.
|
| 21 |
+
|
| 22 |
+
Parameters:
|
| 23 |
+
|
| 24 |
+
config (:class:`~DeBERTa.deberta.ModelConfig`): The model config. The field of pool config will be initalized with the `pooling` field in model config.
|
| 25 |
+
|
| 26 |
+
Attributes:
|
| 27 |
+
|
| 28 |
+
hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`.
|
| 29 |
+
|
| 30 |
+
dropout (float): The dropout rate applied on the output of `[CLS]` token,
|
| 31 |
+
|
| 32 |
+
hidden_act (:obj:`str`): The activation function of the projection layer, it can be one of ['gelu', 'tanh'].
|
| 33 |
+
|
| 34 |
+
Example::
|
| 35 |
+
|
| 36 |
+
# Here is the content of an exmple model config file in json format
|
| 37 |
+
|
| 38 |
+
{
|
| 39 |
+
"hidden_size": 768,
|
| 40 |
+
"num_hidden_layers" 12,
|
| 41 |
+
"num_attention_heads": 12,
|
| 42 |
+
"intermediate_size": 3072,
|
| 43 |
+
...
|
| 44 |
+
"pooling": {
|
| 45 |
+
"hidden_size": 768,
|
| 46 |
+
"hidden_act": "gelu",
|
| 47 |
+
"dropout": 0.1
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
def __init__(self, config=None):
|
| 53 |
+
"""Constructs PoolConfig.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
`config`: the config of the model. The field of pool config will be initalized with the 'pooling' field in model config.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
self.hidden_size = 768
|
| 60 |
+
self.dropout = 0
|
| 61 |
+
self.hidden_act = 'gelu'
|
| 62 |
+
if config:
|
| 63 |
+
pool_config = getattr(config, 'pooling', config)
|
| 64 |
+
if isinstance(pool_config, dict):
|
| 65 |
+
pool_config = AbsModelConfig.from_dict(pool_config)
|
| 66 |
+
self.hidden_size = getattr(pool_config, 'hidden_size', config.hidden_size)
|
| 67 |
+
self.dropout = getattr(pool_config, 'dropout', 0)
|
| 68 |
+
self.hidden_act = getattr(pool_config, 'hidden_act', 'gelu')
|
| 69 |
+
|
| 70 |
+
class ContextPooler(nn.Module):
|
| 71 |
+
def __init__(self, config):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 74 |
+
self.dropout = StableDropout(config.dropout)
|
| 75 |
+
self.config = config
|
| 76 |
+
|
| 77 |
+
def forward(self, hidden_states, mask = None):
|
| 78 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 79 |
+
# to the first token.
|
| 80 |
+
|
| 81 |
+
context_token = hidden_states[:, 0]
|
| 82 |
+
context_token = self.dropout(context_token)
|
| 83 |
+
pooled_output = self.dense(context_token)
|
| 84 |
+
pooled_output = ACT2FN[self.config.hidden_act](pooled_output)
|
| 85 |
+
return pooled_output
|
| 86 |
+
|
| 87 |
+
def output_dim(self):
|
| 88 |
+
return self.config.hidden_size
|
nlu/DeBERTa/deberta/pretrained_models.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
nlu/DeBERTa/deberta/spm_tokenizer.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft, Inc. 2020
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
#
|
| 6 |
+
# Author: penhe@microsoft.com
|
| 7 |
+
# Date: 11/15/2020
|
| 8 |
+
#
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import sentencepiece as sp
|
| 12 |
+
import six
|
| 13 |
+
import unicodedata
|
| 14 |
+
import os
|
| 15 |
+
import regex as re
|
| 16 |
+
from .cache_utils import load_vocab
|
| 17 |
+
from ..utils import get_logger
|
| 18 |
+
logger=get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
import pdb
|
| 22 |
+
|
| 23 |
+
__all__ = ['SPMTokenizer']
|
| 24 |
+
|
| 25 |
+
class SPMTokenizer:
|
| 26 |
+
def __init__(self, vocab_file, do_lower_case=False, special_tokens=None, bpe_dropout=0, split_by_punct=False):
|
| 27 |
+
self.split_by_punct = split_by_punct
|
| 28 |
+
spm = sp.SentencePieceProcessor()
|
| 29 |
+
assert os.path.exists(vocab_file)
|
| 30 |
+
spm.load(vocab_file)
|
| 31 |
+
bpe_vocab_size = spm.GetPieceSize()
|
| 32 |
+
# Token map
|
| 33 |
+
# <unk> 0+1
|
| 34 |
+
# <s> 1+1
|
| 35 |
+
# </s> 2+1
|
| 36 |
+
self.vocab = {spm.IdToPiece(i):i for i in range(bpe_vocab_size)}
|
| 37 |
+
self.id_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)]
|
| 38 |
+
#self.vocab['[PAD]'] = 0
|
| 39 |
+
#self.vocab['[CLS]'] = 1
|
| 40 |
+
#self.vocab['[SEP]'] = 2
|
| 41 |
+
#self.vocab['[UNK]'] = 3
|
| 42 |
+
|
| 43 |
+
_special_tokens = ['[MASK]', '[SEP]', '[PAD]', '[UNK]', '[CLS]']
|
| 44 |
+
self.special_tokens = []
|
| 45 |
+
if special_tokens is not None:
|
| 46 |
+
_special_tokens.extend(special_tokens)
|
| 47 |
+
for t in _special_tokens:
|
| 48 |
+
self.add_special_token(t)
|
| 49 |
+
|
| 50 |
+
self.spm = spm
|
| 51 |
+
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
| 52 |
+
|
| 53 |
+
def tokenize(self, text):
|
| 54 |
+
pieces = self._encode_as_pieces(text)
|
| 55 |
+
def _norm(x):
|
| 56 |
+
if x not in self.vocab or x=='<unk>':
|
| 57 |
+
return '[UNK]'
|
| 58 |
+
else:
|
| 59 |
+
return x
|
| 60 |
+
pieces = [_norm(p) for p in pieces]
|
| 61 |
+
return pieces
|
| 62 |
+
|
| 63 |
+
def convert_tokens_to_ids(self, tokens):
|
| 64 |
+
return [self.vocab[t] if t in self.vocab else 1 for t in tokens]
|
| 65 |
+
|
| 66 |
+
def convert_ids_to_tokens(self, ids):
|
| 67 |
+
tokens = []
|
| 68 |
+
for i in ids:
|
| 69 |
+
tokens.append(self.ids_to_tokens[i])
|
| 70 |
+
return tokens
|
| 71 |
+
|
| 72 |
+
def decode(self, tokens, start=-1, end=-1, raw_text=None):
|
| 73 |
+
if raw_text is None:
|
| 74 |
+
return self.spm.decode_pieces([t for t in tokens if t not in self.special_tokens])
|
| 75 |
+
else:
|
| 76 |
+
words = self.split_to_words(raw_text)
|
| 77 |
+
word_tokens = [self.tokenize(w) for w in words]
|
| 78 |
+
wt = [w for t in word_tokens for w in t]
|
| 79 |
+
#assert tokens == wt, f'{tokens} || {wt}'
|
| 80 |
+
if wt!=tokens:
|
| 81 |
+
for a,b in zip(wt, tokens):
|
| 82 |
+
if a!=b:
|
| 83 |
+
pdb.set_trace()
|
| 84 |
+
token2words = [0]*len(tokens)
|
| 85 |
+
tid = 0
|
| 86 |
+
for i,w in enumerate(word_tokens):
|
| 87 |
+
for k,t in enumerate(w):
|
| 88 |
+
token2words[tid] = i
|
| 89 |
+
tid += 1
|
| 90 |
+
word_start = token2words[start]
|
| 91 |
+
word_end = token2words[end] if end <len(tokens) else len(words)
|
| 92 |
+
text = ''.join(words[word_start:word_end])
|
| 93 |
+
return text
|
| 94 |
+
|
| 95 |
+
def add_special_token(self, token):
|
| 96 |
+
if token not in self.special_tokens:
|
| 97 |
+
self.special_tokens.append(token)
|
| 98 |
+
if token not in self.vocab:
|
| 99 |
+
self.vocab[token] = len(self.vocab)
|
| 100 |
+
self.id_to_tokens.append(token)
|
| 101 |
+
return self.id(token)
|
| 102 |
+
|
| 103 |
+
def part_of_whole_word(self, token, is_bos=False):
|
| 104 |
+
if is_bos:
|
| 105 |
+
return True
|
| 106 |
+
if (len(token)==1 and (_is_whitespace(list(token)[0]) or _is_control(list(token)[0]) or _is_punctuation(list(token)[0]))) or token in self.special_tokens:
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
word_start = b'\xe2\x96\x81'.decode('utf-8')
|
| 110 |
+
return not token.startswith(word_start)
|
| 111 |
+
|
| 112 |
+
def pad(self):
|
| 113 |
+
return '[PAD]'
|
| 114 |
+
|
| 115 |
+
def bos(self):
|
| 116 |
+
return '[CLS]'
|
| 117 |
+
|
| 118 |
+
def eos(self):
|
| 119 |
+
return '[SEP]'
|
| 120 |
+
|
| 121 |
+
def unk(self):
|
| 122 |
+
return '[UNK]'
|
| 123 |
+
|
| 124 |
+
def mask(self):
|
| 125 |
+
return '[MASK]'
|
| 126 |
+
|
| 127 |
+
def sym(self, id):
|
| 128 |
+
return self.ids_to_tokens[id]
|
| 129 |
+
|
| 130 |
+
def id(self, sym):
|
| 131 |
+
return self.vocab[sym] if sym in self.vocab else 1
|
| 132 |
+
|
| 133 |
+
def _encode_as_pieces(self, text):
|
| 134 |
+
text = convert_to_unicode(text)
|
| 135 |
+
if self.split_by_punct:
|
| 136 |
+
words = self._run_split_on_punc(text)
|
| 137 |
+
pieces = [self.spm.encode_as_pieces(w) for w in words]
|
| 138 |
+
return [p for w in pieces for p in w]
|
| 139 |
+
else:
|
| 140 |
+
return self.spm.encode_as_pieces(text)
|
| 141 |
+
|
| 142 |
+
def split_to_words(self, text):
|
| 143 |
+
pieces = self._encode_as_pieces(text)
|
| 144 |
+
word_start = b'\xe2\x96\x81'.decode('utf-8')
|
| 145 |
+
words = []
|
| 146 |
+
offset = 0
|
| 147 |
+
prev_end = 0
|
| 148 |
+
for i,p in enumerate(pieces):
|
| 149 |
+
if p.startswith(word_start):
|
| 150 |
+
if offset>prev_end:
|
| 151 |
+
words.append(text[prev_end:offset])
|
| 152 |
+
prev_end = offset
|
| 153 |
+
w = p.replace(word_start, '')
|
| 154 |
+
else:
|
| 155 |
+
w = p
|
| 156 |
+
try:
|
| 157 |
+
s = text.index(w, offset)
|
| 158 |
+
pn = ""
|
| 159 |
+
k = i+1
|
| 160 |
+
while k < len(pieces):
|
| 161 |
+
pn = pieces[k].replace(word_start, '')
|
| 162 |
+
if len(pn)>0:
|
| 163 |
+
break
|
| 164 |
+
k += 1
|
| 165 |
+
|
| 166 |
+
if len(pn)>0 and pn in text[offset:s]:
|
| 167 |
+
offset = offset + 1
|
| 168 |
+
else:
|
| 169 |
+
offset = s + len(w)
|
| 170 |
+
except:
|
| 171 |
+
offset = offset + 1
|
| 172 |
+
|
| 173 |
+
if prev_end< offset:
|
| 174 |
+
words.append(text[prev_end:offset])
|
| 175 |
+
|
| 176 |
+
return words
|
| 177 |
+
|
| 178 |
+
def _run_strip_accents(self, text):
|
| 179 |
+
"""Strips accents from a piece of text."""
|
| 180 |
+
text = unicodedata.normalize("NFD", text)
|
| 181 |
+
output = []
|
| 182 |
+
for char in text:
|
| 183 |
+
cat = unicodedata.category(char)
|
| 184 |
+
if cat == "Mn":
|
| 185 |
+
continue
|
| 186 |
+
output.append(char)
|
| 187 |
+
return "".join(output)
|
| 188 |
+
|
| 189 |
+
def _run_split_on_punc(self, text):
|
| 190 |
+
"""Splits punctuation on a piece of text."""
|
| 191 |
+
#words = list(re.findall(self.pat, text))
|
| 192 |
+
chars = list(text)
|
| 193 |
+
i = 0
|
| 194 |
+
start_new_word = True
|
| 195 |
+
output = []
|
| 196 |
+
while i < len(chars):
|
| 197 |
+
char = chars[i]
|
| 198 |
+
if _is_punctuation(char):
|
| 199 |
+
output.append([char])
|
| 200 |
+
start_new_word = True
|
| 201 |
+
else:
|
| 202 |
+
if start_new_word:
|
| 203 |
+
output.append([])
|
| 204 |
+
start_new_word = False
|
| 205 |
+
output[-1].append(char)
|
| 206 |
+
i += 1
|
| 207 |
+
|
| 208 |
+
return ["".join(x) for x in output]
|
| 209 |
+
|
| 210 |
+
def _tokenize_chinese_chars(self, text):
|
| 211 |
+
"""Adds whitespace around any CJK character."""
|
| 212 |
+
output = []
|
| 213 |
+
for char in text:
|
| 214 |
+
cp = ord(char)
|
| 215 |
+
if self._is_chinese_char(cp):
|
| 216 |
+
output.append(" ")
|
| 217 |
+
output.append(char)
|
| 218 |
+
output.append(" ")
|
| 219 |
+
else:
|
| 220 |
+
output.append(char)
|
| 221 |
+
return "".join(output)
|
| 222 |
+
|
| 223 |
+
def _is_chinese_char(self, cp):
|
| 224 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 225 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 226 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 227 |
+
#
|
| 228 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 229 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 230 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 231 |
+
# space-separated words, so they are not treated specially and handled
|
| 232 |
+
# like the all of the other languages.
|
| 233 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
| 234 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
| 235 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
| 236 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
| 237 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
| 238 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
| 239 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
| 240 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
| 241 |
+
return True
|
| 242 |
+
|
| 243 |
+
return False
|
| 244 |
+
|
| 245 |
+
def _clean_text(self, text):
|
| 246 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
| 247 |
+
output = []
|
| 248 |
+
for char in text:
|
| 249 |
+
cp = ord(char)
|
| 250 |
+
if cp == 0 or cp == 0xfffd or _is_control(char):
|
| 251 |
+
continue
|
| 252 |
+
if _is_whitespace(char):
|
| 253 |
+
output.append(" ")
|
| 254 |
+
else:
|
| 255 |
+
output.append(char)
|
| 256 |
+
return "".join(output)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _is_whitespace(char):
|
| 260 |
+
"""Checks whether `chars` is a whitespace character."""
|
| 261 |
+
# \t, \n, and \r are technically contorl characters but we treat them
|
| 262 |
+
# as whitespace since they are generally considered as such.
|
| 263 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
| 264 |
+
return True
|
| 265 |
+
cat = unicodedata.category(char)
|
| 266 |
+
if cat == "Zs":
|
| 267 |
+
return True
|
| 268 |
+
return False
|
| 269 |
+
|
| 270 |
+
def _is_control(char):
|
| 271 |
+
"""Checks whether `chars` is a control character."""
|
| 272 |
+
# These are technically control characters but we count them as whitespace
|
| 273 |
+
# characters.
|
| 274 |
+
if char == "\t" or char == "\n" or char == "\r":
|
| 275 |
+
return False
|
| 276 |
+
cat = unicodedata.category(char)
|
| 277 |
+
if cat.startswith("C"):
|
| 278 |
+
return True
|
| 279 |
+
return False
|
| 280 |
+
|
| 281 |
+
def _is_punctuation(char):
|
| 282 |
+
"""Checks whether `chars` is a punctuation character."""
|
| 283 |
+
cp = ord(char)
|
| 284 |
+
# We treat all non-letter/number ASCII as punctuation.
|
| 285 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
| 286 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
| 287 |
+
# consistency.
|
| 288 |
+
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
| 289 |
+
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
| 290 |
+
return True
|
| 291 |
+
cat = unicodedata.category(char)
|
| 292 |
+
if cat.startswith("P"):
|
| 293 |
+
return True
|
| 294 |
+
return False
|
| 295 |
+
|
| 296 |
+
def whitespace_tokenize(text):
|
| 297 |
+
"""Runs basic whitespace cleaning and splitting on a peice of text."""
|
| 298 |
+
text = text.strip()
|
| 299 |
+
if not text:
|
| 300 |
+
return []
|
| 301 |
+
tokens = text.split()
|
| 302 |
+
return tokens
|
| 303 |
+
|
| 304 |
+
def convert_to_unicode(text):
|
| 305 |
+
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
| 306 |
+
if six.PY3:
|
| 307 |
+
if isinstance(text, str):
|
| 308 |
+
return text
|
| 309 |
+
elif isinstance(text, bytes):
|
| 310 |
+
return text.decode("utf-8", "ignore")
|
| 311 |
+
else:
|
| 312 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
| 313 |
+
elif six.PY2:
|
| 314 |
+
if isinstance(text, str):
|
| 315 |
+
return text.decode("utf-8", "ignore")
|
| 316 |
+
elif isinstance(text, unicode):
|
| 317 |
+
return text
|
| 318 |
+
else:
|
| 319 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
| 320 |
+
else:
|
| 321 |
+
raise ValueError("Not running on Python2 or Python 3?")
|
| 322 |
+
|
nlu/DeBERTa/deberta/tokenizers.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Author: penhe@microsoft.com
|
| 3 |
+
# Date: 04/25/2019
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
""" tokenizers
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from .spm_tokenizer import *
|
| 10 |
+
from .gpt2_tokenizer import GPT2Tokenizer
|
| 11 |
+
|
| 12 |
+
__all__ = ['tokenizers']
|
| 13 |
+
tokenizers={
|
| 14 |
+
'gpt2': GPT2Tokenizer,
|
| 15 |
+
'spm': SPMTokenizer
|
| 16 |
+
}
|
nlu/DeBERTa/optims/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
#
|
| 5 |
+
# Author: Pengcheng He (penhe@microsoft.com)
|
| 6 |
+
# Date: 05/15/2019
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
""" optimizers
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from .xadam import XAdam
|
| 13 |
+
from .fp16_optimizer import *
|
| 14 |
+
from .lr_schedulers import SCHEDULES
|
| 15 |
+
from .args import get_args
|
| 16 |
+
|
nlu/DeBERTa/optims/args.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
#
|
| 5 |
+
# Author: Pengcheng He (penhe@microsoft.com)
|
| 6 |
+
# Date: 05/15/2019
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
""" Arguments for optimizer
|
| 10 |
+
"""
|
| 11 |
+
import argparse
|
| 12 |
+
from ..utils import boolean_string
|
| 13 |
+
|
| 14 |
+
__all__ = ['get_args']
|
| 15 |
+
def get_args():
|
| 16 |
+
parser=argparse.ArgumentParser(add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 17 |
+
group = parser.add_argument_group(title='Optimizer', description='Parameters for the distributed optimizer')
|
| 18 |
+
group.add_argument('--fp16',
|
| 19 |
+
default=False,
|
| 20 |
+
type=boolean_string,
|
| 21 |
+
help="Whether to use 16-bit float precision instead of 32-bit")
|
| 22 |
+
|
| 23 |
+
group.add_argument('--loss_scale',
|
| 24 |
+
type=float, default=16384,
|
| 25 |
+
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
|
| 26 |
+
|
| 27 |
+
group.add_argument('--scale_steps',
|
| 28 |
+
type=int, default=250,
|
| 29 |
+
help='The steps to wait to increase the loss scale.')
|
| 30 |
+
|
| 31 |
+
group.add_argument('--lookahead_k',
|
| 32 |
+
default=-1,
|
| 33 |
+
type=int,
|
| 34 |
+
help="lookahead k parameter")
|
| 35 |
+
|
| 36 |
+
group.add_argument('--lookahead_alpha',
|
| 37 |
+
default=0.5,
|
| 38 |
+
type=float,
|
| 39 |
+
help="lookahead alpha parameter")
|
| 40 |
+
|
| 41 |
+
group.add_argument('--with_radam',
|
| 42 |
+
default=False,
|
| 43 |
+
type=boolean_string,
|
| 44 |
+
help="whether to use RAdam")
|
| 45 |
+
|
| 46 |
+
group.add_argument('--opt_type',
|
| 47 |
+
type=str.lower,
|
| 48 |
+
default='adam',
|
| 49 |
+
choices=['adam', 'admax'],
|
| 50 |
+
help="The optimizer to be used.")
|
| 51 |
+
|
| 52 |
+
group.add_argument("--warmup_proportion",
|
| 53 |
+
default=0.1,
|
| 54 |
+
type=float,
|
| 55 |
+
help="Proportion of training to perform linear learning rate warmup for. "
|
| 56 |
+
"E.g., 0.1 = 10%% of training.")
|
| 57 |
+
|
| 58 |
+
group.add_argument("--lr_schedule_ends",
|
| 59 |
+
default=0,
|
| 60 |
+
type=float,
|
| 61 |
+
help="The ended learning rate scale for learning rate scheduling")
|
| 62 |
+
|
| 63 |
+
group.add_argument("--lr_schedule",
|
| 64 |
+
default='warmup_linear',
|
| 65 |
+
type=str,
|
| 66 |
+
help="The learning rate scheduler used for traning. " +
|
| 67 |
+
"E.g. warmup_linear, warmup_linear_shift, warmup_cosine, warmup_constant. Default, warmup_linear")
|
| 68 |
+
|
| 69 |
+
group.add_argument("--max_grad_norm",
|
| 70 |
+
default=1,
|
| 71 |
+
type=float,
|
| 72 |
+
help="The clip threshold of global gradient norm")
|
| 73 |
+
|
| 74 |
+
group.add_argument("--learning_rate",
|
| 75 |
+
default=5e-5,
|
| 76 |
+
type=float,
|
| 77 |
+
help="The initial learning rate for Adam.")
|
| 78 |
+
|
| 79 |
+
group.add_argument("--epsilon",
|
| 80 |
+
default=1e-6,
|
| 81 |
+
type=float,
|
| 82 |
+
help="epsilon setting for Adam.")
|
| 83 |
+
|
| 84 |
+
group.add_argument("--adam_beta1",
|
| 85 |
+
default=0.9,
|
| 86 |
+
type=float,
|
| 87 |
+
help="The beta1 parameter for Adam.")
|
| 88 |
+
|
| 89 |
+
group.add_argument("--adam_beta2",
|
| 90 |
+
default=0.999,
|
| 91 |
+
type=float,
|
| 92 |
+
help="The beta2 parameter for Adam.")
|
| 93 |
+
|
| 94 |
+
group.add_argument('--weight_decay',
|
| 95 |
+
type=float,
|
| 96 |
+
default=0.01,
|
| 97 |
+
help="The weight decay rate")
|
| 98 |
+
|
| 99 |
+
return parser
|
| 100 |
+
|
nlu/DeBERTa/optims/fp16_optimizer.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This source code is licensed under the MIT license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
#
|
| 5 |
+
# Author: Pengcheng He (penhe@microsoft.com)
|
| 6 |
+
# Date: 05/15/2019
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
""" FP16 optimizer wrapper
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
import numpy as np
|
| 14 |
+
import math
|
| 15 |
+
import torch
|
| 16 |
+
import pdb
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
| 19 |
+
import ctypes
|
| 20 |
+
|
| 21 |
+
from ..utils import get_logger,boolean_string
|
| 22 |
+
logger=get_logger()
|
| 23 |
+
|
| 24 |
+
__all__ = ['Fp16Optimizer', 'ExpLossScaler', 'get_world_size']
|
| 25 |
+
|
| 26 |
+
def get_world_size():
|
| 27 |
+
try:
|
| 28 |
+
wd = dist.get_world_size()
|
| 29 |
+
return wd
|
| 30 |
+
except:
|
| 31 |
+
return 1
|
| 32 |
+
|
| 33 |
+
def fused_norm(input):
|
| 34 |
+
return torch.norm(input, p=2, dtype=torch.float32)
|
| 35 |
+
|
| 36 |
+
class OptParameter(torch.Tensor):
|
| 37 |
+
def __new__(cls, data, out_data=None, grad=None, name=None):
|
| 38 |
+
param = torch.Tensor._make_subclass(cls, data)
|
| 39 |
+
param._xgrad = grad
|
| 40 |
+
param.out_data = out_data
|
| 41 |
+
param._name = name
|
| 42 |
+
return param
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def name(self):
|
| 46 |
+
return self._name
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def grad(self):
|
| 50 |
+
return self._xgrad
|
| 51 |
+
|
| 52 |
+
@grad.setter
|
| 53 |
+
def grad(self, grad):
|
| 54 |
+
self._xgrad = grad
|
| 55 |
+
|
| 56 |
+
class Fp16Optimizer(object):
|
| 57 |
+
def __init__(self, param_groups, optimizer_fn, loss_scaler=None, grad_clip_norm = 1.0, lookahead_k = -1, lookahead_alpha = 0.5, rank=-1, distributed=False):
|
| 58 |
+
# all parameters should on the same device
|
| 59 |
+
groups = []
|
| 60 |
+
original_groups = []
|
| 61 |
+
self.rank = rank
|
| 62 |
+
self.distributed = distributed
|
| 63 |
+
if self.rank<0:
|
| 64 |
+
self.distributed = False
|
| 65 |
+
for group in param_groups:
|
| 66 |
+
if 'offset' not in group:
|
| 67 |
+
group['offset'] = None
|
| 68 |
+
if ('rank' not in group) or (not self.distributed):
|
| 69 |
+
group['rank'] = -1
|
| 70 |
+
assert group['offset'] is None, f"{group['names']}: {group['offset']}"
|
| 71 |
+
group_rank = group['rank']
|
| 72 |
+
params = group['params'] # parameter
|
| 73 |
+
if len(params) > 1:
|
| 74 |
+
flattened_params = _flatten_dense_tensors([p.data for p in params])
|
| 75 |
+
unflattend_params = _unflatten_dense_tensors(flattened_params, [p.data for p in params])
|
| 76 |
+
for uf,p in zip(unflattend_params, params):
|
| 77 |
+
p.data = uf
|
| 78 |
+
else:
|
| 79 |
+
flattened_params = params[0].data.view(-1)
|
| 80 |
+
if group['offset'] is not None:
|
| 81 |
+
start, length = group['offset']
|
| 82 |
+
flattened_params = flattened_params.narrow(0, start, length)
|
| 83 |
+
|
| 84 |
+
if params[0].dtype==torch.half:
|
| 85 |
+
if self.rank == group_rank or (not self.distributed):
|
| 86 |
+
master_params = flattened_params.clone().to(torch.float).detach_().to(flattened_params.device)
|
| 87 |
+
else:
|
| 88 |
+
master_params = flattened_params.clone().to(torch.float).detach_().cpu()
|
| 89 |
+
group['params'] = [OptParameter(master_params, flattened_params, name='master')]
|
| 90 |
+
else:
|
| 91 |
+
group['params'] = [OptParameter(flattened_params, None, name='master')]
|
| 92 |
+
|
| 93 |
+
o_group = defaultdict(list)
|
| 94 |
+
o_group['names'] = group['names']
|
| 95 |
+
o_group['params'] = params
|
| 96 |
+
o_group['rank'] = group_rank
|
| 97 |
+
o_group['offset'] = group['offset']
|
| 98 |
+
|
| 99 |
+
group['names'] = ['master']
|
| 100 |
+
|
| 101 |
+
original_groups.append(o_group)
|
| 102 |
+
groups.append(group)
|
| 103 |
+
self.param_groups = groups
|
| 104 |
+
self.loss_scaler = loss_scaler
|
| 105 |
+
self.optimizer = optimizer_fn(self.param_groups)
|
| 106 |
+
self.original_param_groups = original_groups
|
| 107 |
+
self.max_grad_norm = grad_clip_norm
|
| 108 |
+
self.lookahead_k = lookahead_k
|
| 109 |
+
self.lookahead_alpha = lookahead_alpha
|
| 110 |
+
|
| 111 |
+
def backward(self, loss):
|
| 112 |
+
if self.loss_scaler:
|
| 113 |
+
loss_scale, loss, step_loss = self.loss_scaler.scale(loss)
|
| 114 |
+
else:
|
| 115 |
+
loss_scale = 1
|
| 116 |
+
step_loss = loss.item()
|
| 117 |
+
|
| 118 |
+
loss.backward()
|
| 119 |
+
return loss_scale, step_loss
|
| 120 |
+
|
| 121 |
+
def step(self, lr_scale, loss_scale = 1):
|
| 122 |
+
grad_scale = self._grad_scale(loss_scale)
|
| 123 |
+
if grad_scale is None or math.isinf(grad_scale):
|
| 124 |
+
self.loss_scaler.update(False)
|
| 125 |
+
return False
|
| 126 |
+
|
| 127 |
+
if self.lookahead_k > 0:
|
| 128 |
+
for p in self.param_groups:
|
| 129 |
+
if 'la_count' not in p:
|
| 130 |
+
# init
|
| 131 |
+
#make old copy
|
| 132 |
+
p['la_count'] = 0
|
| 133 |
+
p['slow_params'] = [x.data.detach().clone().requires_grad_(False) for x in p['params']]
|
| 134 |
+
self.optimizer.step(grad_scale, lr_scale)
|
| 135 |
+
|
| 136 |
+
# for group in self.param_groups:
|
| 137 |
+
# for p in group['params']:
|
| 138 |
+
# # p.data : master fp32
|
| 139 |
+
# # p.out_data : fp16 tensor backing model nn.Parameters
|
| 140 |
+
# if hasattr(p, 'out_data') and p.out_data is not None:
|
| 141 |
+
# p.out_data.copy_(p.data, non_blocking=True)
|
| 142 |
+
|
| 143 |
+
if self.lookahead_k > 0:
|
| 144 |
+
for p in self.param_groups:
|
| 145 |
+
p['la_count'] += 1
|
| 146 |
+
if p['la_count'] == self.lookahead_k:
|
| 147 |
+
p['la_count'] = 0
|
| 148 |
+
for s,f in zip(p['slow_params'], p['params']):
|
| 149 |
+
s.mul_(1-self.lookahead_alpha)
|
| 150 |
+
s.add_(f.data.detach()*self.lookahead_alpha)
|
| 151 |
+
f.data.copy_(s, non_blocking=True)
|
| 152 |
+
if hasattr(f, 'out_data') and f.out_data is not None:
|
| 153 |
+
f.out_data.copy_(f.data, non_blocking=True)
|
| 154 |
+
|
| 155 |
+
if self.loss_scaler:
|
| 156 |
+
self.loss_scaler.update(True)
|
| 157 |
+
return True
|
| 158 |
+
|
| 159 |
+
def zero_grad(self):
|
| 160 |
+
for group, o_group in zip(self.param_groups, self.original_param_groups):
|
| 161 |
+
for p in group['params']:
|
| 162 |
+
p.grad = None
|
| 163 |
+
for p in o_group['params']:
|
| 164 |
+
p.grad = None
|
| 165 |
+
|
| 166 |
+
def _grad_scale(self, loss_scale = 1):
|
| 167 |
+
named_params = {}
|
| 168 |
+
named_grads = {}
|
| 169 |
+
for g in self.original_param_groups:
|
| 170 |
+
for n,p in zip(g['names'], g['params']):
|
| 171 |
+
named_params[n] = p
|
| 172 |
+
named_grads[n] = p.grad if p.grad is not None else torch.zeros_like(p.data)
|
| 173 |
+
|
| 174 |
+
wd = get_world_size()
|
| 175 |
+
def _reduce(group):
|
| 176 |
+
grads = [named_grads[n] for n in group]
|
| 177 |
+
if len(grads)>1:
|
| 178 |
+
flattened_grads = _flatten_dense_tensors(grads)
|
| 179 |
+
else:
|
| 180 |
+
flattened_grads = grads[0],view(-1)
|
| 181 |
+
|
| 182 |
+
if wd > 1:
|
| 183 |
+
flattened_grads /= wd
|
| 184 |
+
handle = dist.all_reduce(flattened_grads, async_op=True)
|
| 185 |
+
else:
|
| 186 |
+
handle = None
|
| 187 |
+
return flattened_grads, handle
|
| 188 |
+
|
| 189 |
+
def _process_grad(group, flattened_grads, max_grad, norm):
|
| 190 |
+
grads = [named_grads[n] for n in group]
|
| 191 |
+
norm = norm.to(flattened_grads.device)
|
| 192 |
+
norm = norm + fused_norm(flattened_grads)**2
|
| 193 |
+
|
| 194 |
+
if len(grads) > 1:
|
| 195 |
+
unflattend_grads = _unflatten_dense_tensors(flattened_grads, grads)
|
| 196 |
+
else:
|
| 197 |
+
unflattend_grads = [flattened_grads]
|
| 198 |
+
|
| 199 |
+
for n,ug in zip(group, unflattend_grads):
|
| 200 |
+
named_grads[n] = ug #.to(named_params[n].data)
|
| 201 |
+
|
| 202 |
+
return max_grad, norm
|
| 203 |
+
|
| 204 |
+
group_size = 0
|
| 205 |
+
group = []
|
| 206 |
+
max_size = 32*1024*1024
|
| 207 |
+
norm = torch.zeros(1, dtype=torch.float)
|
| 208 |
+
max_grad = 0
|
| 209 |
+
|
| 210 |
+
all_grads = []
|
| 211 |
+
for name in sorted(named_params.keys(), key=lambda x:x.replace('deberta.', 'bert.')):
|
| 212 |
+
group.append(name)
|
| 213 |
+
group_size += named_params[name].data.numel()
|
| 214 |
+
if group_size>=max_size:
|
| 215 |
+
flatten, handle = _reduce(group)
|
| 216 |
+
all_grads.append([handle, flatten, group])
|
| 217 |
+
group = []
|
| 218 |
+
group_size = 0
|
| 219 |
+
if group_size>0:
|
| 220 |
+
flatten, handle = _reduce(group)
|
| 221 |
+
all_grads.append([handle, flatten, group])
|
| 222 |
+
group = []
|
| 223 |
+
group_size = 0
|
| 224 |
+
for h,fg,group in all_grads:
|
| 225 |
+
if h is not None:
|
| 226 |
+
h.wait()
|
| 227 |
+
max_grad, norm = _process_grad(group, fg, max_grad, norm)
|
| 228 |
+
|
| 229 |
+
norm = norm**0.5
|
| 230 |
+
if torch.isnan(norm) or torch.isinf(norm) :#in ['-inf', 'inf', 'nan']:
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
scaled_norm = norm.detach().item()/loss_scale
|
| 234 |
+
grad_scale = loss_scale
|
| 235 |
+
|
| 236 |
+
if self.max_grad_norm>0:
|
| 237 |
+
scale = norm/(loss_scale*self.max_grad_norm)
|
| 238 |
+
if scale>1:
|
| 239 |
+
grad_scale *= scale
|
| 240 |
+
|
| 241 |
+
for group, o_g in zip(self.param_groups, self.original_param_groups):
|
| 242 |
+
grads = [named_grads[n] for n in o_g['names']]
|
| 243 |
+
|
| 244 |
+
if len(grads) > 1:
|
| 245 |
+
flattened_grads = _flatten_dense_tensors(grads)
|
| 246 |
+
else:
|
| 247 |
+
flattened_grads = grads[0].view(-1)
|
| 248 |
+
if group['offset'] is not None:
|
| 249 |
+
start, length = group['offset']
|
| 250 |
+
flattened_grads = flattened_grads.narrow(0, start, length)
|
| 251 |
+
if group['rank'] == self.rank or (not self.distributed):
|
| 252 |
+
group['params'][0].grad = flattened_grads
|
| 253 |
+
|
| 254 |
+
return grad_scale
|
| 255 |
+
|
| 256 |
+
class ExpLossScaler:
|
| 257 |
+
def __init__(self, init_scale=2**16, scale_interval=1000):
|
| 258 |
+
self.cur_scale = init_scale
|
| 259 |
+
self.scale_interval = scale_interval
|
| 260 |
+
self.invalid_cnt = 0
|
| 261 |
+
self.last_scale = 0
|
| 262 |
+
self.steps = 0
|
| 263 |
+
self.down_scale_smooth = 0
|
| 264 |
+
|
| 265 |
+
def scale(self, loss):
|
| 266 |
+
assert self.cur_scale > 0, self.init_scale
|
| 267 |
+
step_loss = loss.float().detach().item()
|
| 268 |
+
if step_loss != 0 and math.isfinite(step_loss):
|
| 269 |
+
loss_scale = self.cur_scale
|
| 270 |
+
else:
|
| 271 |
+
loss_scale = 1
|
| 272 |
+
loss = loss.float()*loss_scale
|
| 273 |
+
return (loss_scale, loss, step_loss)
|
| 274 |
+
|
| 275 |
+
def update(self, is_valid = True):
|
| 276 |
+
if not is_valid:
|
| 277 |
+
self.invalid_cnt += 1
|
| 278 |
+
if self.invalid_cnt>self.down_scale_smooth:
|
| 279 |
+
self.cur_scale /= 2
|
| 280 |
+
self.cur_scale = max(self.cur_scale, 1)
|
| 281 |
+
self.last_scale = self.steps
|
| 282 |
+
else:
|
| 283 |
+
self.invalid_cnt = 0
|
| 284 |
+
if self.steps - self.last_scale>self.scale_interval:
|
| 285 |
+
self.cur_scale *= 2
|
| 286 |
+
self.last_scale = self.steps
|
| 287 |
+
self.steps += 1
|
| 288 |
+
|
| 289 |
+
def state_dict(self):
|
| 290 |
+
state = defaultdict(float)
|
| 291 |
+
state['steps'] = self.steps
|
| 292 |
+
state['invalid_cnt'] = self.invalid_cnt
|
| 293 |
+
state['cur_scale'] = self.cur_scale
|
| 294 |
+
state['last_scale'] = self.last_scale
|
| 295 |
+
return state
|
| 296 |
+
|
| 297 |
+
def load_state_dict(self, state):
|
| 298 |
+
self.steps = state['steps']
|
| 299 |
+
self.invalid_cnt = state['invalid_cnt']
|
| 300 |
+
self.cur_scale = state['cur_scale']
|
| 301 |
+
self.last_scale = state['last_scale']
|