nvan13 commited on
Commit
ab0f6ec
·
verified ·
1 Parent(s): f4dcc30

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. assets/control.png +3 -0
  3. assets/subject.png +3 -0
  4. generation/control/ControlNet/font/DejaVuSans.ttf +3 -0
  5. generation/control/ControlNet/ldm/modules/image_degradation/utils/test.png +3 -0
  6. llama/data/MetaMathQA-40K.json +3 -0
  7. llama/data/MetaMathQA.json +3 -0
  8. llama/output/cp1e4/ft/adapter_model.safetensors +3 -0
  9. llama/output/cp1e4/ft/tokenizer.model +3 -0
  10. llama/output/cp1e5/ft/adapter_model.safetensors +3 -0
  11. llama/output/cp1e5N/ft/adapter_model.safetensors +3 -0
  12. llama/output/cp1e5N/ft/tokenizer.model +3 -0
  13. llama/output/cp3e5/ft/adapter_model.safetensors +3 -0
  14. llama/output/cp3e5N/ft/adapter_model.safetensors +3 -0
  15. llama/output/cp3e5N/ft/tokenizer.model +3 -0
  16. llama/output/cpr1/ft/adapter_model.safetensors +3 -0
  17. llama/output/cpr1/ft/tokenizer.model +3 -0
  18. llama/output/cpr2/ft/adapter_model.safetensors +3 -0
  19. llama/output/cpr2/ft/tokenizer.model +3 -0
  20. nlu/DeBERTa.egg-info/PKG-INFO +39 -0
  21. nlu/DeBERTa.egg-info/SOURCES.txt +73 -0
  22. nlu/DeBERTa.egg-info/dependency_links.txt +1 -0
  23. nlu/DeBERTa.egg-info/requires.txt +19 -0
  24. nlu/DeBERTa.egg-info/top_level.txt +2 -0
  25. nlu/DeBERTa/apps/tasks/task_registry.py +70 -0
  26. nlu/DeBERTa/data/__init__.py +5 -0
  27. nlu/DeBERTa/data/async_data.py +38 -0
  28. nlu/DeBERTa/data/data_sampler.py +76 -0
  29. nlu/DeBERTa/data/dataloader.py +511 -0
  30. nlu/DeBERTa/data/dynamic_dataset.py +60 -0
  31. nlu/DeBERTa/data/example.py +105 -0
  32. nlu/DeBERTa/deberta/__init__.py +22 -0
  33. nlu/DeBERTa/deberta/bert.py +308 -0
  34. nlu/DeBERTa/deberta/cache_utils.py +135 -0
  35. nlu/DeBERTa/deberta/config.py +90 -0
  36. nlu/DeBERTa/deberta/da_utils.py +68 -0
  37. nlu/DeBERTa/deberta/deberta.py +145 -0
  38. nlu/DeBERTa/deberta/disentangled_attention.py +221 -0
  39. nlu/DeBERTa/deberta/gpt2_bpe_utils.py +163 -0
  40. nlu/DeBERTa/deberta/gpt2_tokenizer.py +216 -0
  41. nlu/DeBERTa/deberta/mlm.py +38 -0
  42. nlu/DeBERTa/deberta/nnmodule.py +137 -0
  43. nlu/DeBERTa/deberta/ops.py +228 -0
  44. nlu/DeBERTa/deberta/pooling.py +88 -0
  45. nlu/DeBERTa/deberta/pretrained_models.py +2 -0
  46. nlu/DeBERTa/deberta/spm_tokenizer.py +322 -0
  47. nlu/DeBERTa/deberta/tokenizers.py +16 -0
  48. nlu/DeBERTa/optims/__init__.py +16 -0
  49. nlu/DeBERTa/optims/args.py +100 -0
  50. nlu/DeBERTa/optims/fp16_optimizer.py +301 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/control.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/subject.png filter=lfs diff=lfs merge=lfs -text
38
+ generation/control/ControlNet/font/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
39
+ generation/control/ControlNet/ldm/modules/image_degradation/utils/test.png filter=lfs diff=lfs merge=lfs -text
40
+ llama/data/MetaMathQA-40K.json filter=lfs diff=lfs merge=lfs -text
41
+ llama/data/MetaMathQA.json filter=lfs diff=lfs merge=lfs -text
assets/control.png ADDED

Git LFS Details

  • SHA256: b1943c7d2d2042fd1f5455f7c85509c7fc2299221d3118caf8369807b99ff451
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/subject.png ADDED

Git LFS Details

  • SHA256: d115037067258634d251581e308b6509fd9b8190b6084d00a211b6886dd379c7
  • Pointer size: 131 Bytes
  • Size of remote file: 966 kB
generation/control/ControlNet/font/DejaVuSans.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7da195a74c55bef988d0d48f9508bd5d849425c1770dba5d7bfc6ce9ed848954
3
+ size 757076
generation/control/ControlNet/ldm/modules/image_degradation/utils/test.png ADDED

Git LFS Details

  • SHA256: 92e516278f0d3e85e84cfb55b43338e12d5896a0ee3833aafdf378025457d753
  • Pointer size: 131 Bytes
  • Size of remote file: 441 kB
llama/data/MetaMathQA-40K.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c884f10e8aa1229a6e73a6bba2c9134ee0c7b7de92a02a7b8c9459085a59e117
3
+ size 31076207
llama/data/MetaMathQA.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb39a5d8c05c042ece92eae37dfd5ea414a5979df2bf3ad3b86411bef8205725
3
+ size 395626321
llama/output/cp1e4/ft/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e1c2fceb4f91331d69364aa56d01dd2103d4e59066f1519f1242a62ecca387a
3
+ size 1082171824
llama/output/cp1e4/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
llama/output/cp1e5/ft/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6121d3f7682fd21f70fc78ab9097b22ede67191507c54d44a9bd9c30adf44de
3
+ size 592928
llama/output/cp1e5N/ft/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d85146aea100acda2fd5bb5a011f8d1e14983756bb0c102bf85efe04ac176479
3
+ size 1082171824
llama/output/cp1e5N/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
llama/output/cp3e5/ft/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1945e74d818ded53f08bc892bb458dd0e6addcd548b2f864dbd16a476a8954ef
3
+ size 1082171824
llama/output/cp3e5N/ft/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2396d96c0a301cceddf424fbdf7c7f3518311f90140fa9aad9053706288e9fc
3
+ size 1082171824
llama/output/cp3e5N/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
llama/output/cpr1/ft/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:617c715b246fae47190ca1f8e304e9dbdadf6ac70bbfdd0f3bc3c4b1cd783c0d
3
+ size 1049665904
llama/output/cpr1/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
llama/output/cpr2/ft/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daede58d9fd4806298d90f9af12ba478c119afab844244f355f35ab3829eb029
3
+ size 1049665904
llama/output/cpr2/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
nlu/DeBERTa.egg-info/PKG-INFO ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: DeBERTa
3
+ Version: 0.1.13
4
+ Summary: Decoding enhanced BERT with Disentangled Attention
5
+ Home-page: https://github.com/microsoft/DeBERTa
6
+ Author: penhe
7
+ Author-email: penhe@microsoft.com
8
+ License: MIT
9
+ Keywords: NLP deep learning transformer pytorch Attention BERT RoBERTa DeBERTa
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.6
12
+ Classifier: Programming Language :: Python :: 3.7
13
+ Classifier: Programming Language :: Python :: 3.8
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Operating System :: OS Independent
17
+ Requires-Python: >=3.6
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: nltk
21
+ Requires-Dist: spacy
22
+ Requires-Dist: numpy
23
+ Requires-Dist: pytest
24
+ Requires-Dist: regex
25
+ Requires-Dist: scipy
26
+ Requires-Dist: scikit-learn
27
+ Requires-Dist: tqdm
28
+ Requires-Dist: ujson
29
+ Requires-Dist: seqeval
30
+ Requires-Dist: psutil
31
+ Requires-Dist: sentencepiece
32
+ Requires-Dist: torch
33
+ Provides-Extra: docs
34
+ Requires-Dist: recommonmark; extra == "docs"
35
+ Requires-Dist: sphinx; extra == "docs"
36
+ Requires-Dist: sphinx-markdown-tables; extra == "docs"
37
+ Requires-Dist: sphinx-rtd-theme; extra == "docs"
38
+
39
+ deberta long des
nlu/DeBERTa.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ setup.cfg
3
+ setup.py
4
+ DeBERTa/__init__.py
5
+ DeBERTa.egg-info/PKG-INFO
6
+ DeBERTa.egg-info/SOURCES.txt
7
+ DeBERTa.egg-info/dependency_links.txt
8
+ DeBERTa.egg-info/requires.txt
9
+ DeBERTa.egg-info/top_level.txt
10
+ DeBERTa/apps/__init__.py
11
+ DeBERTa/apps/_utils.py
12
+ DeBERTa/apps/run.py
13
+ DeBERTa/apps/models/__init__.py
14
+ DeBERTa/apps/models/masked_language_model.py
15
+ DeBERTa/apps/models/multi_choice.py
16
+ DeBERTa/apps/models/ner.py
17
+ DeBERTa/apps/models/record_qa.py
18
+ DeBERTa/apps/models/replaced_token_detection_model.py
19
+ DeBERTa/apps/models/sequence_classification.py
20
+ DeBERTa/apps/tasks/__init__.py
21
+ DeBERTa/apps/tasks/glue_tasks.py
22
+ DeBERTa/apps/tasks/metrics.py
23
+ DeBERTa/apps/tasks/mlm_task.py
24
+ DeBERTa/apps/tasks/ner_task.py
25
+ DeBERTa/apps/tasks/race_task.py
26
+ DeBERTa/apps/tasks/record_eval.py
27
+ DeBERTa/apps/tasks/rtd_task.py
28
+ DeBERTa/apps/tasks/superglue_tasks.py
29
+ DeBERTa/apps/tasks/task.py
30
+ DeBERTa/apps/tasks/task_registry.py
31
+ DeBERTa/data/__init__.py
32
+ DeBERTa/data/async_data.py
33
+ DeBERTa/data/data_sampler.py
34
+ DeBERTa/data/dataloader.py
35
+ DeBERTa/data/dynamic_dataset.py
36
+ DeBERTa/data/example.py
37
+ DeBERTa/deberta/__init__.py
38
+ DeBERTa/deberta/bert.py
39
+ DeBERTa/deberta/cache_utils.py
40
+ DeBERTa/deberta/config.py
41
+ DeBERTa/deberta/da_utils.py
42
+ DeBERTa/deberta/deberta.py
43
+ DeBERTa/deberta/disentangled_attention.py
44
+ DeBERTa/deberta/gpt2_bpe_utils.py
45
+ DeBERTa/deberta/gpt2_tokenizer.py
46
+ DeBERTa/deberta/mlm.py
47
+ DeBERTa/deberta/nnmodule.py
48
+ DeBERTa/deberta/ops.py
49
+ DeBERTa/deberta/pooling.py
50
+ DeBERTa/deberta/pretrained_models.py
51
+ DeBERTa/deberta/spm_tokenizer.py
52
+ DeBERTa/deberta/tokenizers.py
53
+ DeBERTa/optims/__init__.py
54
+ DeBERTa/optims/args.py
55
+ DeBERTa/optims/fp16_optimizer.py
56
+ DeBERTa/optims/lr_schedulers.py
57
+ DeBERTa/optims/xadam.py
58
+ DeBERTa/sift/__init__.py
59
+ DeBERTa/sift/sift.py
60
+ DeBERTa/training/__init__.py
61
+ DeBERTa/training/_utils.py
62
+ DeBERTa/training/args.py
63
+ DeBERTa/training/dist_launcher.py
64
+ DeBERTa/training/optimizer_utils.py
65
+ DeBERTa/training/trainer.py
66
+ DeBERTa/utils/__init__.py
67
+ DeBERTa/utils/argument_types.py
68
+ DeBERTa/utils/jit_tracing.py
69
+ DeBERTa/utils/logger_util.py
70
+ DeBERTa/utils/xtqdm.py
71
+ adapterlib/__init__.py
72
+ adapterlib/layers.py
73
+ adapterlib/utils.py
nlu/DeBERTa.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
nlu/DeBERTa.egg-info/requires.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nltk
2
+ spacy
3
+ numpy
4
+ pytest
5
+ regex
6
+ scipy
7
+ scikit-learn
8
+ tqdm
9
+ ujson
10
+ seqeval
11
+ psutil
12
+ sentencepiece
13
+ torch
14
+
15
+ [docs]
16
+ recommonmark
17
+ sphinx
18
+ sphinx-markdown-tables
19
+ sphinx-rtd-theme
nlu/DeBERTa.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ DeBERTa
2
+ adapterlib
nlu/DeBERTa/apps/tasks/task_registry.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Author: penhe@microsoft.com
7
+ # Date: 01/25/2019
8
+ #
9
+
10
+ from glob import glob
11
+ import os
12
+ import importlib
13
+ import pdb
14
+ import sys
15
+ from ...utils import get_logger
16
+ from .task import Task
17
+
18
+ __all__ = ['load_tasks', 'register_task', 'get_task']
19
+ tasks={}
20
+
21
+ logger=get_logger()
22
+
23
+ def register_task(name=None, desc=None):
24
+ def register_task_x(cls):
25
+ _name = name
26
+ if _name is None:
27
+ _name = cls.__name__
28
+
29
+ _desc = desc
30
+ if _desc is None:
31
+ _desc = _name
32
+
33
+ _name = _name.lower()
34
+ if _name in tasks:
35
+ logger.warning(f'{_name} already registered in the registry: {tasks[_name]}.')
36
+ assert issubclass(cls, Task), f'Registered class must be a subclass of Task.'
37
+ tasks[_name] = cls
38
+ cls._meta = {
39
+ 'name': _name,
40
+ 'desc': _desc}
41
+ return cls
42
+
43
+ if type(name)==type:
44
+ cls = name
45
+ name = None
46
+ return register_task_x(cls)
47
+ return register_task_x
48
+
49
+ def load_tasks(task_dir = None):
50
+ script_dir = os.path.dirname(os.path.abspath(__file__))
51
+ sys_tasks = glob(os.path.join(script_dir, "*.py"))
52
+ for t in sys_tasks:
53
+ m = os.path.splitext(os.path.basename(t))[0]
54
+ if not m.startswith('_'):
55
+ importlib.import_module(f'DeBERTa.apps.tasks.{m}')
56
+
57
+ if task_dir:
58
+ assert os.path.exists(task_dir), f"{task_dir} must be a valid directory."
59
+ customer_tasks = glob(os.path.join(task_dir, "*.py"))
60
+ sys.path.append(task_dir)
61
+ for t in customer_tasks:
62
+ m = os.path.splitext(os.path.basename(t))[0]
63
+ if not m.startswith('_'):
64
+ importlib.import_module(f'{m}')
65
+
66
+ def get_task(name=None):
67
+ if name is None:
68
+ return tasks
69
+
70
+ return tasks[name.lower()]
nlu/DeBERTa/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .example import ExampleInstance,ExampleSet,example_to_feature
2
+ from .dataloader import SequentialDataLoader
3
+ from .dynamic_dataset import *
4
+ from .data_sampler import *
5
+ from .async_data import *
nlu/DeBERTa/data/async_data.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ #
5
+ # Author: Pengcheng He (penhe@microsoft.com)
6
+ # Date: 05/15/2019
7
+ #
8
+
9
+ from queue import Queue,Empty
10
+ from threading import Thread
11
+ class AsyncDataLoader(object):
12
+ def __init__(self, dataloader, buffer_size=100):
13
+ self.buffer_size = buffer_size
14
+ self.dataloader = dataloader
15
+
16
+ def __iter__(self):
17
+ queue = Queue(self.buffer_size)
18
+ dl=iter(self.dataloader)
19
+ def _worker():
20
+ while True:
21
+ try:
22
+ queue.put(next(dl))
23
+ except StopIteration:
24
+ break
25
+ queue.put(None)
26
+ t=Thread(target=_worker)
27
+ t.start()
28
+ while True:
29
+ d = queue.get()
30
+ if d is None:
31
+ break
32
+ yield d
33
+ del t
34
+ del queue
35
+
36
+ def __len__(self):
37
+ return len(self.dataloader)
38
+
nlu/DeBERTa/data/data_sampler.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ #
5
+ # Author: Pengcheng He (penhe@microsoft.com)
6
+ # Date: 05/15/2019
7
+ #
8
+
9
+ import os
10
+ import numpy as np
11
+ import math
12
+ import sys
13
+ from torch.utils.data import Sampler
14
+
15
+ __all__=['BatchSampler', 'DistributedBatchSampler', 'RandomSampler', 'SequentialSampler']
16
+ class BatchSampler(Sampler):
17
+ def __init__(self, sampler, batch_size):
18
+ self.sampler = sampler
19
+ self.batch_size = batch_size
20
+
21
+ def __iter__(self):
22
+ batch = []
23
+ for idx in self.sampler:
24
+ batch.append(idx)
25
+ if len(batch)==self.batch_size:
26
+ yield batch
27
+ batch = []
28
+ if len(batch)>0:
29
+ yield batch
30
+
31
+ def __len__(self):
32
+ return (len(self.sampler) + self.batch_size - 1)//self.batch_size
33
+
34
+ class DistributedBatchSampler(Sampler):
35
+ def __init__(self, sampler, rank=0, world_size = 1, drop_last = False):
36
+ self.sampler = sampler
37
+ self.rank = rank
38
+ self.world_size = world_size
39
+ self.drop_last = drop_last
40
+
41
+ def __iter__(self):
42
+ for b in self.sampler:
43
+ if len(b)%self.world_size != 0:
44
+ if self.drop_last:
45
+ break
46
+ else:
47
+ b.extend([b[0] for _ in range(self.world_size-len(b)%self.world_size)])
48
+ chunk_size = len(b)//self.world_size
49
+ yield b[self.rank*chunk_size:(self.rank+1)*chunk_size]
50
+
51
+ def __len__(self):
52
+ return len(self.sampler)
53
+
54
+ class RandomSampler(Sampler):
55
+ def __init__(self, total_samples:int, data_seed:int = 0):
56
+ self.indices = np.array(np.arange(total_samples))
57
+ self.rng = np.random.RandomState(data_seed)
58
+
59
+ def __iter__(self):
60
+ self.rng.shuffle(self.indices)
61
+ for i in self.indices:
62
+ yield i
63
+
64
+ def __len__(self):
65
+ return len(self.indices)
66
+
67
+ class SequentialSampler(Sampler):
68
+ def __init__(self, total_samples:int):
69
+ self.indices = np.array(np.arange(total_samples))
70
+
71
+ def __iter__(self):
72
+ for i in self.indices:
73
+ yield i
74
+
75
+ def __len__(self):
76
+ return len(self.indices)
nlu/DeBERTa/data/dataloader.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.multiprocessing as multiprocessing
4
+ from torch._C import _set_worker_signal_handlers, \
5
+ _remove_worker_pids, _error_if_any_worker_fails
6
+
7
+ from packaging import version
8
+
9
+ if version.Version(torch.__version__) >= version.Version('1.0.0'):
10
+ from torch._C import _set_worker_pids
11
+ else:
12
+ from torch._C import _update_worker_pids as _set_worker_pids
13
+
14
+ from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler, Sampler
15
+ import signal
16
+ import functools
17
+ import collections.abc
18
+ import re
19
+ import sys
20
+ import threading
21
+ import traceback
22
+ import os
23
+ import time
24
+ # from torch._six import string_classes
25
+ string_classes = str
26
+
27
+ IS_WINDOWS = sys.platform == "win32"
28
+ if IS_WINDOWS:
29
+ import ctypes
30
+ from ctypes.wintypes import DWORD, BOOL, HANDLE
31
+
32
+ if sys.version_info[0] == 2:
33
+ import Queue as queue
34
+ else:
35
+ import queue
36
+
37
+ __all__ = ['SequentialDataLoader']
38
+
39
+ class ExceptionWrapper(object):
40
+ r"""Wraps an exception plus traceback to communicate across threads"""
41
+
42
+ def __init__(self, exc_info):
43
+ self.exc_type = exc_info[0]
44
+ self.exc_msg = "".join(traceback.format_exception(*exc_info))
45
+
46
+
47
+ _use_shared_memory = False
48
+ r"""Whether to use shared memory in default_collate"""
49
+
50
+ MANAGER_STATUS_CHECK_INTERVAL = 5.0
51
+
52
+ if IS_WINDOWS:
53
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
54
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
55
+ # of the manager and ask if the process status has changed.
56
+ class ManagerWatchdog(object):
57
+ def __init__(self):
58
+ self.manager_pid = os.getppid()
59
+
60
+ self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
61
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
62
+ self.kernel32.OpenProcess.restype = HANDLE
63
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
64
+ self.kernel32.WaitForSingleObject.restype = DWORD
65
+
66
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
67
+ SYNCHRONIZE = 0x00100000
68
+ self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
69
+
70
+ if not self.manager_handle:
71
+ raise ctypes.WinError(ctypes.get_last_error())
72
+
73
+ def is_alive(self):
74
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
75
+ return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0
76
+ else:
77
+ class ManagerWatchdog(object):
78
+ def __init__(self):
79
+ self.manager_pid = os.getppid()
80
+
81
+ def is_alive(self):
82
+ return os.getppid() == self.manager_pid
83
+
84
+
85
+ def _worker_loop(dataset, index_queue, data_queue, collate_fn, init_fn, worker_id):
86
+ global _use_shared_memory
87
+ _use_shared_memory = True
88
+
89
+ # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
90
+ # module's handlers are executed after Python returns from C low-level
91
+ # handlers, likely when the same fatal signal happened again already.
92
+ # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
93
+ _set_worker_signal_handlers()
94
+
95
+ torch.set_num_threads(1)
96
+
97
+ if init_fn is not None:
98
+ init_fn(worker_id)
99
+
100
+ watchdog = ManagerWatchdog()
101
+
102
+ while True:
103
+ try:
104
+ r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
105
+ except queue.Empty:
106
+ if watchdog.is_alive():
107
+ continue
108
+ else:
109
+ break
110
+ if r is None:
111
+ break
112
+ idx, batch_indices = r
113
+ try:
114
+ samples = collate_fn([dataset[i] for i in batch_indices])
115
+ except Exception:
116
+ data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
117
+ else:
118
+ data_queue.put((idx, samples))
119
+ del samples
120
+
121
+
122
+ def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
123
+ if pin_memory:
124
+ torch.cuda.set_device(device_id)
125
+
126
+ while True:
127
+ try:
128
+ r = in_queue.get()
129
+ except Exception:
130
+ if done_event.is_set():
131
+ return
132
+ raise
133
+ if r is None:
134
+ break
135
+ if isinstance(r[1], ExceptionWrapper):
136
+ out_queue.put(r)
137
+ continue
138
+ idx, batch = r
139
+ try:
140
+ if pin_memory:
141
+ batch = pin_memory_batch(batch)
142
+ except Exception:
143
+ out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
144
+ else:
145
+ out_queue.put((idx, batch))
146
+
147
+ numpy_type_map = {
148
+ 'float64': torch.DoubleTensor,
149
+ 'float32': torch.FloatTensor,
150
+ 'float16': torch.HalfTensor,
151
+ 'int64': torch.LongTensor,
152
+ 'int32': torch.IntTensor,
153
+ 'int16': torch.ShortTensor,
154
+ 'int8': torch.CharTensor,
155
+ 'uint8': torch.ByteTensor,
156
+ }
157
+
158
+
159
+ def default_collate(batch):
160
+ r"""Puts each data field into a tensor with outer dimension batch size"""
161
+
162
+ error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
163
+ elem_type = type(batch[0])
164
+ if isinstance(batch[0], torch.Tensor):
165
+ out = None
166
+ if _use_shared_memory:
167
+ # If we're in a background process, concatenate directly into a
168
+ # shared memory tensor to avoid an extra copy
169
+ numel = sum([x.numel() for x in batch])
170
+ storage = batch[0].storage()._new_shared(numel)
171
+ out = batch[0].new(storage)
172
+ return torch.stack(batch, 0, out=out)
173
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
174
+ and elem_type.__name__ != 'string_':
175
+ elem = batch[0]
176
+ if elem_type.__name__ == 'ndarray':
177
+ # array of string classes and object
178
+ if re.search('[SaUO]', elem.dtype.str) is not None:
179
+ raise TypeError(error_msg.format(elem.dtype))
180
+
181
+ return torch.stack([torch.from_numpy(b) for b in batch], 0)
182
+ if elem.shape == (): # scalars
183
+ py_type = float if elem.dtype.name.startswith('float') else int
184
+ return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
185
+ elif isinstance(batch[0], int):
186
+ return torch.LongTensor(batch)
187
+ elif isinstance(batch[0], float):
188
+ return torch.DoubleTensor(batch)
189
+ elif isinstance(batch[0], string_classes):
190
+ return batch
191
+ elif isinstance(batch[0], collections.abc.Mapping):
192
+ return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
193
+ elif isinstance(batch[0], collections.abc.Sequence):
194
+ transposed = zip(*batch)
195
+ return [default_collate(samples) for samples in transposed]
196
+
197
+ raise TypeError((error_msg.format(type(batch[0]))))
198
+
199
+
200
+ def pin_memory_batch(batch):
201
+ if isinstance(batch, torch.Tensor):
202
+ return batch.pin_memory()
203
+ elif isinstance(batch, string_classes):
204
+ return batch
205
+ elif isinstance(batch, collections.abc.Mapping):
206
+ return {k: pin_memory_batch(sample) for k, sample in batch.items()}
207
+ elif isinstance(batch, collections.abc.Sequence):
208
+ return [pin_memory_batch(sample) for sample in batch]
209
+ else:
210
+ return batch
211
+
212
+
213
+ _SIGCHLD_handler_set = False
214
+ r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
215
+ handler needs to be set for all DataLoaders in a process."""
216
+
217
+
218
+ def _set_SIGCHLD_handler():
219
+ # Windows doesn't support SIGCHLD handler
220
+ if sys.platform == 'win32':
221
+ return
222
+ # can't set signal in child threads
223
+ if not isinstance(threading.current_thread(), threading._MainThread):
224
+ return
225
+ global _SIGCHLD_handler_set
226
+ if _SIGCHLD_handler_set:
227
+ return
228
+ previous_handler = signal.getsignal(signal.SIGCHLD)
229
+ if not callable(previous_handler):
230
+ previous_handler = None
231
+
232
+ def handler(signum, frame):
233
+ # This following call uses `waitid` with WNOHANG from C side. Therefore,
234
+ # Python can still get and update the process status successfully.
235
+ _error_if_any_worker_fails()
236
+ if previous_handler is not None:
237
+ previous_handler(signum, frame)
238
+
239
+ signal.signal(signal.SIGCHLD, handler)
240
+ _SIGCHLD_handler_set = True
241
+
242
+
243
+ class _SequentialDataLoaderIter(object):
244
+ r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
245
+
246
+ def __init__(self, loader):
247
+ self.dataset = loader.dataset
248
+ self.collate_fn = loader.collate_fn
249
+ self.batch_sampler = loader.batch_sampler
250
+ self.num_workers = loader.num_workers
251
+ self.pin_memory = loader.pin_memory and torch.cuda.is_available()
252
+ self.timeout = loader.timeout
253
+ self.done_event = threading.Event()
254
+
255
+ self.sample_iter = iter(self.batch_sampler)
256
+
257
+ if self.num_workers > 0:
258
+ self.worker_init_fn = loader.worker_init_fn
259
+ self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
260
+ self.worker_queue_idx = 0
261
+ self.worker_result_queue = multiprocessing.SimpleQueue()
262
+ self.batches_outstanding = 0
263
+ self.worker_pids_set = False
264
+ self.shutdown = False
265
+ self.send_idx = 0
266
+ self.rcvd_idx = 0
267
+ self.reorder_dict = {}
268
+
269
+ self.workers = [
270
+ multiprocessing.Process(
271
+ target=_worker_loop,
272
+ args=(self.dataset, self.index_queues[i],
273
+ self.worker_result_queue, self.collate_fn, self.worker_init_fn, i))
274
+ for i in range(self.num_workers)]
275
+
276
+ if self.pin_memory or self.timeout > 0:
277
+ self.data_queue = queue.Queue()
278
+ if self.pin_memory:
279
+ maybe_device_id = torch.cuda.current_device()
280
+ else:
281
+ # do not initialize cuda context if not necessary
282
+ maybe_device_id = None
283
+ self.worker_manager_thread = threading.Thread(
284
+ target=_worker_manager_loop,
285
+ args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
286
+ maybe_device_id))
287
+ self.worker_manager_thread.daemon = True
288
+ self.worker_manager_thread.start()
289
+ else:
290
+ self.data_queue = self.worker_result_queue
291
+
292
+ for w in self.workers:
293
+ w.daemon = True # ensure that the worker exits on process exit
294
+ w.start()
295
+
296
+ _set_worker_pids(id(self), tuple(w.pid for w in self.workers))
297
+ _set_SIGCHLD_handler()
298
+ self.worker_pids_set = True
299
+
300
+ # prime the prefetch loop
301
+ for _ in range(2 * self.num_workers):
302
+ self._put_indices()
303
+
304
+ def __len__(self):
305
+ return len(self.batch_sampler)
306
+
307
+ def _get_batch(self):
308
+ if self.timeout > 0:
309
+ try:
310
+ return self.data_queue.get(timeout=self.timeout)
311
+ except queue.Empty:
312
+ raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
313
+ else:
314
+ return self.data_queue.get()
315
+
316
+ def __next__(self):
317
+ if self.num_workers == 0: # same-process loading
318
+ indices = next(self.sample_iter) # may raise StopIteration
319
+ batch = self.collate_fn([self.dataset[i] for i in indices])
320
+ if self.pin_memory:
321
+ batch = pin_memory_batch(batch)
322
+ return batch
323
+
324
+ # check if the next sample has already been generated
325
+ if self.rcvd_idx in self.reorder_dict:
326
+ batch = self.reorder_dict.pop(self.rcvd_idx)
327
+ return self._process_next_batch(batch)
328
+
329
+ if self.batches_outstanding == 0:
330
+ self._shutdown_workers()
331
+ raise StopIteration
332
+
333
+ while True:
334
+ assert (not self.shutdown and self.batches_outstanding > 0)
335
+ idx, batch = self._get_batch()
336
+ self.batches_outstanding -= 1
337
+ if idx != self.rcvd_idx:
338
+ # store out-of-order samples
339
+ self.reorder_dict[idx] = batch
340
+ continue
341
+ return self._process_next_batch(batch)
342
+
343
+ next = __next__ # Python 2 compatibility
344
+
345
+ def __iter__(self):
346
+ return self
347
+
348
+ def _put_indices(self):
349
+ assert self.batches_outstanding < 2 * self.num_workers
350
+ indices = next(self.sample_iter, None)
351
+ if indices is None:
352
+ return
353
+ self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
354
+ self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
355
+ self.batches_outstanding += 1
356
+ self.send_idx += 1
357
+
358
+ def _process_next_batch(self, batch):
359
+ self.rcvd_idx += 1
360
+ self._put_indices()
361
+ if isinstance(batch, ExceptionWrapper):
362
+ raise batch.exc_type(batch.exc_msg)
363
+ return batch
364
+
365
+ def __getstate__(self):
366
+ # TODO: add limited pickling support for sharing an iterator
367
+ # across multiple threads for HOGWILD.
368
+ # Probably the best way to do this is by moving the sample pushing
369
+ # to a separate thread and then just sharing the data queue
370
+ # but signalling the end is tricky without a non-blocking API
371
+ raise NotImplementedError("_SequentialDataLoaderIter cannot be pickled")
372
+
373
+ def _shutdown_workers(self):
374
+ try:
375
+ if not self.shutdown:
376
+ self.shutdown = True
377
+ self.done_event.set()
378
+ for q in self.index_queues:
379
+ q.put(None)
380
+ # if some workers are waiting to put, make place for them
381
+ try:
382
+ while not self.worker_result_queue.empty():
383
+ self.worker_result_queue.get()
384
+ except (FileNotFoundError, ImportError):
385
+ # Many weird errors can happen here due to Python
386
+ # shutting down. These are more like obscure Python bugs.
387
+ # FileNotFoundError can happen when we rebuild the fd
388
+ # fetched from the queue but the socket is already closed
389
+ # from the worker side.
390
+ # ImportError can happen when the unpickler loads the
391
+ # resource from `get`.
392
+ pass
393
+ # done_event should be sufficient to exit worker_manager_thread,
394
+ # but be safe here and put another None
395
+ self.worker_result_queue.put(None)
396
+ finally:
397
+ # removes pids no matter what
398
+ if self.worker_pids_set:
399
+ _remove_worker_pids(id(self))
400
+ self.worker_pids_set = False
401
+
402
+ def __del__(self):
403
+ if self.num_workers > 0:
404
+ self._shutdown_workers()
405
+
406
+
407
+ class SequentialDataLoader(object):
408
+ r"""
409
+ Sequential Data loader. Combines a dataset and a sampler, and provides
410
+ single- or multi-process iterators over the dataset.
411
+ This is modified from Pytorch.DataLoader by disable random state touch as for sequential data loading,
412
+ we don't want it to touch any random state.
413
+ Arguments:
414
+ dataset (Dataset): dataset from which to load the data.
415
+ batch_size (int, optional): how many samples per batch to load
416
+ (default: 1).
417
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
418
+ at every epoch (default: False).
419
+ sampler (Sampler, optional): defines the strategy to draw samples from
420
+ the dataset. If specified, ``shuffle`` must be False.
421
+ batch_sampler (Sampler, optional): like sampler, but returns a batch of
422
+ indices at a time. Mutually exclusive with batch_size, shuffle,
423
+ sampler, and drop_last.
424
+ num_workers (int, optional): how many subprocesses to use for data
425
+ loading. 0 means that the data will be loaded in the main process.
426
+ (default: 0)
427
+ collate_fn (callable, optional): merges a list of samples to form a mini-batch.
428
+ pin_memory (bool, optional): If ``True``, the data loader will copy tensors
429
+ into CUDA pinned memory before returning them.
430
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
431
+ if the dataset size is not divisible by the batch size. If ``False`` and
432
+ the size of dataset is not divisible by the batch size, then the last batch
433
+ will be smaller. (default: False)
434
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
435
+ from workers. Should always be non-negative. (default: 0)
436
+ worker_init_fn (callable, optional): If not None, this will be called on each
437
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
438
+ input, after seeding and before data loading. (default: None)
439
+
440
+ .. note:: By default, each worker will have its PyTorch seed set to
441
+ ``base_seed + worker_id``, where ``base_seed`` is a long generated
442
+ by main process using its RNG. However, seeds for other libraies
443
+ may be duplicated upon initializing workers (w.g., NumPy), causing
444
+ each worker to return identical random numbers. (See
445
+ :ref:`dataloader-workers-random-seed` section in FAQ.) You may
446
+ use ``torch.initial_seed()`` to access the PyTorch seed for each
447
+ worker in :attr:`worker_init_fn`, and use it to set other seeds
448
+ before data loading.
449
+
450
+ .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
451
+ unpicklable object, e.g., a lambda function.
452
+ """
453
+
454
+ __initialized = False
455
+
456
+ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
457
+ num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
458
+ timeout=0, worker_init_fn=None):
459
+ self.dataset = dataset
460
+ self.batch_size = batch_size
461
+ self.num_workers = num_workers
462
+ self.collate_fn = collate_fn
463
+ self.pin_memory = pin_memory
464
+ self.drop_last = drop_last
465
+ self.timeout = timeout
466
+ self.worker_init_fn = worker_init_fn
467
+
468
+ if timeout < 0:
469
+ raise ValueError('timeout option should be non-negative')
470
+
471
+ if batch_sampler is not None:
472
+ if batch_size > 1 or shuffle or sampler is not None or drop_last:
473
+ raise ValueError('batch_sampler option is mutually exclusive '
474
+ 'with batch_size, shuffle, sampler, and '
475
+ 'drop_last')
476
+ self.batch_size = None
477
+ self.drop_last = None
478
+
479
+ if sampler is not None and shuffle:
480
+ raise ValueError('sampler option is mutually exclusive with '
481
+ 'shuffle')
482
+
483
+ if self.num_workers < 0:
484
+ raise ValueError('num_workers option cannot be negative; '
485
+ 'use num_workers=0 to disable multiprocessing.')
486
+
487
+ if batch_sampler is None:
488
+ if sampler is None:
489
+ if shuffle:
490
+ sampler = RandomSampler(dataset)
491
+ else:
492
+ sampler = SequentialSampler(dataset)
493
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
494
+
495
+ self.sampler = sampler
496
+ self.batch_sampler = batch_sampler
497
+ self.__initialized = True
498
+
499
+ def __setattr__(self, attr, val):
500
+ if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
501
+ raise ValueError('{} attribute should not be set after {} is '
502
+ 'initialized'.format(attr, self.__class__.__name__))
503
+
504
+ super(SequentialDataLoader, self).__setattr__(attr, val)
505
+
506
+ def __iter__(self):
507
+ return _SequentialDataLoaderIter(self)
508
+
509
+ def __len__(self):
510
+ return len(self.batch_sampler)
511
+
nlu/DeBERTa/data/dynamic_dataset.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Author: penhe@microsoft.com
7
+ # Date: 05/15/2019
8
+ #
9
+
10
+ import pdb
11
+ from torch.utils.data import Dataset
12
+ import random
13
+ import mmap
14
+ import numpy as np
15
+ from bisect import bisect
16
+ from ..utils import get_logger
17
+ logger=get_logger()
18
+
19
+ __all__ = ['DynamicDataset']
20
+
21
+ class DynamicDataset(Dataset):
22
+ def __init__(self, corpus, feature_fn, dataset_size=None, shuffle=False, **kwargs):
23
+ self.corpus = corpus
24
+ self.ds_len = len(self.corpus)
25
+ logger.info(f'Total corpus examples: {self.ds_len}')
26
+ self.feature_fn = feature_fn
27
+
28
+ if not dataset_size:
29
+ self.dataset_size = self.ds_len
30
+ else:
31
+ self.dataset_size = int(dataset_size)
32
+
33
+ self.shuffle = shuffle
34
+ index_buf = mmap.mmap(-1, self.dataset_size*8)
35
+ shuffle_idx = np.ndarray(shape=(self.dataset_size, ), buffer=index_buf, dtype=int)
36
+ shuffle_idx[:] = np.arange(self.dataset_size)[:]
37
+ if self.shuffle:
38
+ #rng = np.random.RandomState(0)
39
+ rng = random.Random(0)
40
+ rng.shuffle(shuffle_idx)
41
+ self.shuffle_idx = shuffle_idx
42
+ self.index_offset = 0
43
+ if 'index_offset' in kwargs:
44
+ self.index_offset = kwargs['index_offset']
45
+
46
+ def __len__(self):
47
+ return self.dataset_size
48
+
49
+ def __getitem__(self, idx):
50
+ if isinstance(idx, tuple) or isinstance(idx, list):
51
+ idx, ext_params = idx
52
+ else:
53
+ ext_params = None
54
+ idx += self.index_offset
55
+ seed = idx
56
+ rng = random.Random(seed)
57
+ # get seq length
58
+ example_idx = self.shuffle_idx[idx%self.dataset_size]%self.ds_len
59
+ example = self.corpus[example_idx, rng, ext_params]
60
+ return self.feature_fn(example, rng, ext_params = ext_params)
nlu/DeBERTa/data/example.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from collections import OrderedDict
4
+ import numpy as np
5
+ import tempfile
6
+ import numpy as np
7
+ import mmap
8
+ import pickle
9
+ import signal
10
+ import sys
11
+ import pdb
12
+
13
+ from ..utils import xtqdm as tqdm
14
+
15
+ __all__=['ExampleInstance', 'example_to_feature', 'ExampleSet']
16
+
17
+ class ExampleInstance:
18
+ def __init__(self, segments, label=None, **kwv):
19
+ self.segments = segments
20
+ self.label = label
21
+ self.__dict__.update(kwv)
22
+
23
+ def __repr__(self):
24
+ return f'segments: {self.segments}\nlabel: {self.label}'
25
+
26
+ def __getitem__(self, i):
27
+ return self.segments[i]
28
+
29
+ def __len__(self):
30
+ return len(self.segments)
31
+
32
+ class ExampleSet:
33
+ def __init__(self, pairs):
34
+ self._data = np.array([pickle.dumps(p) for p in pairs])
35
+ self.total = len(self._data)
36
+
37
+ def __getitem__(self, idx):
38
+ """
39
+ return pair
40
+ """
41
+ if isinstance(idx, tuple):
42
+ idx,rng, ext_params = idx
43
+ else:
44
+ rng,ext_params=None, None
45
+ content = self._data[idx]
46
+ example = pickle.loads(content)
47
+ return example
48
+
49
+ def __len__(self):
50
+ return self.total
51
+
52
+ def __iter__(self):
53
+ for i in range(self.total):
54
+ yield self[i]
55
+
56
+ def _truncate_segments(segments, max_num_tokens, rng):
57
+ """
58
+ Truncate sequence pair according to original BERT implementation:
59
+ https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L391
60
+ """
61
+ while True:
62
+ if sum(len(s) for s in segments)<=max_num_tokens:
63
+ break
64
+
65
+ segments = sorted(segments, key=lambda s:len(s), reverse=True)
66
+ trunc_tokens = segments[0]
67
+
68
+ assert len(trunc_tokens) >= 1
69
+
70
+ if rng.random() < 0.5:
71
+ trunc_tokens.pop(0)
72
+ else:
73
+ trunc_tokens.pop()
74
+ return segments
75
+
76
+ def example_to_feature(tokenizer, example, max_seq_len=512, rng=None, mask_generator = None, ext_params=None, label_type='int', **kwargs):
77
+ if not rng:
78
+ rng = random
79
+ max_num_tokens = max_seq_len - len(example.segments) - 1
80
+ segments = _truncate_segments([tokenizer.tokenize(s) for s in example.segments], max_num_tokens, rng)
81
+ tokens = ['[CLS]']
82
+ type_ids = [0]
83
+ for i,s in enumerate(segments):
84
+ tokens.extend(s)
85
+ tokens.append('[SEP]')
86
+ type_ids.extend([i]*(len(s)+1))
87
+ if mask_generator:
88
+ tokens, lm_labels = mask_generator.mask_tokens(tokens, rng)
89
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
90
+ pos_ids = list(range(len(token_ids)))
91
+ input_mask = [1]*len(token_ids)
92
+ features = OrderedDict(input_ids = token_ids,
93
+ type_ids = type_ids,
94
+ position_ids = pos_ids,
95
+ input_mask = input_mask)
96
+ if mask_generator:
97
+ features['lm_labels'] = lm_labels
98
+ padding_size = max(0, max_seq_len - len(token_ids))
99
+ for f in features:
100
+ features[f].extend([0]*padding_size)
101
+ features[f] = torch.tensor(features[f], dtype=torch.int)
102
+ label_type = torch.int if label_type=='int' else torch.float
103
+ if example.label is not None:
104
+ features['labels'] = torch.tensor(example.label, dtype=label_type)
105
+ return features
nlu/DeBERTa/deberta/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Author: penhe@microsoft.com
3
+ # Date: 04/25/2019
4
+ #
5
+
6
+ """ Components for NN
7
+ """
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import division
11
+ from __future__ import print_function
12
+
13
+ from .tokenizers import *
14
+ from .pooling import *
15
+ from .mlm import MLMPredictionHead
16
+ from .nnmodule import NNModule
17
+ from .deberta import *
18
+ from .disentangled_attention import *
19
+ from .ops import *
20
+ from .bert import *
21
+ from .config import *
22
+ from .cache_utils import *
nlu/DeBERTa/deberta/bert.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) Microsoft, Inc. 2020
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This piece of code is modified based on https://github.com/huggingface/transformers
8
+
9
+ import copy
10
+ import torch
11
+ from torch import nn
12
+ from collections.abc import Sequence
13
+ from packaging import version
14
+ import numpy as np
15
+ import math
16
+ import os
17
+ import pdb
18
+
19
+ import json
20
+ from .ops import *
21
+ from .disentangled_attention import *
22
+ from .da_utils import *
23
+
24
+ from adapterlib import adapter_dict
25
+
26
+ __all__ = ['BertEncoder', 'BertEmbeddings', 'ACT2FN', 'LayerNorm', 'BertLMPredictionHead']
27
+
28
+ class BertSelfOutput(nn.Module):
29
+ def __init__(self, config):
30
+ super().__init__()
31
+ # self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
+ if config.inject_adapter != 'linear':
33
+ self.dense = adapter_dict[config.inject_adapter](config.hidden_size, config.hidden_size, config=config)
34
+ else:
35
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
36
+
37
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
38
+ self.dropout = StableDropout(config.hidden_dropout_prob)
39
+ self.config = config
40
+
41
+ def forward(self, hidden_states, input_states, mask=None):
42
+ hidden_states = self.dense(hidden_states)
43
+ hidden_states = self.dropout(hidden_states)
44
+ hidden_states += input_states
45
+ hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
46
+ return hidden_states
47
+
48
+ class BertAttention(nn.Module):
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.self = DisentangledSelfAttention(config)
52
+ self.output = BertSelfOutput(config)
53
+ self.config = config
54
+
55
+ def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
56
+ output = self.self(hidden_states, attention_mask, return_att, query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
57
+ self_output, att_matrix, att_logits_=output['hidden_states'], output['attention_probs'], output['attention_logits']
58
+ if query_states is None:
59
+ query_states = hidden_states
60
+ attention_output = self.output(self_output, query_states, attention_mask)
61
+
62
+ if return_att:
63
+ return (attention_output, att_matrix)
64
+ else:
65
+ return attention_output
66
+
67
+ class BertIntermediate(nn.Module):
68
+ def __init__(self, config):
69
+ super().__init__()
70
+ # self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
71
+ if config.inject_adapter != 'linear':
72
+ self.dense = adapter_dict[config.inject_adapter](config.hidden_size, config.intermediate_size, config=config)
73
+ else:
74
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
75
+
76
+ self.intermediate_act_fn = ACT2FN[config.hidden_act] \
77
+ if isinstance(config.hidden_act, str) else config.hidden_act
78
+
79
+ def forward(self, hidden_states):
80
+ hidden_states = self.dense(hidden_states)
81
+ hidden_states = self.intermediate_act_fn(hidden_states)
82
+ return hidden_states
83
+
84
+ class BertOutput(nn.Module):
85
+ def __init__(self, config):
86
+ super(BertOutput, self).__init__()
87
+ # self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
88
+ if config.inject_adapter != 'linear':
89
+ self.dense = adapter_dict[config.inject_adapter](config.intermediate_size, config.hidden_size, config=config)
90
+ else:
91
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
92
+
93
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
94
+ self.dropout = StableDropout(config.hidden_dropout_prob)
95
+ self.config = config
96
+
97
+ def forward(self, hidden_states, input_states, mask=None):
98
+ hidden_states = self.dense(hidden_states)
99
+ hidden_states = self.dropout(hidden_states)
100
+ hidden_states += input_states
101
+ hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
102
+ return hidden_states
103
+
104
+ class BertLayer(nn.Module):
105
+ def __init__(self, config):
106
+ super(BertLayer, self).__init__()
107
+ self.attention = BertAttention(config)
108
+ self.intermediate = BertIntermediate(config)
109
+ self.output = BertOutput(config)
110
+
111
+ def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
112
+ attention_output = self.attention(hidden_states, attention_mask, return_att=return_att, \
113
+ query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
114
+ if return_att:
115
+ attention_output, att_matrix = attention_output
116
+ intermediate_output = self.intermediate(attention_output)
117
+ layer_output = self.output(intermediate_output, attention_output, attention_mask)
118
+ if return_att:
119
+ return (layer_output, att_matrix)
120
+ else:
121
+ return layer_output
122
+
123
+ class ConvLayer(nn.Module):
124
+ def __init__(self, config):
125
+ super().__init__()
126
+ kernel_size = getattr(config, 'conv_kernel_size', 3)
127
+ groups = getattr(config, 'conv_groups', 1)
128
+ self.conv_act = getattr(config, 'conv_act', 'tanh')
129
+ self.conv = torch.nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size, padding = (kernel_size-1)//2, groups = groups)
130
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
131
+ self.dropout = StableDropout(config.hidden_dropout_prob)
132
+ self.config = config
133
+
134
+ def forward(self, hidden_states, residual_states, input_mask):
135
+ out = self.conv(hidden_states.permute(0,2,1).contiguous()).permute(0,2,1).contiguous()
136
+ if version.Version(torch.__version__) >= version.Version('1.2.0a'):
137
+ rmask = (1-input_mask).bool()
138
+ else:
139
+ rmask = (1-input_mask).byte()
140
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
141
+ out = ACT2FN[self.conv_act](self.dropout(out))
142
+ output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask)
143
+
144
+ return output_states
145
+
146
+ class BertEncoder(nn.Module):
147
+ """ Modified BertEncoder with relative position bias support
148
+ """
149
+ def __init__(self, config):
150
+ super().__init__()
151
+ #layer = BertLayer(config)
152
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
153
+ self.relative_attention = getattr(config, 'relative_attention', False)
154
+ if self.relative_attention:
155
+ self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
156
+ if self.max_relative_positions <1:
157
+ self.max_relative_positions = config.max_position_embeddings
158
+ self.position_buckets = getattr(config, 'position_buckets', -1)
159
+ pos_ebd_size = self.max_relative_positions*2
160
+ if self.position_buckets>0:
161
+ pos_ebd_size = self.position_buckets*2
162
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
163
+
164
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, 'norm_rel_ebd', 'none').lower().split('|')]
165
+ if 'layer_norm' in self.norm_rel_ebd:
166
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine = True)
167
+ kernel_size = getattr(config, 'conv_kernel_size', 0)
168
+ self.with_conv = False
169
+ if kernel_size > 0:
170
+ self.with_conv = True
171
+ self.conv = ConvLayer(config)
172
+
173
+ def get_rel_embedding(self):
174
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
175
+ if rel_embeddings is not None and ('layer_norm' in self.norm_rel_ebd):
176
+ rel_embeddings = self.LayerNorm(rel_embeddings)
177
+ return rel_embeddings
178
+
179
+ def get_attention_mask(self, attention_mask):
180
+ if attention_mask.dim()<=2:
181
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
182
+ attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1)
183
+ attention_mask = attention_mask.byte()
184
+ elif attention_mask.dim()==3:
185
+ attention_mask = attention_mask.unsqueeze(1)
186
+
187
+ return attention_mask
188
+
189
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
190
+ if self.relative_attention and relative_pos is None:
191
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
192
+ relative_pos = build_relative_position(q, hidden_states.size(-2), bucket_size = self.position_buckets, \
193
+ max_position=self.max_relative_positions, device = hidden_states.device)
194
+ return relative_pos
195
+
196
+ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None):
197
+ if attention_mask.dim()<=2:
198
+ input_mask = attention_mask
199
+ else:
200
+ input_mask = (attention_mask.sum(-2)>0).byte()
201
+ attention_mask = self.get_attention_mask(attention_mask)
202
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
203
+
204
+ all_encoder_layers = []
205
+ att_matrices = []
206
+ if isinstance(hidden_states, Sequence):
207
+ next_kv = hidden_states[0]
208
+ else:
209
+ next_kv = hidden_states
210
+ rel_embeddings = self.get_rel_embedding()
211
+ for i, layer_module in enumerate(self.layer):
212
+ output_states = layer_module(next_kv, attention_mask, return_att, query_states = query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
213
+ if return_att:
214
+ output_states, att_m = output_states
215
+
216
+ if i == 0 and self.with_conv:
217
+ prenorm = output_states #output['prenorm_states']
218
+ output_states = self.conv(hidden_states, prenorm, input_mask)
219
+
220
+ if query_states is not None:
221
+ query_states = output_states
222
+ if isinstance(hidden_states, Sequence):
223
+ next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None
224
+ else:
225
+ next_kv = output_states
226
+
227
+ if output_all_encoded_layers:
228
+ all_encoder_layers.append(output_states)
229
+ if return_att:
230
+ att_matrices.append(att_m)
231
+ if not output_all_encoded_layers:
232
+ all_encoder_layers.append(output_states)
233
+ if return_att:
234
+ att_matrices.append(att_m)
235
+ return {
236
+ 'hidden_states': all_encoder_layers,
237
+ 'attention_matrices': att_matrices
238
+ }
239
+
240
+ class BertEmbeddings(nn.Module):
241
+ """Construct the embeddings from word, position and token_type embeddings.
242
+ """
243
+ def __init__(self, config):
244
+ super(BertEmbeddings, self).__init__()
245
+ padding_idx = getattr(config, 'padding_idx', 0)
246
+ self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
247
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx = padding_idx)
248
+ self.position_biased_input = getattr(config, 'position_biased_input', True)
249
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
250
+
251
+ if config.type_vocab_size>0:
252
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
253
+
254
+ if self.embedding_size != config.hidden_size:
255
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
256
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
257
+ self.dropout = StableDropout(config.hidden_dropout_prob)
258
+ self.output_to_half = False
259
+ self.config = config
260
+
261
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, mask = None):
262
+ seq_length = input_ids.size(1)
263
+ if position_ids is None:
264
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
265
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
266
+ if token_type_ids is None:
267
+ token_type_ids = torch.zeros_like(input_ids)
268
+
269
+ words_embeddings = self.word_embeddings(input_ids)
270
+ position_embeddings = self.position_embeddings(position_ids.long())
271
+
272
+ embeddings = words_embeddings
273
+ if self.config.type_vocab_size>0:
274
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
275
+ embeddings += token_type_embeddings
276
+
277
+ if self.position_biased_input:
278
+ embeddings += position_embeddings
279
+
280
+ if self.embedding_size != self.config.hidden_size:
281
+ embeddings = self.embed_proj(embeddings)
282
+ embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask)
283
+ embeddings = self.dropout(embeddings)
284
+ return {
285
+ 'embeddings': embeddings,
286
+ 'position_embeddings': position_embeddings}
287
+
288
+ class BertLMPredictionHead(nn.Module):
289
+ def __init__(self, config, vocab_size):
290
+ super().__init__()
291
+ self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
292
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
293
+ self.transform_act_fn = ACT2FN[config.hidden_act] \
294
+ if isinstance(config.hidden_act, str) else config.hidden_act
295
+
296
+ self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps, elementwise_affine=True)
297
+
298
+ self.bias = nn.Parameter(torch.zeros(vocab_size))
299
+
300
+ def forward(self, hidden_states, embeding_weight):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.transform_act_fn(hidden_states)
303
+ # b x s x d
304
+ hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
305
+
306
+ # b x s x v
307
+ logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias
308
+ return logits
nlu/DeBERTa/deberta/cache_utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Author: penhe@microsoft.com
7
+ # Date: 05/15/2020
8
+ #
9
+
10
+ import pdb
11
+ import torch
12
+ import os
13
+ import requests
14
+ from .config import ModelConfig
15
+ import pathlib
16
+ from ..utils import xtqdm as tqdm
17
+ from zipfile import ZipFile
18
+ from ..utils import get_logger
19
+ logger = get_logger()
20
+
21
+ __all__ = ['pretrained_models', 'load_model_state', 'load_vocab']
22
+
23
+ class PretrainedModel:
24
+ def __init__(self, name, vocab, vocab_type, model='pytorch_model.bin', config='config.json', **kwargs):
25
+ self.__dict__.update(kwargs)
26
+ host = f'https://huggingface.co/microsoft/{name}/resolve/main/'
27
+ self.name = name
28
+ self.model_url = host + model
29
+ self.config_url = host + config
30
+ self.vocab_url = host + vocab
31
+ self.vocab_type = vocab_type
32
+
33
+ pretrained_models= {
34
+ 'base': PretrainedModel('deberta-base', 'bpe_encoder.bin', 'gpt2'),
35
+ 'large': PretrainedModel('deberta-large', 'bpe_encoder.bin', 'gpt2'),
36
+ 'xlarge': PretrainedModel('deberta-xlarge', 'bpe_encoder.bin', 'gpt2'),
37
+ 'base-mnli': PretrainedModel('deberta-base-mnli', 'bpe_encoder.bin', 'gpt2'),
38
+ 'large-mnli': PretrainedModel('deberta-large-mnli', 'bpe_encoder.bin', 'gpt2'),
39
+ 'xlarge-mnli': PretrainedModel('deberta-xlarge-mnli', 'bpe_encoder.bin', 'gpt2'),
40
+ 'xlarge-v2': PretrainedModel('deberta-v2-xlarge', 'spm.model', 'spm'),
41
+ 'xxlarge-v2': PretrainedModel('deberta-v2-xxlarge', 'spm.model', 'spm'),
42
+ 'xlarge-v2-mnli': PretrainedModel('deberta-v2-xlarge-mnli', 'spm.model', 'spm'),
43
+ 'xxlarge-v2-mnli': PretrainedModel('deberta-v2-xxlarge-mnli', 'spm.model', 'spm'),
44
+ 'deberta-v3-small': PretrainedModel('deberta-v3-small', 'spm.model', 'spm'),
45
+ 'deberta-v3-base': PretrainedModel('deberta-v3-base', 'spm.model', 'spm'),
46
+ 'deberta-v3-large': PretrainedModel('deberta-v3-large', 'spm.model', 'spm'),
47
+ 'mdeberta-v3-base': PretrainedModel('mdeberta-v3-base', 'spm.model', 'spm'),
48
+ 'deberta-v3-xsmall': PretrainedModel('deberta-v3-xsmall', 'spm.model', 'spm'),
49
+ }
50
+
51
+ def download_asset(url, name, tag=None, no_cache=False, cache_dir=None):
52
+ _tag = tag
53
+ if _tag is None:
54
+ _tag = 'latest'
55
+ if not cache_dir:
56
+ cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/')
57
+ os.makedirs(cache_dir, exist_ok=True)
58
+ output=os.path.join(cache_dir, name)
59
+ if os.path.exists(output) and (not no_cache):
60
+ return output
61
+
62
+ #repo=f'https://huggingface.co/microsoft/deberta-{name}/blob/main/bpe_encoder.bin'
63
+ headers = {}
64
+ headers['Accept'] = 'application/octet-stream'
65
+ resp = requests.get(url, stream=True, headers=headers)
66
+ if resp.status_code != 200:
67
+ raise Exception(f'Request for {url} return {resp.status_code}, {resp.text}')
68
+
69
+ try:
70
+ with open(output, 'wb') as fs:
71
+ progress = tqdm(total=int(resp.headers['Content-Length']) if 'Content-Length' in resp.headers else -1, ncols=80, desc=f'Downloading {name}')
72
+ for c in resp.iter_content(chunk_size=1024*1024):
73
+ fs.write(c)
74
+ progress.update(len(c))
75
+ progress.close()
76
+ except:
77
+ os.remove(output)
78
+ raise
79
+
80
+ return output
81
+
82
+ def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=None):
83
+ model_path = path_or_pretrained_id
84
+ if model_path and (not os.path.exists(model_path)) and (path_or_pretrained_id.lower() in pretrained_models):
85
+ _tag = tag
86
+ if 'deberta-v3-base' in path_or_pretrained_id:
87
+ pretrained = pretrained_models['deberta-v3-base']
88
+ else:
89
+ pretrained = pretrained_models[path_or_pretrained_id.lower()]
90
+ if _tag is None:
91
+ _tag = 'latest'
92
+ if not cache_dir:
93
+ cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
94
+ os.makedirs(cache_dir, exist_ok=True)
95
+ model_path = os.path.join(cache_dir, 'pytorch_model.bin')
96
+ if (not os.path.exists(model_path)) or no_cache:
97
+ asset = download_asset(pretrained.model_url, 'pytorch_model.bin', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
98
+ asset = download_asset(pretrained.config_url, 'model_config.json', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
99
+ elif not model_path:
100
+ return None,None
101
+
102
+ model_path = os.path.join(model_path, 'pytorch_model.bin')
103
+ config_path = os.path.join(os.path.dirname(model_path), 'model_config.json')
104
+ model_state = torch.load(model_path, map_location='cpu')
105
+ logger.info("Loaded pretrained model file {}".format(model_path))
106
+ if 'config' in model_state:
107
+ model_config = ModelConfig.from_dict(model_state['config'])
108
+ elif os.path.exists(config_path):
109
+ model_config = ModelConfig.from_json_file(config_path)
110
+ else:
111
+ model_config = None
112
+ return model_state, model_config
113
+
114
+ def load_vocab(vocab_path=None, vocab_type=None, pretrained_id=None, tag=None, no_cache=False, cache_dir=None):
115
+ if pretrained_id and (pretrained_id.lower() in pretrained_models):
116
+ _tag = tag
117
+ if _tag is None:
118
+ _tag = 'latest'
119
+
120
+ pretrained = pretrained_models[pretrained_id.lower()]
121
+ if not cache_dir:
122
+ cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
123
+ os.makedirs(cache_dir, exist_ok=True)
124
+ vocab_type = pretrained.vocab_type
125
+ url = pretrained.vocab_url
126
+ outname = os.path.basename(url)
127
+ vocab_path =os.path.join(cache_dir, outname)
128
+ if (not os.path.exists(vocab_path)) or no_cache:
129
+ asset = download_asset(url, outname, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
130
+ if vocab_type is None:
131
+ vocab_type = 'spm'
132
+ return vocab_path, vocab_type
133
+
134
+ def test_download():
135
+ vocab = load_vocab()
nlu/DeBERTa/deberta/config.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import copy
3
+
4
+ __all__=['AbsModelConfig', 'ModelConfig']
5
+
6
+ class AbsModelConfig(object):
7
+ def __init__(self):
8
+ pass
9
+
10
+ @classmethod
11
+ def from_dict(cls, json_object):
12
+ """Constructs a `ModelConfig` from a Python dictionary of parameters."""
13
+ config = cls()
14
+ for key, value in json_object.items():
15
+ if isinstance(value, dict):
16
+ value = AbsModelConfig.from_dict(value)
17
+ config.__dict__[key] = value
18
+ return config
19
+
20
+ @classmethod
21
+ def from_json_file(cls, json_file):
22
+ """Constructs a `ModelConfig` from a json file of parameters."""
23
+ with open(json_file, "r", encoding='utf-8') as reader:
24
+ text = reader.read()
25
+ return cls.from_dict(json.loads(text))
26
+
27
+ def __repr__(self):
28
+ return str(self.to_json_string())
29
+
30
+ def to_dict(self):
31
+ """Serializes this instance to a Python dictionary."""
32
+ output = copy.deepcopy(self.__dict__)
33
+ return output
34
+
35
+ def to_json_string(self):
36
+ """Serializes this instance to a JSON string."""
37
+ def _json_default(obj):
38
+ if isinstance(obj, AbsModelConfig):
39
+ return obj.__dict__
40
+ return json.dumps(self.__dict__, indent=2, sort_keys=True, default=_json_default) + "\n"
41
+
42
+ class ModelConfig(AbsModelConfig):
43
+ """Configuration class to store the configuration of a :class:`~DeBERTa.deberta.DeBERTa` model.
44
+
45
+ Attributes:
46
+ hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`.
47
+ num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`.
48
+ num_attention_heads (int): Number of attention heads for each attention layer in
49
+ the Transformer encoder, default: `12`.
50
+ intermediate_size (int): The size of the "intermediate" (i.e., feed-forward)
51
+ layer in the Transformer encoder, default: `3072`.
52
+ hidden_act (str): The non-linear activation function (function or string) in the
53
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`.
54
+ hidden_dropout_prob (float): The dropout probabilitiy for all fully connected
55
+ layers in the embeddings, encoder, and pooler, default: `0.1`.
56
+ attention_probs_dropout_prob (float): The dropout ratio for the attention
57
+ probabilities, default: `0.1`.
58
+ max_position_embeddings (int): The maximum sequence length that this model might
59
+ ever be used with. Typically set this to something large just in case
60
+ (e.g., 512 or 1024 or 2048), default: `512`.
61
+ type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into
62
+ `DeBERTa` model, default: `-1`.
63
+ initializer_range (int): The sttdev of the _normal_initializer for
64
+ initializing all weight matrices, default: `0.02`.
65
+ relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`.
66
+ max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`.
67
+ padding_idx (int): The value used to pad input_ids, default: `0`.
68
+ position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`.
69
+ pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p"., default: "None".
70
+
71
+
72
+ """
73
+ def __init__(self):
74
+ """Constructs ModelConfig.
75
+
76
+ """
77
+
78
+ self.hidden_size = 768
79
+ self.num_hidden_layers = 12
80
+ self.num_attention_heads = 12
81
+ self.hidden_act = "gelu"
82
+ self.intermediate_size = 3072
83
+ self.hidden_dropout_prob = 0.1
84
+ self.attention_probs_dropout_prob = 0.1
85
+ self.max_position_embeddings = 512
86
+ self.type_vocab_size = 0
87
+ self.initializer_range = 0.02
88
+ self.layer_norm_eps = 1e-7
89
+ self.padding_idx = 0
90
+ self.vocab_size = -1
nlu/DeBERTa/deberta/da_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pdb
3
+ from functools import lru_cache
4
+ import numpy as np
5
+ import math
6
+
7
+ __all__=['build_relative_position', 'make_log_bucket_position']
8
+
9
+ @lru_cache(maxsize=128)
10
+ def make_log_bucket_dict(bucket_size, max_position, device=None):
11
+ relative_pos = torch.arange(-max_position, max_position, device=device)
12
+ sign = torch.sign(relative_pos)
13
+ mid = bucket_size//2
14
+ abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), torch.tensor(mid-1).to(relative_pos), torch.abs(relative_pos))
15
+ log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid
16
+ bucket_pos = torch.where(abs_pos<=mid, relative_pos, (log_pos*sign).to(relative_pos)).to(torch.long)
17
+ return bucket_pos
18
+
19
+ # Faster version
20
+ def make_log_bucket_position(relative_pos, bucket_size, max_position):
21
+ relative_pos = torch.clamp(relative_pos,-max_position+1, max_position-1) + max_position
22
+ bucket_dict = make_log_bucket_dict(bucket_size, max_position, relative_pos.device)
23
+ for d in range(relative_pos.dim()-1):
24
+ bucket_dict = bucket_dict.unsqueeze(0)
25
+ bucket_pos = torch.gather(bucket_dict.expand(list(relative_pos.size())[:-1] + [bucket_dict.size(-1)]), index=relative_pos.long(), dim=-1)
26
+ return bucket_pos
27
+
28
+ @lru_cache(maxsize=128)
29
+ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
30
+ q_ids = torch.arange(0, query_size)
31
+ k_ids = torch.arange(0, key_size)
32
+ if device is not None:
33
+ q_ids = q_ids.to(device)
34
+ k_ids = k_ids.to(device)
35
+ rel_pos_ids = q_ids.view(-1,1) - k_ids.view(1,-1)
36
+ #q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
37
+ if bucket_size>0 and max_position > 0:
38
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
39
+ #rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
40
+ rel_pos_ids = rel_pos_ids[:query_size, :]
41
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
42
+ return rel_pos_ids
43
+
44
+ def build_relative_position_from_abs(query_pos, key_pos, bucket_size=-1, max_position=-1, device=None):
45
+ if isinstance(query_pos, tuple):
46
+ q_ids = torch.tensor(query_pos)
47
+ else:
48
+ q_ids = query_pos
49
+ if isinstance(key_pos, tuple):
50
+ k_ids = torch.tensor(key_pos)
51
+ else:
52
+ k_ids = key_pos
53
+
54
+ if device is not None:
55
+ q_ids = q_ids.to(device)
56
+ k_ids = k_ids.to(device)
57
+ rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.unsqueeze(-2)
58
+ #q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
59
+ if bucket_size>0 and max_position > 0:
60
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
61
+ #rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
62
+ return rel_pos_ids
63
+
64
+ def test_log_bucket():
65
+ x=np.arange(-511,511)
66
+ y=make_log_bucket_position(x, 128, 512)
67
+ pdb.set_trace()
68
+
nlu/DeBERTa/deberta/deberta.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Author: penhe@microsoft.com
7
+ # Date: 01/15/2020
8
+ #
9
+
10
+ import copy
11
+ import torch
12
+ import os
13
+
14
+ import json
15
+ from .ops import *
16
+ from .bert import *
17
+ from .config import ModelConfig
18
+ from .cache_utils import load_model_state
19
+ import pdb
20
+
21
+ __all__ = ['DeBERTa']
22
+
23
+ class DeBERTa(torch.nn.Module):
24
+ """ DeBERTa encoder
25
+ This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.
26
+
27
+ Parameters:
28
+ config:
29
+ A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
30
+ for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`
31
+
32
+ pre_trained:
33
+ The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \
34
+ i.e. [**base, large, base_mnli, large_mnli**]
35
+
36
+ """
37
+
38
+ def __init__(self, config=None, pre_trained=None):
39
+ super().__init__()
40
+ state = None
41
+ if pre_trained is not None:
42
+ state, model_config = load_model_state(pre_trained)
43
+ if config is not None and model_config is not None:
44
+ for k in config.__dict__:
45
+ if k not in ['hidden_size',
46
+ 'intermediate_size',
47
+ 'num_attention_heads',
48
+ 'num_hidden_layers',
49
+ 'vocab_size',
50
+ 'max_position_embeddings']:
51
+ model_config.__dict__[k] = config.__dict__[k]
52
+ config = copy.copy(model_config)
53
+ self.embeddings = BertEmbeddings(config)
54
+ self.encoder = BertEncoder(config)
55
+ self.config = config
56
+ self.pre_trained = pre_trained
57
+ self.apply_state(state)
58
+
59
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
60
+ """
61
+ Args:
62
+ input_ids:
63
+ a torch.LongTensor of shape [batch_size, sequence_length] \
64
+ with the word token indices in the vocabulary
65
+
66
+ attention_mask:
67
+ an optional parameter for input mask or attention mask.
68
+
69
+ - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
70
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
71
+ input sequence length in the current batch. It's the mask that we typically use for attention when \
72
+ a batch has varying length sentences.
73
+
74
+ - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
75
+ In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
76
+
77
+ token_type_ids:
78
+ an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
79
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
80
+ a `sentence B` token (see BERT paper for more details).
81
+
82
+ output_all_encoded_layers:
83
+ whether to output results of all encoder layers, default, True
84
+
85
+ Returns:
86
+
87
+ - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
88
+ the last layer of stacked transformer layers
89
+
90
+ - Attention matrix of self-attention layers if `return_att=True`
91
+
92
+
93
+ Example::
94
+
95
+ # Batch of wordPiece token ids.
96
+ # Each sample was padded with zero to the maxium length of the batch
97
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
98
+ # Mask of valid input ids
99
+ attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
100
+
101
+ # DeBERTa model initialized with pretrained base model
102
+ bert = DeBERTa(pre_trained='base')
103
+
104
+ encoder_layers = bert(input_ids, attention_mask=attention_mask)
105
+
106
+ """
107
+
108
+ if attention_mask is None:
109
+ attention_mask = torch.ones_like(input_ids)
110
+ if token_type_ids is None:
111
+ token_type_ids = torch.zeros_like(input_ids)
112
+
113
+ ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, attention_mask)
114
+ embedding_output = ebd_output['embeddings']
115
+ encoder_output = self.encoder(embedding_output,
116
+ attention_mask,
117
+ output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
118
+ encoder_output.update(ebd_output)
119
+ return encoder_output
120
+
121
+ def apply_state(self, state = None):
122
+ """ Load state from previous loaded model state dictionary.
123
+
124
+ Args:
125
+ state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
126
+ If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
127
+ the `DeBERTa` model
128
+ """
129
+ if self.pre_trained is None and state is None:
130
+ return
131
+ if state is None:
132
+ state, config = load_model_state(self.pre_trained)
133
+ self.config = config
134
+
135
+ prefix = ''
136
+ for k in state:
137
+ if 'embeddings.' in k:
138
+ if not k.startswith('embeddings.'):
139
+ prefix = k[:k.index('embeddings.')]
140
+ break
141
+
142
+ missing_keys = []
143
+ unexpected_keys = []
144
+ error_msgs = []
145
+ self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
nlu/DeBERTa/deberta/disentangled_attention.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Author: penhe@microsoft.com
7
+ # Date: 01/15/2020
8
+ #
9
+
10
+ """
11
+ Disentangled SelfAttention module
12
+ """
13
+
14
+ import numpy as np
15
+ import math
16
+ import torch
17
+ from torch import nn
18
+ import functools
19
+ import pdb
20
+
21
+ from .ops import *
22
+ from .da_utils import build_relative_position
23
+
24
+ from ..utils import get_logger
25
+ logger=get_logger()
26
+
27
+ from adapterlib import adapter_dict
28
+
29
+ __all__=['DisentangledSelfAttention']
30
+ class DisentangledSelfAttention(nn.Module):
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ self.num_attention_heads = config.num_attention_heads
34
+ _attention_head_size = int(config.hidden_size / config.num_attention_heads)
35
+ self.attention_head_size = getattr(config, 'attention_head_size', _attention_head_size)
36
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
37
+
38
+ # -----------------------------------------------------------------------------------------------------------------------
39
+ if config.inject_adapter != 'linear':
40
+ self.query_proj = adapter_dict[config.inject_adapter](config.hidden_size, self.all_head_size, config=config)
41
+ else:
42
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
43
+
44
+ # self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
45
+ if config.inject_adapter != 'linear':
46
+ self.key_proj = adapter_dict[config.inject_adapter](config.hidden_size, self.all_head_size, config=config)
47
+ else:
48
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
49
+
50
+ if config.inject_adapter != 'linear':
51
+ self.value_proj = adapter_dict[config.inject_adapter](config.hidden_size, self.all_head_size, config=config)
52
+ else:
53
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
54
+
55
+ # -----------------------------------------------------------------------------------------------------------------------
56
+
57
+ # self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
58
+ # self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
59
+ # self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
60
+
61
+ self.share_att_key = getattr(config, 'share_att_key', False)
62
+ self.pos_att_type = [x.strip() for x in getattr(config, 'pos_att_type', 'c2p').lower().split('|')] # c2p|p2c
63
+ self.relative_attention = getattr(config, 'relative_attention', False)
64
+
65
+ if self.relative_attention:
66
+ self.position_buckets = getattr(config, 'position_buckets', -1)
67
+ self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
68
+ if self.max_relative_positions <1:
69
+ self.max_relative_positions = config.max_position_embeddings
70
+ self.pos_ebd_size = self.max_relative_positions
71
+ if self.position_buckets>0:
72
+ self.pos_ebd_size = self.position_buckets
73
+ # For backward compitable
74
+
75
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
76
+
77
+ if (not self.share_att_key):
78
+ if 'c2p' in self.pos_att_type or 'p2p' in self.pos_att_type:
79
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
80
+ if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
81
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
82
+
83
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
84
+ self._register_load_state_dict_pre_hook(self._pre_load_hook)
85
+
86
+ def transpose_for_scores(self, x, attention_heads):
87
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
88
+ x = x.view(*new_x_shape)
89
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
90
+
91
+ def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
92
+ if query_states is None:
93
+ query_states = hidden_states
94
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads).float()
95
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads).float()
96
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
97
+
98
+ rel_att = None
99
+ # Take the dot product between "query" and "key" to get the raw attention scores.
100
+ scale_factor = 1
101
+ if 'c2p' in self.pos_att_type:
102
+ scale_factor += 1
103
+ if 'p2c' in self.pos_att_type:
104
+ scale_factor += 1
105
+ if 'p2p' in self.pos_att_type:
106
+ scale_factor += 1
107
+ scale = 1/math.sqrt(query_layer.size(-1)*scale_factor)
108
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)*scale)
109
+ if self.relative_attention:
110
+ rel_embeddings = self.pos_dropout(rel_embeddings)
111
+ rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
112
+
113
+ if rel_att is not None:
114
+ attention_scores = (attention_scores + rel_att)
115
+ attention_scores = (attention_scores - attention_scores.max(dim=-1, keepdim=True).values.detach()).to(hidden_states)
116
+ attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1))
117
+
118
+ # bxhxlxd
119
+ _attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
120
+ attention_probs = self.dropout(_attention_probs)
121
+ context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer)
122
+ context_layer = context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)).permute(0, 2, 1, 3).contiguous()
123
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
124
+ context_layer = context_layer.view(*new_context_layer_shape)
125
+
126
+ return {
127
+ 'hidden_states': context_layer,
128
+ 'attention_probs': _attention_probs,
129
+ 'attention_logits': attention_scores
130
+ }
131
+
132
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
133
+ if relative_pos is None:
134
+ q = query_layer.size(-2)
135
+ relative_pos = build_relative_position(q, key_layer.size(-2), bucket_size = self.position_buckets, \
136
+ max_position = self.max_relative_positions, device=query_layer.device)
137
+ if relative_pos.dim()==2:
138
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
139
+ elif relative_pos.dim()==3:
140
+ relative_pos = relative_pos.unsqueeze(1)
141
+ # bxhxqxk
142
+ elif relative_pos.dim()!=4:
143
+ raise ValueError(f'Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}')
144
+
145
+ att_span = self.pos_ebd_size
146
+ relative_pos = relative_pos.long().to(query_layer.device)
147
+
148
+ rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span:self.pos_ebd_size + att_span, :].unsqueeze(0) #.repeat(query_layer.size(0)//self.num_attention_heads, 1, 1)
149
+ if self.share_att_key:
150
+ pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings), self.num_attention_heads)\
151
+ .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
152
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads)\
153
+ .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
154
+ else:
155
+ if 'c2p' in self.pos_att_type or 'p2p' in self.pos_att_type:
156
+ pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads)\
157
+ .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
158
+ if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
159
+ pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads)\
160
+ .repeat(query_layer.size(0)//self.num_attention_heads, 1, 1) #.split(self.all_head_size, dim=-1)
161
+
162
+ score = 0
163
+ # content->position
164
+ if 'c2p' in self.pos_att_type:
165
+ scale = 1/math.sqrt(pos_key_layer.size(-1)*scale_factor)
166
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2).to(query_layer)*scale)
167
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span*2-1).squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)])
168
+ c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos)
169
+ score += c2p_att
170
+
171
+ # position->content
172
+ if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
173
+ scale = 1/math.sqrt(pos_query_layer.size(-1)*scale_factor)
174
+
175
+ if 'p2c' in self.pos_att_type:
176
+ p2c_att = torch.bmm(pos_query_layer.to(key_layer)*scale, key_layer.transpose(-1, -2))
177
+ p2c_att = torch.gather(p2c_att, dim=-2, index=c2p_pos)
178
+ score += p2c_att
179
+
180
+ # position->position
181
+ if 'p2p' in self.pos_att_type:
182
+ pos_query = pos_query_layer[:,:,att_span:,:]
183
+ p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
184
+ p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
185
+ if query_layer.size(-2) != key_layer.size(-2):
186
+ p2p_att = torch.gather(p2p_att, dim=-2, index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))))
187
+ p2p_att = torch.gather(p2p_att, dim=-1, index=c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]))
188
+ score += p2p_att
189
+
190
+ return score
191
+
192
+ def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
193
+ missing_keys, unexpected_keys, error_msgs):
194
+ self_state = self.state_dict()
195
+ if ((prefix + 'query_proj.weight') not in state_dict) and ((prefix + 'in_proj.weight') in state_dict):
196
+ v1_proj = state_dict[prefix+'in_proj.weight']
197
+ v1_proj = v1_proj.unsqueeze(0).reshape(self.num_attention_heads, -1, v1_proj.size(-1))
198
+ q,k,v=v1_proj.chunk(3, dim=1)
199
+ state_dict[prefix + 'query_proj.weight'] = q.reshape(-1, v1_proj.size(-1))
200
+ state_dict[prefix + 'key_proj.weight'] = k.reshape(-1, v1_proj.size(-1))
201
+ state_dict[prefix + 'key_proj.bias'] = self_state['key_proj.bias']
202
+ state_dict[prefix + 'value_proj.weight'] = v.reshape(-1, v1_proj.size(-1))
203
+ v1_query_bias = state_dict[prefix + 'q_bias']
204
+ state_dict[prefix + 'query_proj.bias'] = v1_query_bias
205
+ v1_value_bias = state_dict[prefix +'v_bias']
206
+ state_dict[prefix + 'value_proj.bias'] = v1_value_bias
207
+
208
+ v1_pos_key_proj = state_dict[prefix + 'pos_proj.weight']
209
+ state_dict[prefix + 'pos_key_proj.weight'] = v1_pos_key_proj
210
+ v1_pos_query_proj = state_dict[prefix + 'pos_q_proj.weight']
211
+ state_dict[prefix + 'pos_query_proj.weight'] = v1_pos_query_proj
212
+ v1_pos_query_proj_bias = state_dict[prefix + 'pos_q_proj.bias']
213
+ state_dict[prefix + 'pos_query_proj.bias'] = v1_pos_query_proj_bias
214
+ state_dict[prefix + 'pos_key_proj.bias'] = self_state['pos_key_proj.bias']
215
+
216
+ del state_dict[prefix + 'in_proj.weight']
217
+ del state_dict[prefix + 'q_bias']
218
+ del state_dict[prefix + 'v_bias']
219
+ del state_dict[prefix + 'pos_proj.weight']
220
+ del state_dict[prefix + 'pos_q_proj.weight']
221
+ del state_dict[prefix + 'pos_q_proj.bias']
nlu/DeBERTa/deberta/gpt2_bpe_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte pair encoding utilities from GPT-2.
3
+
4
+ Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
5
+ Original license: MIT
6
+ """
7
+
8
+ from functools import lru_cache
9
+ import json
10
+ import random
11
+ import unicodedata
12
+
13
+ try:
14
+ import regex as re
15
+ except ImportError:
16
+ raise ImportError('Please install regex with: pip install regex')
17
+
18
+ @lru_cache()
19
+ def bytes_to_unicode():
20
+ """
21
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
22
+ The reversible bpe codes work on unicode strings.
23
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
24
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
25
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
26
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
27
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
28
+ """
29
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
30
+ cs = bs[:]
31
+ n = 0
32
+ for b in range(2**8):
33
+ if b not in bs:
34
+ bs.append(b)
35
+ cs.append(2**8+n)
36
+ n += 1
37
+ cs = [chr(n) for n in cs]
38
+ return dict(zip(bs, cs))
39
+
40
+ def get_pairs(word):
41
+ """Return set of symbol pairs in a word.
42
+ Word is represented as tuple of symbols (symbols being variable-length strings).
43
+ """
44
+ pairs = set()
45
+ prev_char = word[0]
46
+ for char in word[1:]:
47
+ pairs.add((prev_char, char))
48
+ prev_char = char
49
+ return pairs
50
+
51
+ class Encoder:
52
+
53
+ def __init__(self, encoder, bpe_merges, errors='replace'):
54
+ self.encoder = encoder
55
+ self.decoder = {v:k for k,v in self.encoder.items()}
56
+ self.errors = errors # how to handle errors in decoding
57
+ self.byte_encoder = bytes_to_unicode()
58
+ self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
59
+ self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges))))
60
+ self.cache = {}
61
+ self.random = random.Random(0)
62
+
63
+ # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
64
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
65
+
66
+ def bpe(self, token):
67
+ if token in self.cache:
68
+ return self.cache[token]
69
+ word = tuple(token)
70
+ pairs = get_pairs(word)
71
+
72
+ if not pairs:
73
+ return token
74
+
75
+ while True:
76
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
77
+ if bigram not in self.bpe_ranks:
78
+ break
79
+ first, second = bigram
80
+ new_word = []
81
+ i = 0
82
+ while i < len(word):
83
+ try:
84
+ j = word.index(first, i)
85
+ new_word.extend(word[i:j])
86
+ i = j
87
+ except:
88
+ new_word.extend(word[i:])
89
+ break
90
+
91
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
92
+ new_word.append(first+second)
93
+ i += 2
94
+ else:
95
+ new_word.append(word[i])
96
+ i += 1
97
+ new_word = tuple(new_word)
98
+ word = new_word
99
+ if len(word) == 1:
100
+ break
101
+ else:
102
+ pairs = get_pairs(word)
103
+ word = ' '.join(word)
104
+ self.cache[token] = word
105
+ return word
106
+
107
+ def split_to_words(self, text):
108
+ return list(re.findall(self.pat, text))
109
+
110
+ def encode(self, text):
111
+ bpe_tokens = []
112
+ for token in self.split_to_words(text):
113
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
114
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
115
+ return bpe_tokens
116
+
117
+ def decode(self, tokens):
118
+ text = ''.join([self.decoder[token] for token in tokens])
119
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
120
+ return text
121
+
122
+ def get_encoder(encoder, vocab):
123
+ return Encoder(
124
+ encoder=encoder,
125
+ bpe_merges=vocab,
126
+ )
127
+
128
+ def _is_whitespace(char):
129
+ """Checks whether `chars` is a whitespace character."""
130
+ # \t, \n, and \r are technically contorl characters but we treat them
131
+ # as whitespace since they are generally considered as such.
132
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
133
+ return True
134
+ cat = unicodedata.category(char)
135
+ if cat == "Zs":
136
+ return True
137
+ return False
138
+
139
+ def _is_control(char):
140
+ """Checks whether `chars` is a control character."""
141
+ # These are technically control characters but we count them as whitespace
142
+ # characters.
143
+ if char == "\t" or char == "\n" or char == "\r":
144
+ return False
145
+ cat = unicodedata.category(char)
146
+ if cat.startswith("C"):
147
+ return True
148
+ return False
149
+
150
+ def _is_punctuation(char):
151
+ """Checks whether `chars` is a punctuation character."""
152
+ cp = ord(char)
153
+ # We treat all non-letter/number ASCII as punctuation.
154
+ # Characters such as "^", "$", and "`" are not in the Unicode
155
+ # Punctuation class but we treat them as punctuation anyways, for
156
+ # consistency.
157
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
158
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
159
+ return True
160
+ cat = unicodedata.category(char)
161
+ if cat.startswith("P"):
162
+ return True
163
+ return False
nlu/DeBERTa/deberta/gpt2_tokenizer.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copyright (c) Microsoft, Inc. 2020
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Author: penhe@microsoft.com
8
+ # Date: 01/15/2020
9
+ #
10
+
11
+ # This piece of code is derived from https://github.com/pytorch/fairseq/blob/master/fairseq/data/encoders/gpt2_bpe.py
12
+
13
+ import torch
14
+ import unicodedata
15
+ import os
16
+ from .gpt2_bpe_utils import get_encoder,_is_control,_is_whitespace,_is_punctuation
17
+ from .cache_utils import load_vocab
18
+
19
+ __all__ = ['GPT2Tokenizer']
20
+
21
+ class GPT2Tokenizer(object):
22
+ """ A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer
23
+
24
+ Args:
25
+
26
+ vocab_file (:obj:`str`, optional):
27
+ The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, \
28
+ e.g. "bpe_encoder", default: `None`.
29
+
30
+ If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file is a \
31
+ state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. \
32
+
33
+ The difference between our wrapped GPT2 tokenizer and RoBERTa wrapped tokenizer are,
34
+
35
+ - Special tokens, unlike `RoBERTa` which use `<s>`, `</s>` as the `start` token and `end` token of a sentence. We use `[CLS]` and `[SEP]` as the `start` and `end`\
36
+ token of input sentence which is the same as `BERT`.
37
+
38
+ - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264
39
+
40
+ do_lower_case (:obj:`bool`, optional):
41
+ Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**.
42
+
43
+ special_tokens (:obj:`list`, optional):
44
+ List of special tokens to be added to the end of the vocabulary.
45
+
46
+
47
+ """
48
+ def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None):
49
+ self.pad_token='[PAD]'
50
+ self.sep_token='[SEP]'
51
+ self.unk_token='[UNK]'
52
+ self.cls_token='[CLS]'
53
+
54
+ self.symbols = []
55
+ self.count = []
56
+ self.indices = {}
57
+ self.pad_token_id = self.add_symbol(self.pad_token)
58
+ self.cls_token_id = self.add_symbol(self.cls_token)
59
+ self.sep_token_id = self.add_symbol(self.sep_token)
60
+ self.unk_token_id = self.add_symbol(self.unk_token)
61
+
62
+ self.gpt2_encoder = torch.load(vocab_file)
63
+ self.bpe = get_encoder(self.gpt2_encoder['encoder'], self.gpt2_encoder['vocab'])
64
+ for w,n in self.gpt2_encoder['dict_map']:
65
+ self.add_symbol(w, n)
66
+
67
+ self.mask_token='[MASK]'
68
+ self.mask_id = self.add_symbol(self.mask_token)
69
+ self.special_tokens = ['[MASK]', '[SEP]', '[PAD]', '[UNK]', '[CLS]']
70
+ if special_tokens is not None:
71
+ for t in special_tokens:
72
+ self.add_special_token(t)
73
+
74
+ self.vocab = self.indices
75
+ self.ids_to_tokens = self.symbols
76
+
77
+ def tokenize(self, text):
78
+ """ Convert an input text to tokens.
79
+
80
+ Args:
81
+
82
+ text (:obj:`str`): input text to be tokenized.
83
+
84
+ Returns:
85
+ A list of byte tokens where each token represent the byte id in GPT2 byte dictionary
86
+
87
+ Example::
88
+
89
+ >>> tokenizer = GPT2Tokenizer()
90
+ >>> text = "Hello world!"
91
+ >>> tokens = tokenizer.tokenize(text)
92
+ >>> print(tokens)
93
+ ['15496', '995', '0']
94
+
95
+ """
96
+ bpe = self._encode(text)
97
+
98
+ return [t for t in bpe.split(' ') if t]
99
+
100
+ def convert_tokens_to_ids(self, tokens):
101
+ """ Convert list of tokens to ids.
102
+
103
+ Args:
104
+
105
+ tokens (:obj:`list<str>`): list of tokens
106
+
107
+ Returns:
108
+
109
+ List of ids
110
+ """
111
+
112
+ return [self.vocab[t] for t in tokens]
113
+
114
+ def convert_ids_to_tokens(self, ids):
115
+ """ Convert list of ids to tokens.
116
+
117
+ Args:
118
+
119
+ ids (:obj:`list<int>`): list of ids
120
+
121
+ Returns:
122
+
123
+ List of tokens
124
+ """
125
+
126
+ tokens = []
127
+ for i in ids:
128
+ tokens.append(self.ids_to_tokens[i])
129
+ return tokens
130
+
131
+ def split_to_words(self, text):
132
+ return self.bpe.split_to_words(text)
133
+
134
+ def decode(self, tokens):
135
+ """ Decode list of tokens to text strings.
136
+
137
+ Args:
138
+
139
+ tokens (:obj:`list<str>`): list of tokens.
140
+
141
+ Returns:
142
+
143
+ Text string corresponds to the input tokens.
144
+
145
+ Example::
146
+
147
+ >>> tokenizer = GPT2Tokenizer()
148
+ >>> text = "Hello world!"
149
+ >>> tokens = tokenizer.tokenize(text)
150
+ >>> print(tokens)
151
+ ['15496', '995', '0']
152
+
153
+ >>> tokenizer.decode(tokens)
154
+ 'Hello world!'
155
+
156
+ """
157
+ return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens])
158
+
159
+ def add_special_token(self, token):
160
+ """Adds a special token to the dictionary.
161
+
162
+ Args:
163
+ token (:obj:`str`): Tthe new token/word to be added to the vocabulary.
164
+
165
+ Returns:
166
+ The id of new token in the vocabulary.
167
+
168
+ """
169
+ self.special_tokens.append(token)
170
+ return self.add_symbol(token)
171
+
172
+ def part_of_whole_word(self, token, is_bos=False):
173
+ if is_bos:
174
+ return True
175
+ s = self._decode(token)
176
+ if (len(s)==1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0]))):
177
+ return False
178
+
179
+ return not s.startswith(' ')
180
+
181
+ def sym(self, id):
182
+ return self.ids_to_tokens[id]
183
+
184
+ def id(self, sym):
185
+ return self.vocab[sym]
186
+
187
+ def _encode(self, x: str) -> str:
188
+ return ' '.join(map(str, self.bpe.encode(x)))
189
+
190
+ def _decode(self, x: str) -> str:
191
+ return self.bpe.decode(map(int, x.split()))
192
+
193
+ def add_symbol(self, word, n=1):
194
+ """Adds a word to the dictionary.
195
+
196
+ Args:
197
+ word (:obj:`str`): Tthe new token/word to be added to the vocabulary.
198
+ n (int, optional): The frequency of the word.
199
+
200
+ Returns:
201
+ The id of the new word.
202
+
203
+ """
204
+ if word in self.indices:
205
+ idx = self.indices[word]
206
+ self.count[idx] = self.count[idx] + n
207
+ return idx
208
+ else:
209
+ idx = len(self.symbols)
210
+ self.indices[word] = idx
211
+ self.symbols.append(word)
212
+ self.count.append(n)
213
+ return idx
214
+
215
+ def save_pretrained(self, path: str):
216
+ torch.save(self.gpt2_encoder, path)
nlu/DeBERTa/deberta/mlm.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) Microsoft, Inc. 2020
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This piece of code is modified based on https://github.com/huggingface/transformers
8
+
9
+ import torch
10
+ from torch import nn
11
+ import pdb
12
+
13
+ from .bert import LayerNorm,ACT2FN
14
+
15
+ __all__ = ['MLMPredictionHead']
16
+
17
+ class MLMPredictionHead(nn.Module):
18
+ def __init__(self, config, vocab_size):
19
+ super().__init__()
20
+ self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
21
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
22
+ self.transform_act_fn = ACT2FN[config.hidden_act] \
23
+ if isinstance(config.hidden_act, str) else config.hidden_act
24
+
25
+ self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps)
26
+ self.bias = nn.Parameter(torch.zeros(vocab_size))
27
+ self.pre_norm = PreLayerNorm(config)
28
+
29
+ def forward(self, hidden_states, embeding_weight):
30
+ hidden_states = self.pre_norm(hidden_states)
31
+ hidden_states = self.dense(hidden_states)
32
+ hidden_states = self.transform_act_fn(hidden_states)
33
+ # b x s x d
34
+ hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
35
+
36
+ # b x s x v
37
+ logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias
38
+ return logits
nlu/DeBERTa/deberta/nnmodule.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import os
3
+ import torch
4
+ import copy
5
+ from torch import nn
6
+ from .config import ModelConfig
7
+ from ..utils import xtqdm as tqdm
8
+ from .cache_utils import load_model_state
9
+
10
+ from ..utils import get_logger
11
+ logger = get_logger()
12
+
13
+ __all__ = ['NNModule']
14
+
15
+ class NNModule(nn.Module):
16
+ """ An abstract class to handle weights initialization and \
17
+ a simple interface for dowloading and loading pretrained models.
18
+
19
+ Args:
20
+
21
+ config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module
22
+
23
+ """
24
+
25
+ def __init__(self, config, *inputs, **kwargs):
26
+ super().__init__()
27
+ self.config = config
28
+
29
+ def init_weights(self, module):
30
+ """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
31
+
32
+ Args:
33
+
34
+ module (:obj:`torch.nn.Module`): The module to apply the initialization.
35
+
36
+ Example::
37
+
38
+ class MyModule(NNModule):
39
+ def __init__(self, config):
40
+ # Add construction instructions
41
+ self.bert = DeBERTa(config)
42
+
43
+ # Add other modules
44
+ ...
45
+
46
+ # Apply initialization
47
+ self.apply(self.init_weights)
48
+
49
+ """
50
+ if isinstance(module, (nn.Linear, nn.Embedding)):
51
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
52
+ if isinstance(module, nn.Linear) and module.bias is not None:
53
+ module.bias.data.zero_()
54
+
55
+ def export_onnx(self, onnx_path, input):
56
+ raise NotImplementedError
57
+
58
+ @classmethod
59
+ def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
60
+ """ Instantiate a sub-class of NNModule from a pre-trained model file.
61
+
62
+ Args:
63
+
64
+ model_path (:obj:`str`): Path or name of the pre-trained model which can be either,
65
+
66
+ - The path of pre-trained model
67
+
68
+ - The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**].
69
+
70
+ If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models.
71
+
72
+ model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order:
73
+
74
+ 1. ['config'] in the model state dictionary.
75
+
76
+ 2. `model_config.json` aside the `model_path`.
77
+
78
+ If it failed to find a config the method will fail.
79
+
80
+ tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`.
81
+
82
+ no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`.
83
+
84
+ cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa`
85
+
86
+ Return:
87
+
88
+ :obj:`NNModule` : The sub-class object.
89
+
90
+ """
91
+ # Load config
92
+ if model_config:
93
+ config = ModelConfig.from_json_file(model_config)
94
+ else:
95
+ config = None
96
+ model_config = None
97
+ model_state = None
98
+ if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
99
+ model_path = None
100
+ try:
101
+ model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
102
+ except Exception as exp:
103
+ raise Exception(f'Failed to get model {model_path}. Exception: {exp}')
104
+
105
+ if config is not None and model_config is not None:
106
+ for k in config.__dict__:
107
+ if k not in ['hidden_size',
108
+ 'intermediate_size',
109
+ 'num_attention_heads',
110
+ 'num_hidden_layers',
111
+ 'vocab_size',
112
+ 'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0):
113
+ model_config.__dict__[k] = config.__dict__[k]
114
+ if model_config is not None:
115
+ config = copy.copy(model_config)
116
+ vocab_size = config.vocab_size
117
+ # Instantiate model.
118
+ model = cls(config, *inputs, **kwargs)
119
+ if not model_state:
120
+ return model
121
+ # copy state_dict so _load_from_state_dict can modify it
122
+ state_dict = model_state.copy()
123
+
124
+ missing_keys = []
125
+ unexpected_keys = []
126
+ error_msgs = []
127
+ metadata = getattr(state_dict, '_metadata', None)
128
+ def load(module, prefix=''):
129
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
130
+ module._load_from_state_dict(
131
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
132
+ for name, child in module._modules.items():
133
+ if child is not None:
134
+ load(child, prefix + name + '.')
135
+ load(model)
136
+ logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}')
137
+ return model
nlu/DeBERTa/deberta/ops.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Author: penhe@microsoft.com
7
+ # Date: 01/15/2020
8
+ #
9
+
10
+ import pdb
11
+ import math
12
+ from packaging import version
13
+ import torch
14
+ from torch.nn import LayerNorm
15
+ from ..utils.jit_tracing import traceable
16
+
17
+ if version.Version(torch.__version__) >= version.Version('1.0.0'):
18
+ from torch import _softmax_backward_data as _softmax_backward_data
19
+ else:
20
+ from torch import softmax_backward_data as _softmax_backward_data
21
+
22
+ __all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax', 'ACT2FN', 'LayerNorm']
23
+
24
+ @traceable
25
+ class XSoftmax(torch.autograd.Function):
26
+ """ Masked Softmax which is optimized for saving memory
27
+
28
+ Args:
29
+
30
+ input (:obj:`torch.tensor`): The input tensor that will apply softmax.
31
+ mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation.
32
+ dim (int): The dimenssion that will apply softmax.
33
+
34
+ Example::
35
+
36
+ import torch
37
+ from DeBERTa.deberta import XSoftmax
38
+ # Make a tensor
39
+ x = torch.randn([4,20,100])
40
+ # Create a mask
41
+ mask = (x>0).int()
42
+ y = XSoftmax.apply(x, mask, dim=-1)
43
+
44
+ """
45
+
46
+ @staticmethod
47
+ def forward(self, input, mask, dim):
48
+ """
49
+ """
50
+
51
+ self.dim = dim
52
+ if version.Version(torch.__version__) >= version.Version('1.2.0a'):
53
+ rmask = ~(mask.bool())
54
+ else:
55
+ rmask = (1-mask).byte() # This line is not supported by Onnx tracing.
56
+
57
+ output = input.masked_fill(rmask, float('-inf'))
58
+ output = torch.softmax(output, self.dim)
59
+ output.masked_fill_(rmask, 0)
60
+ self.save_for_backward(output)
61
+ return output
62
+
63
+ @staticmethod
64
+ def backward(self, grad_output):
65
+ """
66
+ """
67
+
68
+ output, = self.saved_tensors
69
+ if version.Version(torch.__version__) >= version.Version('1.11.0a'):
70
+ inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
71
+ else:
72
+ inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
73
+ return inputGrad, None, None
74
+
75
+ @staticmethod
76
+ def symbolic(g, self, mask, dim):
77
+ import torch.onnx.symbolic_helper as sym_help
78
+ from torch.onnx.symbolic_opset9 import masked_fill, softmax
79
+
80
+ mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx['Long'])
81
+ r_mask = g.op("Cast", g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx['Byte'])
82
+ output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float('-inf'))))
83
+ output = softmax(g, output, dim)
84
+ return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
85
+
86
+ class DropoutContext(object):
87
+ def __init__(self):
88
+ self.dropout = 0
89
+ self.mask = None
90
+ self.scale = 1
91
+ self.reuse_mask = True
92
+
93
+ def get_mask(input, local_context):
94
+ if not isinstance(local_context, DropoutContext):
95
+ dropout = local_context
96
+ mask = None
97
+ else:
98
+ dropout = local_context.dropout
99
+ dropout *= local_context.scale
100
+ mask = local_context.mask if local_context.reuse_mask else None
101
+
102
+ if dropout>0 and mask is None:
103
+ if version.Version(torch.__version__) >= version.Version('1.2.0a'):
104
+ mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool()
105
+ else:
106
+ mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte()
107
+
108
+ if isinstance(local_context, DropoutContext):
109
+ if local_context.mask is None:
110
+ local_context.mask = mask
111
+
112
+ return mask, dropout
113
+
114
+ @traceable
115
+ class XDropout(torch.autograd.Function):
116
+ @staticmethod
117
+ def forward(ctx, input, local_ctx):
118
+ mask, dropout = get_mask(input, local_ctx)
119
+ ctx.scale=1.0/(1-dropout)
120
+ if dropout>0:
121
+ ctx.save_for_backward(mask)
122
+ return input.masked_fill(mask, 0)*ctx.scale
123
+ else:
124
+ return input
125
+
126
+ @staticmethod
127
+ def backward(ctx, grad_output):
128
+ if ctx.scale > 1:
129
+ mask, = ctx.saved_tensors
130
+ return grad_output.masked_fill(mask, 0)*ctx.scale, None
131
+ else:
132
+ return grad_output, None
133
+
134
+ class StableDropout(torch.nn.Module):
135
+ """ Optimized dropout module for stabilizing the training
136
+
137
+ Args:
138
+
139
+ drop_prob (float): the dropout probabilities
140
+
141
+ """
142
+
143
+ def __init__(self, drop_prob):
144
+ super().__init__()
145
+ self.drop_prob = drop_prob
146
+ self.count = 0
147
+ self.context_stack = None
148
+
149
+ def forward(self, x):
150
+ """ Call the module
151
+
152
+ Args:
153
+
154
+ x (:obj:`torch.tensor`): The input tensor to apply dropout
155
+
156
+
157
+ """
158
+ if self.training and self.drop_prob>0:
159
+ return XDropout.apply(x, self.get_context())
160
+ return x
161
+
162
+ def clear_context(self):
163
+ self.count = 0
164
+ self.context_stack = None
165
+
166
+ def init_context(self, reuse_mask=True, scale = 1):
167
+ if self.context_stack is None:
168
+ self.context_stack = []
169
+ self.count = 0
170
+ for c in self.context_stack:
171
+ c.reuse_mask = reuse_mask
172
+ c.scale = scale
173
+
174
+ def get_context(self):
175
+ if self.context_stack is not None:
176
+ if self.count >= len(self.context_stack):
177
+ self.context_stack.append(DropoutContext())
178
+ ctx = self.context_stack[self.count]
179
+ ctx.dropout = self.drop_prob
180
+ self.count += 1
181
+ return ctx
182
+ else:
183
+ return self.drop_prob
184
+
185
+ def MaskedLayerNorm(layerNorm, input, mask = None):
186
+ """ Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module.
187
+
188
+ Args:
189
+ layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function
190
+ input (:obj:`torch.tensor`): The input tensor
191
+ mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0`
192
+
193
+ Example::
194
+
195
+ # Create a tensor b x n x d
196
+ x = torch.randn([1,10,100])
197
+ m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int)
198
+ LayerNorm = DeBERTa.deberta.LayerNorm(100)
199
+ y = MaskedLayerNorm(LayerNorm, x, m)
200
+
201
+ """
202
+ output = layerNorm(input).to(input)
203
+ if mask is None:
204
+ return output
205
+ if mask.dim()!=input.dim():
206
+ if mask.dim()==4:
207
+ mask=mask.squeeze(1).squeeze(1)
208
+ mask = mask.unsqueeze(2)
209
+ mask = mask.to(output.dtype)
210
+ return output*mask
211
+
212
+ def gelu(x):
213
+ """Implementation of the gelu activation function.
214
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
215
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
216
+ """
217
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
218
+
219
+
220
+ def swish(x):
221
+ return x * torch.sigmoid(x)
222
+
223
+ def linear_act(x):
224
+ return x
225
+
226
+ ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish, "tanh": torch.tanh, "linear": linear_act, 'sigmoid': torch.sigmoid}
227
+
228
+
nlu/DeBERTa/deberta/pooling.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Author: penhe@microsoft.com
3
+ # Date: 01/25/2019
4
+ #
5
+ """
6
+ Pooling functions
7
+ """
8
+
9
+ from torch import nn
10
+ import copy
11
+ import json
12
+ import pdb
13
+ from .bert import ACT2FN
14
+ from .ops import StableDropout
15
+ from .config import AbsModelConfig
16
+
17
+ __all__ = ['PoolConfig', 'ContextPooler']
18
+
19
+ class PoolConfig(AbsModelConfig):
20
+ """Configuration class to store the configuration of `pool layer`.
21
+
22
+ Parameters:
23
+
24
+ config (:class:`~DeBERTa.deberta.ModelConfig`): The model config. The field of pool config will be initalized with the `pooling` field in model config.
25
+
26
+ Attributes:
27
+
28
+ hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`.
29
+
30
+ dropout (float): The dropout rate applied on the output of `[CLS]` token,
31
+
32
+ hidden_act (:obj:`str`): The activation function of the projection layer, it can be one of ['gelu', 'tanh'].
33
+
34
+ Example::
35
+
36
+ # Here is the content of an exmple model config file in json format
37
+
38
+ {
39
+ "hidden_size": 768,
40
+ "num_hidden_layers" 12,
41
+ "num_attention_heads": 12,
42
+ "intermediate_size": 3072,
43
+ ...
44
+ "pooling": {
45
+ "hidden_size": 768,
46
+ "hidden_act": "gelu",
47
+ "dropout": 0.1
48
+ }
49
+ }
50
+
51
+ """
52
+ def __init__(self, config=None):
53
+ """Constructs PoolConfig.
54
+
55
+ Args:
56
+ `config`: the config of the model. The field of pool config will be initalized with the 'pooling' field in model config.
57
+ """
58
+
59
+ self.hidden_size = 768
60
+ self.dropout = 0
61
+ self.hidden_act = 'gelu'
62
+ if config:
63
+ pool_config = getattr(config, 'pooling', config)
64
+ if isinstance(pool_config, dict):
65
+ pool_config = AbsModelConfig.from_dict(pool_config)
66
+ self.hidden_size = getattr(pool_config, 'hidden_size', config.hidden_size)
67
+ self.dropout = getattr(pool_config, 'dropout', 0)
68
+ self.hidden_act = getattr(pool_config, 'hidden_act', 'gelu')
69
+
70
+ class ContextPooler(nn.Module):
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
74
+ self.dropout = StableDropout(config.dropout)
75
+ self.config = config
76
+
77
+ def forward(self, hidden_states, mask = None):
78
+ # We "pool" the model by simply taking the hidden state corresponding
79
+ # to the first token.
80
+
81
+ context_token = hidden_states[:, 0]
82
+ context_token = self.dropout(context_token)
83
+ pooled_output = self.dense(context_token)
84
+ pooled_output = ACT2FN[self.config.hidden_act](pooled_output)
85
+ return pooled_output
86
+
87
+ def output_dim(self):
88
+ return self.config.hidden_size
nlu/DeBERTa/deberta/pretrained_models.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
nlu/DeBERTa/deberta/spm_tokenizer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft, Inc. 2020
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Author: penhe@microsoft.com
7
+ # Date: 11/15/2020
8
+ #
9
+
10
+
11
+ import sentencepiece as sp
12
+ import six
13
+ import unicodedata
14
+ import os
15
+ import regex as re
16
+ from .cache_utils import load_vocab
17
+ from ..utils import get_logger
18
+ logger=get_logger()
19
+
20
+
21
+ import pdb
22
+
23
+ __all__ = ['SPMTokenizer']
24
+
25
+ class SPMTokenizer:
26
+ def __init__(self, vocab_file, do_lower_case=False, special_tokens=None, bpe_dropout=0, split_by_punct=False):
27
+ self.split_by_punct = split_by_punct
28
+ spm = sp.SentencePieceProcessor()
29
+ assert os.path.exists(vocab_file)
30
+ spm.load(vocab_file)
31
+ bpe_vocab_size = spm.GetPieceSize()
32
+ # Token map
33
+ # <unk> 0+1
34
+ # <s> 1+1
35
+ # </s> 2+1
36
+ self.vocab = {spm.IdToPiece(i):i for i in range(bpe_vocab_size)}
37
+ self.id_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)]
38
+ #self.vocab['[PAD]'] = 0
39
+ #self.vocab['[CLS]'] = 1
40
+ #self.vocab['[SEP]'] = 2
41
+ #self.vocab['[UNK]'] = 3
42
+
43
+ _special_tokens = ['[MASK]', '[SEP]', '[PAD]', '[UNK]', '[CLS]']
44
+ self.special_tokens = []
45
+ if special_tokens is not None:
46
+ _special_tokens.extend(special_tokens)
47
+ for t in _special_tokens:
48
+ self.add_special_token(t)
49
+
50
+ self.spm = spm
51
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
52
+
53
+ def tokenize(self, text):
54
+ pieces = self._encode_as_pieces(text)
55
+ def _norm(x):
56
+ if x not in self.vocab or x=='<unk>':
57
+ return '[UNK]'
58
+ else:
59
+ return x
60
+ pieces = [_norm(p) for p in pieces]
61
+ return pieces
62
+
63
+ def convert_tokens_to_ids(self, tokens):
64
+ return [self.vocab[t] if t in self.vocab else 1 for t in tokens]
65
+
66
+ def convert_ids_to_tokens(self, ids):
67
+ tokens = []
68
+ for i in ids:
69
+ tokens.append(self.ids_to_tokens[i])
70
+ return tokens
71
+
72
+ def decode(self, tokens, start=-1, end=-1, raw_text=None):
73
+ if raw_text is None:
74
+ return self.spm.decode_pieces([t for t in tokens if t not in self.special_tokens])
75
+ else:
76
+ words = self.split_to_words(raw_text)
77
+ word_tokens = [self.tokenize(w) for w in words]
78
+ wt = [w for t in word_tokens for w in t]
79
+ #assert tokens == wt, f'{tokens} || {wt}'
80
+ if wt!=tokens:
81
+ for a,b in zip(wt, tokens):
82
+ if a!=b:
83
+ pdb.set_trace()
84
+ token2words = [0]*len(tokens)
85
+ tid = 0
86
+ for i,w in enumerate(word_tokens):
87
+ for k,t in enumerate(w):
88
+ token2words[tid] = i
89
+ tid += 1
90
+ word_start = token2words[start]
91
+ word_end = token2words[end] if end <len(tokens) else len(words)
92
+ text = ''.join(words[word_start:word_end])
93
+ return text
94
+
95
+ def add_special_token(self, token):
96
+ if token not in self.special_tokens:
97
+ self.special_tokens.append(token)
98
+ if token not in self.vocab:
99
+ self.vocab[token] = len(self.vocab)
100
+ self.id_to_tokens.append(token)
101
+ return self.id(token)
102
+
103
+ def part_of_whole_word(self, token, is_bos=False):
104
+ if is_bos:
105
+ return True
106
+ if (len(token)==1 and (_is_whitespace(list(token)[0]) or _is_control(list(token)[0]) or _is_punctuation(list(token)[0]))) or token in self.special_tokens:
107
+ return False
108
+
109
+ word_start = b'\xe2\x96\x81'.decode('utf-8')
110
+ return not token.startswith(word_start)
111
+
112
+ def pad(self):
113
+ return '[PAD]'
114
+
115
+ def bos(self):
116
+ return '[CLS]'
117
+
118
+ def eos(self):
119
+ return '[SEP]'
120
+
121
+ def unk(self):
122
+ return '[UNK]'
123
+
124
+ def mask(self):
125
+ return '[MASK]'
126
+
127
+ def sym(self, id):
128
+ return self.ids_to_tokens[id]
129
+
130
+ def id(self, sym):
131
+ return self.vocab[sym] if sym in self.vocab else 1
132
+
133
+ def _encode_as_pieces(self, text):
134
+ text = convert_to_unicode(text)
135
+ if self.split_by_punct:
136
+ words = self._run_split_on_punc(text)
137
+ pieces = [self.spm.encode_as_pieces(w) for w in words]
138
+ return [p for w in pieces for p in w]
139
+ else:
140
+ return self.spm.encode_as_pieces(text)
141
+
142
+ def split_to_words(self, text):
143
+ pieces = self._encode_as_pieces(text)
144
+ word_start = b'\xe2\x96\x81'.decode('utf-8')
145
+ words = []
146
+ offset = 0
147
+ prev_end = 0
148
+ for i,p in enumerate(pieces):
149
+ if p.startswith(word_start):
150
+ if offset>prev_end:
151
+ words.append(text[prev_end:offset])
152
+ prev_end = offset
153
+ w = p.replace(word_start, '')
154
+ else:
155
+ w = p
156
+ try:
157
+ s = text.index(w, offset)
158
+ pn = ""
159
+ k = i+1
160
+ while k < len(pieces):
161
+ pn = pieces[k].replace(word_start, '')
162
+ if len(pn)>0:
163
+ break
164
+ k += 1
165
+
166
+ if len(pn)>0 and pn in text[offset:s]:
167
+ offset = offset + 1
168
+ else:
169
+ offset = s + len(w)
170
+ except:
171
+ offset = offset + 1
172
+
173
+ if prev_end< offset:
174
+ words.append(text[prev_end:offset])
175
+
176
+ return words
177
+
178
+ def _run_strip_accents(self, text):
179
+ """Strips accents from a piece of text."""
180
+ text = unicodedata.normalize("NFD", text)
181
+ output = []
182
+ for char in text:
183
+ cat = unicodedata.category(char)
184
+ if cat == "Mn":
185
+ continue
186
+ output.append(char)
187
+ return "".join(output)
188
+
189
+ def _run_split_on_punc(self, text):
190
+ """Splits punctuation on a piece of text."""
191
+ #words = list(re.findall(self.pat, text))
192
+ chars = list(text)
193
+ i = 0
194
+ start_new_word = True
195
+ output = []
196
+ while i < len(chars):
197
+ char = chars[i]
198
+ if _is_punctuation(char):
199
+ output.append([char])
200
+ start_new_word = True
201
+ else:
202
+ if start_new_word:
203
+ output.append([])
204
+ start_new_word = False
205
+ output[-1].append(char)
206
+ i += 1
207
+
208
+ return ["".join(x) for x in output]
209
+
210
+ def _tokenize_chinese_chars(self, text):
211
+ """Adds whitespace around any CJK character."""
212
+ output = []
213
+ for char in text:
214
+ cp = ord(char)
215
+ if self._is_chinese_char(cp):
216
+ output.append(" ")
217
+ output.append(char)
218
+ output.append(" ")
219
+ else:
220
+ output.append(char)
221
+ return "".join(output)
222
+
223
+ def _is_chinese_char(self, cp):
224
+ """Checks whether CP is the codepoint of a CJK character."""
225
+ # This defines a "chinese character" as anything in the CJK Unicode block:
226
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
227
+ #
228
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
229
+ # despite its name. The modern Korean Hangul alphabet is a different block,
230
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
231
+ # space-separated words, so they are not treated specially and handled
232
+ # like the all of the other languages.
233
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
234
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
235
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
236
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
237
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
238
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
239
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
240
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
241
+ return True
242
+
243
+ return False
244
+
245
+ def _clean_text(self, text):
246
+ """Performs invalid character removal and whitespace cleanup on text."""
247
+ output = []
248
+ for char in text:
249
+ cp = ord(char)
250
+ if cp == 0 or cp == 0xfffd or _is_control(char):
251
+ continue
252
+ if _is_whitespace(char):
253
+ output.append(" ")
254
+ else:
255
+ output.append(char)
256
+ return "".join(output)
257
+
258
+
259
+ def _is_whitespace(char):
260
+ """Checks whether `chars` is a whitespace character."""
261
+ # \t, \n, and \r are technically contorl characters but we treat them
262
+ # as whitespace since they are generally considered as such.
263
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
264
+ return True
265
+ cat = unicodedata.category(char)
266
+ if cat == "Zs":
267
+ return True
268
+ return False
269
+
270
+ def _is_control(char):
271
+ """Checks whether `chars` is a control character."""
272
+ # These are technically control characters but we count them as whitespace
273
+ # characters.
274
+ if char == "\t" or char == "\n" or char == "\r":
275
+ return False
276
+ cat = unicodedata.category(char)
277
+ if cat.startswith("C"):
278
+ return True
279
+ return False
280
+
281
+ def _is_punctuation(char):
282
+ """Checks whether `chars` is a punctuation character."""
283
+ cp = ord(char)
284
+ # We treat all non-letter/number ASCII as punctuation.
285
+ # Characters such as "^", "$", and "`" are not in the Unicode
286
+ # Punctuation class but we treat them as punctuation anyways, for
287
+ # consistency.
288
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
289
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
290
+ return True
291
+ cat = unicodedata.category(char)
292
+ if cat.startswith("P"):
293
+ return True
294
+ return False
295
+
296
+ def whitespace_tokenize(text):
297
+ """Runs basic whitespace cleaning and splitting on a peice of text."""
298
+ text = text.strip()
299
+ if not text:
300
+ return []
301
+ tokens = text.split()
302
+ return tokens
303
+
304
+ def convert_to_unicode(text):
305
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
306
+ if six.PY3:
307
+ if isinstance(text, str):
308
+ return text
309
+ elif isinstance(text, bytes):
310
+ return text.decode("utf-8", "ignore")
311
+ else:
312
+ raise ValueError("Unsupported string type: %s" % (type(text)))
313
+ elif six.PY2:
314
+ if isinstance(text, str):
315
+ return text.decode("utf-8", "ignore")
316
+ elif isinstance(text, unicode):
317
+ return text
318
+ else:
319
+ raise ValueError("Unsupported string type: %s" % (type(text)))
320
+ else:
321
+ raise ValueError("Not running on Python2 or Python 3?")
322
+
nlu/DeBERTa/deberta/tokenizers.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Author: penhe@microsoft.com
3
+ # Date: 04/25/2019
4
+ #
5
+
6
+ """ tokenizers
7
+ """
8
+
9
+ from .spm_tokenizer import *
10
+ from .gpt2_tokenizer import GPT2Tokenizer
11
+
12
+ __all__ = ['tokenizers']
13
+ tokenizers={
14
+ 'gpt2': GPT2Tokenizer,
15
+ 'spm': SPMTokenizer
16
+ }
nlu/DeBERTa/optims/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ #
5
+ # Author: Pengcheng He (penhe@microsoft.com)
6
+ # Date: 05/15/2019
7
+ #
8
+
9
+ """ optimizers
10
+ """
11
+
12
+ from .xadam import XAdam
13
+ from .fp16_optimizer import *
14
+ from .lr_schedulers import SCHEDULES
15
+ from .args import get_args
16
+
nlu/DeBERTa/optims/args.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ #
5
+ # Author: Pengcheng He (penhe@microsoft.com)
6
+ # Date: 05/15/2019
7
+ #
8
+
9
+ """ Arguments for optimizer
10
+ """
11
+ import argparse
12
+ from ..utils import boolean_string
13
+
14
+ __all__ = ['get_args']
15
+ def get_args():
16
+ parser=argparse.ArgumentParser(add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17
+ group = parser.add_argument_group(title='Optimizer', description='Parameters for the distributed optimizer')
18
+ group.add_argument('--fp16',
19
+ default=False,
20
+ type=boolean_string,
21
+ help="Whether to use 16-bit float precision instead of 32-bit")
22
+
23
+ group.add_argument('--loss_scale',
24
+ type=float, default=16384,
25
+ help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
26
+
27
+ group.add_argument('--scale_steps',
28
+ type=int, default=250,
29
+ help='The steps to wait to increase the loss scale.')
30
+
31
+ group.add_argument('--lookahead_k',
32
+ default=-1,
33
+ type=int,
34
+ help="lookahead k parameter")
35
+
36
+ group.add_argument('--lookahead_alpha',
37
+ default=0.5,
38
+ type=float,
39
+ help="lookahead alpha parameter")
40
+
41
+ group.add_argument('--with_radam',
42
+ default=False,
43
+ type=boolean_string,
44
+ help="whether to use RAdam")
45
+
46
+ group.add_argument('--opt_type',
47
+ type=str.lower,
48
+ default='adam',
49
+ choices=['adam', 'admax'],
50
+ help="The optimizer to be used.")
51
+
52
+ group.add_argument("--warmup_proportion",
53
+ default=0.1,
54
+ type=float,
55
+ help="Proportion of training to perform linear learning rate warmup for. "
56
+ "E.g., 0.1 = 10%% of training.")
57
+
58
+ group.add_argument("--lr_schedule_ends",
59
+ default=0,
60
+ type=float,
61
+ help="The ended learning rate scale for learning rate scheduling")
62
+
63
+ group.add_argument("--lr_schedule",
64
+ default='warmup_linear',
65
+ type=str,
66
+ help="The learning rate scheduler used for traning. " +
67
+ "E.g. warmup_linear, warmup_linear_shift, warmup_cosine, warmup_constant. Default, warmup_linear")
68
+
69
+ group.add_argument("--max_grad_norm",
70
+ default=1,
71
+ type=float,
72
+ help="The clip threshold of global gradient norm")
73
+
74
+ group.add_argument("--learning_rate",
75
+ default=5e-5,
76
+ type=float,
77
+ help="The initial learning rate for Adam.")
78
+
79
+ group.add_argument("--epsilon",
80
+ default=1e-6,
81
+ type=float,
82
+ help="epsilon setting for Adam.")
83
+
84
+ group.add_argument("--adam_beta1",
85
+ default=0.9,
86
+ type=float,
87
+ help="The beta1 parameter for Adam.")
88
+
89
+ group.add_argument("--adam_beta2",
90
+ default=0.999,
91
+ type=float,
92
+ help="The beta2 parameter for Adam.")
93
+
94
+ group.add_argument('--weight_decay',
95
+ type=float,
96
+ default=0.01,
97
+ help="The weight decay rate")
98
+
99
+ return parser
100
+
nlu/DeBERTa/optims/fp16_optimizer.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ #
5
+ # Author: Pengcheng He (penhe@microsoft.com)
6
+ # Date: 05/15/2019
7
+ #
8
+
9
+ """ FP16 optimizer wrapper
10
+ """
11
+
12
+ from collections import defaultdict
13
+ import numpy as np
14
+ import math
15
+ import torch
16
+ import pdb
17
+ import torch.distributed as dist
18
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
19
+ import ctypes
20
+
21
+ from ..utils import get_logger,boolean_string
22
+ logger=get_logger()
23
+
24
+ __all__ = ['Fp16Optimizer', 'ExpLossScaler', 'get_world_size']
25
+
26
+ def get_world_size():
27
+ try:
28
+ wd = dist.get_world_size()
29
+ return wd
30
+ except:
31
+ return 1
32
+
33
+ def fused_norm(input):
34
+ return torch.norm(input, p=2, dtype=torch.float32)
35
+
36
+ class OptParameter(torch.Tensor):
37
+ def __new__(cls, data, out_data=None, grad=None, name=None):
38
+ param = torch.Tensor._make_subclass(cls, data)
39
+ param._xgrad = grad
40
+ param.out_data = out_data
41
+ param._name = name
42
+ return param
43
+
44
+ @property
45
+ def name(self):
46
+ return self._name
47
+
48
+ @property
49
+ def grad(self):
50
+ return self._xgrad
51
+
52
+ @grad.setter
53
+ def grad(self, grad):
54
+ self._xgrad = grad
55
+
56
+ class Fp16Optimizer(object):
57
+ def __init__(self, param_groups, optimizer_fn, loss_scaler=None, grad_clip_norm = 1.0, lookahead_k = -1, lookahead_alpha = 0.5, rank=-1, distributed=False):
58
+ # all parameters should on the same device
59
+ groups = []
60
+ original_groups = []
61
+ self.rank = rank
62
+ self.distributed = distributed
63
+ if self.rank<0:
64
+ self.distributed = False
65
+ for group in param_groups:
66
+ if 'offset' not in group:
67
+ group['offset'] = None
68
+ if ('rank' not in group) or (not self.distributed):
69
+ group['rank'] = -1
70
+ assert group['offset'] is None, f"{group['names']}: {group['offset']}"
71
+ group_rank = group['rank']
72
+ params = group['params'] # parameter
73
+ if len(params) > 1:
74
+ flattened_params = _flatten_dense_tensors([p.data for p in params])
75
+ unflattend_params = _unflatten_dense_tensors(flattened_params, [p.data for p in params])
76
+ for uf,p in zip(unflattend_params, params):
77
+ p.data = uf
78
+ else:
79
+ flattened_params = params[0].data.view(-1)
80
+ if group['offset'] is not None:
81
+ start, length = group['offset']
82
+ flattened_params = flattened_params.narrow(0, start, length)
83
+
84
+ if params[0].dtype==torch.half:
85
+ if self.rank == group_rank or (not self.distributed):
86
+ master_params = flattened_params.clone().to(torch.float).detach_().to(flattened_params.device)
87
+ else:
88
+ master_params = flattened_params.clone().to(torch.float).detach_().cpu()
89
+ group['params'] = [OptParameter(master_params, flattened_params, name='master')]
90
+ else:
91
+ group['params'] = [OptParameter(flattened_params, None, name='master')]
92
+
93
+ o_group = defaultdict(list)
94
+ o_group['names'] = group['names']
95
+ o_group['params'] = params
96
+ o_group['rank'] = group_rank
97
+ o_group['offset'] = group['offset']
98
+
99
+ group['names'] = ['master']
100
+
101
+ original_groups.append(o_group)
102
+ groups.append(group)
103
+ self.param_groups = groups
104
+ self.loss_scaler = loss_scaler
105
+ self.optimizer = optimizer_fn(self.param_groups)
106
+ self.original_param_groups = original_groups
107
+ self.max_grad_norm = grad_clip_norm
108
+ self.lookahead_k = lookahead_k
109
+ self.lookahead_alpha = lookahead_alpha
110
+
111
+ def backward(self, loss):
112
+ if self.loss_scaler:
113
+ loss_scale, loss, step_loss = self.loss_scaler.scale(loss)
114
+ else:
115
+ loss_scale = 1
116
+ step_loss = loss.item()
117
+
118
+ loss.backward()
119
+ return loss_scale, step_loss
120
+
121
+ def step(self, lr_scale, loss_scale = 1):
122
+ grad_scale = self._grad_scale(loss_scale)
123
+ if grad_scale is None or math.isinf(grad_scale):
124
+ self.loss_scaler.update(False)
125
+ return False
126
+
127
+ if self.lookahead_k > 0:
128
+ for p in self.param_groups:
129
+ if 'la_count' not in p:
130
+ # init
131
+ #make old copy
132
+ p['la_count'] = 0
133
+ p['slow_params'] = [x.data.detach().clone().requires_grad_(False) for x in p['params']]
134
+ self.optimizer.step(grad_scale, lr_scale)
135
+
136
+ # for group in self.param_groups:
137
+ # for p in group['params']:
138
+ # # p.data : master fp32
139
+ # # p.out_data : fp16 tensor backing model nn.Parameters
140
+ # if hasattr(p, 'out_data') and p.out_data is not None:
141
+ # p.out_data.copy_(p.data, non_blocking=True)
142
+
143
+ if self.lookahead_k > 0:
144
+ for p in self.param_groups:
145
+ p['la_count'] += 1
146
+ if p['la_count'] == self.lookahead_k:
147
+ p['la_count'] = 0
148
+ for s,f in zip(p['slow_params'], p['params']):
149
+ s.mul_(1-self.lookahead_alpha)
150
+ s.add_(f.data.detach()*self.lookahead_alpha)
151
+ f.data.copy_(s, non_blocking=True)
152
+ if hasattr(f, 'out_data') and f.out_data is not None:
153
+ f.out_data.copy_(f.data, non_blocking=True)
154
+
155
+ if self.loss_scaler:
156
+ self.loss_scaler.update(True)
157
+ return True
158
+
159
+ def zero_grad(self):
160
+ for group, o_group in zip(self.param_groups, self.original_param_groups):
161
+ for p in group['params']:
162
+ p.grad = None
163
+ for p in o_group['params']:
164
+ p.grad = None
165
+
166
+ def _grad_scale(self, loss_scale = 1):
167
+ named_params = {}
168
+ named_grads = {}
169
+ for g in self.original_param_groups:
170
+ for n,p in zip(g['names'], g['params']):
171
+ named_params[n] = p
172
+ named_grads[n] = p.grad if p.grad is not None else torch.zeros_like(p.data)
173
+
174
+ wd = get_world_size()
175
+ def _reduce(group):
176
+ grads = [named_grads[n] for n in group]
177
+ if len(grads)>1:
178
+ flattened_grads = _flatten_dense_tensors(grads)
179
+ else:
180
+ flattened_grads = grads[0],view(-1)
181
+
182
+ if wd > 1:
183
+ flattened_grads /= wd
184
+ handle = dist.all_reduce(flattened_grads, async_op=True)
185
+ else:
186
+ handle = None
187
+ return flattened_grads, handle
188
+
189
+ def _process_grad(group, flattened_grads, max_grad, norm):
190
+ grads = [named_grads[n] for n in group]
191
+ norm = norm.to(flattened_grads.device)
192
+ norm = norm + fused_norm(flattened_grads)**2
193
+
194
+ if len(grads) > 1:
195
+ unflattend_grads = _unflatten_dense_tensors(flattened_grads, grads)
196
+ else:
197
+ unflattend_grads = [flattened_grads]
198
+
199
+ for n,ug in zip(group, unflattend_grads):
200
+ named_grads[n] = ug #.to(named_params[n].data)
201
+
202
+ return max_grad, norm
203
+
204
+ group_size = 0
205
+ group = []
206
+ max_size = 32*1024*1024
207
+ norm = torch.zeros(1, dtype=torch.float)
208
+ max_grad = 0
209
+
210
+ all_grads = []
211
+ for name in sorted(named_params.keys(), key=lambda x:x.replace('deberta.', 'bert.')):
212
+ group.append(name)
213
+ group_size += named_params[name].data.numel()
214
+ if group_size>=max_size:
215
+ flatten, handle = _reduce(group)
216
+ all_grads.append([handle, flatten, group])
217
+ group = []
218
+ group_size = 0
219
+ if group_size>0:
220
+ flatten, handle = _reduce(group)
221
+ all_grads.append([handle, flatten, group])
222
+ group = []
223
+ group_size = 0
224
+ for h,fg,group in all_grads:
225
+ if h is not None:
226
+ h.wait()
227
+ max_grad, norm = _process_grad(group, fg, max_grad, norm)
228
+
229
+ norm = norm**0.5
230
+ if torch.isnan(norm) or torch.isinf(norm) :#in ['-inf', 'inf', 'nan']:
231
+ return None
232
+
233
+ scaled_norm = norm.detach().item()/loss_scale
234
+ grad_scale = loss_scale
235
+
236
+ if self.max_grad_norm>0:
237
+ scale = norm/(loss_scale*self.max_grad_norm)
238
+ if scale>1:
239
+ grad_scale *= scale
240
+
241
+ for group, o_g in zip(self.param_groups, self.original_param_groups):
242
+ grads = [named_grads[n] for n in o_g['names']]
243
+
244
+ if len(grads) > 1:
245
+ flattened_grads = _flatten_dense_tensors(grads)
246
+ else:
247
+ flattened_grads = grads[0].view(-1)
248
+ if group['offset'] is not None:
249
+ start, length = group['offset']
250
+ flattened_grads = flattened_grads.narrow(0, start, length)
251
+ if group['rank'] == self.rank or (not self.distributed):
252
+ group['params'][0].grad = flattened_grads
253
+
254
+ return grad_scale
255
+
256
+ class ExpLossScaler:
257
+ def __init__(self, init_scale=2**16, scale_interval=1000):
258
+ self.cur_scale = init_scale
259
+ self.scale_interval = scale_interval
260
+ self.invalid_cnt = 0
261
+ self.last_scale = 0
262
+ self.steps = 0
263
+ self.down_scale_smooth = 0
264
+
265
+ def scale(self, loss):
266
+ assert self.cur_scale > 0, self.init_scale
267
+ step_loss = loss.float().detach().item()
268
+ if step_loss != 0 and math.isfinite(step_loss):
269
+ loss_scale = self.cur_scale
270
+ else:
271
+ loss_scale = 1
272
+ loss = loss.float()*loss_scale
273
+ return (loss_scale, loss, step_loss)
274
+
275
+ def update(self, is_valid = True):
276
+ if not is_valid:
277
+ self.invalid_cnt += 1
278
+ if self.invalid_cnt>self.down_scale_smooth:
279
+ self.cur_scale /= 2
280
+ self.cur_scale = max(self.cur_scale, 1)
281
+ self.last_scale = self.steps
282
+ else:
283
+ self.invalid_cnt = 0
284
+ if self.steps - self.last_scale>self.scale_interval:
285
+ self.cur_scale *= 2
286
+ self.last_scale = self.steps
287
+ self.steps += 1
288
+
289
+ def state_dict(self):
290
+ state = defaultdict(float)
291
+ state['steps'] = self.steps
292
+ state['invalid_cnt'] = self.invalid_cnt
293
+ state['cur_scale'] = self.cur_scale
294
+ state['last_scale'] = self.last_scale
295
+ return state
296
+
297
+ def load_state_dict(self, state):
298
+ self.steps = state['steps']
299
+ self.invalid_cnt = state['invalid_cnt']
300
+ self.cur_scale = state['cur_scale']
301
+ self.last_scale = state['last_scale']