camenduru commited on
Commit
e828767
·
1 Parent(s): 4f94549

thanks to NVIDIA ❤

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +34 -0
  2. .github/ISSUE_TEMPLATE/bug_report.md +23 -0
  3. .gitignore +147 -0
  4. .gitmodules +7 -0
  5. .nojekyll +0 -0
  6. LICENSE +11 -0
  7. README.md +182 -0
  8. apex.egg-info/PKG-INFO +9 -0
  9. apex.egg-info/SOURCES.txt +229 -0
  10. apex.egg-info/dependency_links.txt +1 -0
  11. apex.egg-info/requires.txt +1 -0
  12. apex.egg-info/top_level.txt +15 -0
  13. apex/RNN/README.md +3 -0
  14. apex/RNN/RNNBackend.py +365 -0
  15. apex/RNN/__init__.py +3 -0
  16. apex/RNN/cells.py +84 -0
  17. apex/RNN/models.py +56 -0
  18. apex/__init__.py +68 -0
  19. apex/_autocast_utils.py +26 -0
  20. apex/amp/README.md +72 -0
  21. apex/amp/__init__.py +5 -0
  22. apex/amp/__version__.py +2 -0
  23. apex/amp/_amp_state.py +59 -0
  24. apex/amp/_initialize.py +265 -0
  25. apex/amp/_process_optimizer.py +489 -0
  26. apex/amp/amp.py +183 -0
  27. apex/amp/compat.py +46 -0
  28. apex/amp/frontend.py +446 -0
  29. apex/amp/handle.py +281 -0
  30. apex/amp/lists/__init__.py +0 -0
  31. apex/amp/lists/functional_overrides.py +80 -0
  32. apex/amp/lists/tensor_overrides.py +63 -0
  33. apex/amp/lists/torch_overrides.py +115 -0
  34. apex/amp/opt.py +103 -0
  35. apex/amp/rnn_compat.py +53 -0
  36. apex/amp/scaler.py +217 -0
  37. apex/amp/utils.py +210 -0
  38. apex/amp/wrap.py +276 -0
  39. apex/contrib/__init__.py +0 -0
  40. apex/contrib/bottleneck/__init__.py +2 -0
  41. apex/contrib/bottleneck/bottleneck.py +749 -0
  42. apex/contrib/bottleneck/halo_exchangers.py +180 -0
  43. apex/contrib/bottleneck/test.py +71 -0
  44. apex/contrib/clip_grad/__init__.py +1 -0
  45. apex/contrib/clip_grad/clip_grad.py +128 -0
  46. apex/contrib/conv_bias_relu/__init__.py +2 -0
  47. apex/contrib/conv_bias_relu/conv_bias_relu.py +81 -0
  48. apex/contrib/csrc/bottleneck/bottleneck.cpp +0 -0
  49. apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp +1639 -0
  50. apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp +131 -0
.gitattributes CHANGED
@@ -32,3 +32,37 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ build/lib.linux-x86_64-3.10/amp_C.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
36
+ build/lib.linux-x86_64-3.10/apex_C.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
+ build/lib.linux-x86_64-3.10/distributed_adam_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
38
+ build/lib.linux-x86_64-3.10/fast_layer_norm.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
39
+ build/lib.linux-x86_64-3.10/fused_adam_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
40
+ build/lib.linux-x86_64-3.10/fused_dense_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
41
+ build/lib.linux-x86_64-3.10/fused_layer_norm_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
42
+ build/lib.linux-x86_64-3.10/fused_weight_gradient_mlp_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
43
+ build/lib.linux-x86_64-3.10/generic_scaled_masked_softmax_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
44
+ build/lib.linux-x86_64-3.10/mlp_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
45
+ build/lib.linux-x86_64-3.10/scaled_masked_softmax_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
46
+ build/lib.linux-x86_64-3.10/scaled_softmax_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
47
+ build/lib.linux-x86_64-3.10/scaled_upper_triang_masked_softmax_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
48
+ build/lib.linux-x86_64-3.10/syncbn.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
49
+ build/temp.linux-x86_64-3.10/apex/contrib/csrc/layer_norm/ln_api.o filter=lfs diff=lfs merge=lfs -text
50
+ build/temp.linux-x86_64-3.10/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.o filter=lfs diff=lfs merge=lfs -text
51
+ build/temp.linux-x86_64-3.10/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.o filter=lfs diff=lfs merge=lfs -text
52
+ build/temp.linux-x86_64-3.10/apex/contrib/csrc/optimizers/fused_adam_cuda.o filter=lfs diff=lfs merge=lfs -text
53
+ build/temp.linux-x86_64-3.10/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.o filter=lfs diff=lfs merge=lfs -text
54
+ build/temp.linux-x86_64-3.10/csrc/amp_C_frontend.o filter=lfs diff=lfs merge=lfs -text
55
+ build/temp.linux-x86_64-3.10/csrc/flatten_unflatten.o filter=lfs diff=lfs merge=lfs -text
56
+ build/temp.linux-x86_64-3.10/csrc/fused_dense.o filter=lfs diff=lfs merge=lfs -text
57
+ build/temp.linux-x86_64-3.10/csrc/layer_norm_cuda.o filter=lfs diff=lfs merge=lfs -text
58
+ build/temp.linux-x86_64-3.10/csrc/layer_norm_cuda_kernel.o filter=lfs diff=lfs merge=lfs -text
59
+ build/temp.linux-x86_64-3.10/csrc/megatron/fused_weight_gradient_dense.o filter=lfs diff=lfs merge=lfs -text
60
+ build/temp.linux-x86_64-3.10/csrc/megatron/generic_scaled_masked_softmax.o filter=lfs diff=lfs merge=lfs -text
61
+ build/temp.linux-x86_64-3.10/csrc/megatron/scaled_masked_softmax.o filter=lfs diff=lfs merge=lfs -text
62
+ build/temp.linux-x86_64-3.10/csrc/megatron/scaled_masked_softmax_cuda.o filter=lfs diff=lfs merge=lfs -text
63
+ build/temp.linux-x86_64-3.10/csrc/megatron/scaled_softmax.o filter=lfs diff=lfs merge=lfs -text
64
+ build/temp.linux-x86_64-3.10/csrc/megatron/scaled_softmax_cuda.o filter=lfs diff=lfs merge=lfs -text
65
+ build/temp.linux-x86_64-3.10/csrc/megatron/scaled_upper_triang_masked_softmax.o filter=lfs diff=lfs merge=lfs -text
66
+ build/temp.linux-x86_64-3.10/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.o filter=lfs diff=lfs merge=lfs -text
67
+ build/temp.linux-x86_64-3.10/csrc/mlp.o filter=lfs diff=lfs merge=lfs -text
68
+ build/temp.linux-x86_64-3.10/csrc/syncbn.o filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve apex
4
+ title: ''
5
+ labels: bug
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the Bug**
11
+
12
+ **Minimal Steps/Code to Reproduce the Bug**
13
+ <!--
14
+ Please list the *minimal* steps or provide a code snippet for us to be able to reproduce the bug.
15
+
16
+ A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
17
+ -->
18
+
19
+ **Expected Behavior**
20
+ <!-- A clear and concise description of what you expected to happen. -->
21
+
22
+ **Environment**
23
+ <!-- OS, version of Python, CUDA, PyTorch; collect these via `python -m torch.utils.collect_env` -->
.gitignore ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apex.egg-info
2
+ dist
3
+ build
4
+ docs/build
5
+ *~
6
+ __pycache__
7
+ .vscode
8
+
9
+ # Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+ cover/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ .pybuilder/
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ # For a library or package, you might want to ignore these files since the code is
96
+ # intended to run in multiple environments; otherwise, check them in:
97
+ # .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107
+ __pypackages__/
108
+
109
+ # Celery stuff
110
+ celerybeat-schedule
111
+ celerybeat.pid
112
+
113
+ # SageMath parsed files
114
+ *.sage.py
115
+
116
+ # Environments
117
+ .env
118
+ .venv
119
+ env/
120
+ venv/
121
+ ENV/
122
+ env.bak/
123
+ venv.bak/
124
+
125
+ # Spyder project settings
126
+ .spyderproject
127
+ .spyproject
128
+
129
+ # Rope project settings
130
+ .ropeproject
131
+
132
+ # mkdocs documentation
133
+ /site
134
+
135
+ # mypy
136
+ .mypy_cache/
137
+ .dmypy.json
138
+ dmypy.json
139
+
140
+ # Pyre type checker
141
+ .pyre/
142
+
143
+ # pytype static type analyzer
144
+ .pytype/
145
+
146
+ # Cython debug symbols
147
+ cython_debug/
.gitmodules ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [submodule "apex/contrib/csrc/multihead_attn/cutlass"]
2
+ path = apex/contrib/csrc/multihead_attn/cutlass
3
+ url = https://github.com/NVIDIA/cutlass.git
4
+ branch = v1.2.0
5
+ [submodule "apex/contrib/csrc/cudnn-frontend"]
6
+ path = apex/contrib/csrc/cudnn-frontend
7
+ url = https://github.com/NVIDIA/cudnn-frontend.git
.nojekyll ADDED
File without changes
LICENSE ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ All rights reserved.
2
+
3
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4
+
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+
7
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8
+
9
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10
+
11
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Introduction
2
+
3
+ This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
4
+ Some of the code here will be included in upstream Pytorch eventually.
5
+ The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
6
+
7
+ ## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
8
+
9
+ ## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides
10
+
11
+ # Contents
12
+
13
+ ## 1. Amp: Automatic Mixed Precision
14
+
15
+ **Deprecated. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)**
16
+
17
+ `apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
18
+ Users can easily experiment with different pure and mixed precision training modes by supplying
19
+ different flags to `amp.initialize`.
20
+
21
+ [Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
22
+ (The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
23
+
24
+ [API Documentation](https://nvidia.github.io/apex/amp.html)
25
+
26
+ [Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
27
+
28
+ [DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
29
+
30
+ [Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
31
+
32
+ ## 2. Distributed Training
33
+
34
+ **`apex.parallel.DistributedDataParallel` is deprecated. Use [`torch.nn.parallel.DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=distributeddataparallel#torch.nn.parallel.DistributedDataParallel)**
35
+
36
+ `apex.parallel.DistributedDataParallel` is a module wrapper, similar to
37
+ `torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
38
+ optimized for NVIDIA's NCCL communication library.
39
+
40
+ [API Documentation](https://nvidia.github.io/apex/parallel.html)
41
+
42
+ [Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
43
+
44
+ [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
45
+
46
+ The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
47
+ shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
48
+
49
+ ### Synchronized Batch Normalization
50
+
51
+ **Deprecated. Use [`torch.nn.SyncBatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html)**
52
+
53
+ `apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
54
+ support synchronized BN.
55
+ It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
56
+ Synchronous BN has been used in cases where only a small
57
+ local minibatch can fit on each GPU.
58
+ Allreduced stats increase the effective batch size for the BN layer to the
59
+ global batch size across all processes (which, technically, is the correct
60
+ formulation).
61
+ Synchronous BN has been observed to improve converged accuracy in some of our research models.
62
+
63
+ ### Checkpointing
64
+
65
+ To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
66
+ as well as `amp.load_state_dict()` to restore these attributes.
67
+
68
+ In order to get bitwise accuracy, we recommend the following workflow:
69
+ ```python
70
+ # Initialization
71
+ opt_level = 'O1'
72
+ model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
73
+
74
+ # Train your model
75
+ ...
76
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
77
+ scaled_loss.backward()
78
+ ...
79
+
80
+ # Save checkpoint
81
+ checkpoint = {
82
+ 'model': model.state_dict(),
83
+ 'optimizer': optimizer.state_dict(),
84
+ 'amp': amp.state_dict()
85
+ }
86
+ torch.save(checkpoint, 'amp_checkpoint.pt')
87
+ ...
88
+
89
+ # Restore
90
+ model = ...
91
+ optimizer = ...
92
+ checkpoint = torch.load('amp_checkpoint.pt')
93
+
94
+ model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
95
+ model.load_state_dict(checkpoint['model'])
96
+ optimizer.load_state_dict(checkpoint['optimizer'])
97
+ amp.load_state_dict(checkpoint['amp'])
98
+
99
+ # Continue training
100
+ ...
101
+ ```
102
+
103
+ Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
104
+
105
+ # Installation
106
+ Each [`apex.contrib`](./apex/contrib) module requires one or more install options other than `--cpp_ext` and `--cuda_ext`.
107
+ Note that contrib modules do not necessarily support stable PyTorch releases.
108
+
109
+ ## Containers
110
+ NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
111
+ The containers come with all the custom extensions available at the moment.
112
+
113
+ See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
114
+ - how to pull a container
115
+ - how to run a pulled container
116
+ - release notes
117
+
118
+ ## From Source
119
+
120
+ To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
121
+
122
+ The latest stable release obtainable from https://pytorch.org should also work.
123
+
124
+ ### Linux
125
+ For performance and full functionality, we recommend installing Apex with
126
+ CUDA and C++ extensions via
127
+ ```bash
128
+ git clone https://github.com/NVIDIA/apex
129
+ cd apex
130
+ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
131
+ ```
132
+
133
+ APEX also supports a Python-only build via
134
+ ```bash
135
+ pip install -v --disable-pip-version-check --no-cache-dir ./
136
+ ```
137
+ A Python-only build omits:
138
+ - Fused kernels required to use `apex.optimizers.FusedAdam`.
139
+ - Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
140
+ - Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
141
+ - Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
142
+ `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
143
+
144
+
145
+ ### [Experimental] Windows
146
+ `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
147
+ on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
148
+ If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
149
+
150
+
151
+ ## Custom C++/CUDA Extensions and Install Options
152
+
153
+ If a requirement of a module is not met, then it will not be built.
154
+
155
+ | Module Name | Install Option | Misc |
156
+ |---------------|------------------|--------|
157
+ | `apex_C` | `--cpp_ext` | |
158
+ | `amp_C` | `--cuda_ext` | |
159
+ | `syncbn` | `--cuda_ext` | |
160
+ | `fused_layer_norm_cuda` | `--cuda_ext` | [`apex.normalization`](./apex/normalization) |
161
+ | `mlp_cuda` | `--cuda_ext` | |
162
+ | `scaled_upper_triang_masked_softmax_cuda` | `--cuda_ext` | |
163
+ | `generic_scaled_masked_softmax_cuda` | `--cuda_ext` | |
164
+ | `scaled_masked_softmax_cuda` | `--cuda_ext` | |
165
+ | `fused_weight_gradient_mlp_cuda` | `--cuda_ext` | Requires CUDA>=11 |
166
+ | `permutation_search_cuda` | `--permutation_search` | [`apex.contrib.sparsity`](./apex/contrib/sparsity) |
167
+ | `bnp` | `--bnp` | [`apex.contrib.groupbn`](./apex/contrib/groupbn) |
168
+ | `xentropy` | `--xentropy` | [`apex.contrib.xentropy`](./apex/contrib/xentropy) |
169
+ | `focal_loss_cuda` | `--focal_loss` | [`apex.contrib.focal_loss`](./apex/contrib/focal_loss) |
170
+ | `fused_index_mul_2d` | `--index_mul_2d` | [`apex.contrib.index_mul_2d`](./apex/contrib/index_mul_2d) |
171
+ | `fused_adam_cuda` | `--deprecated_fused_adam` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) |
172
+ | `fused_lamb_cuda` | `--deprecated_fused_lamb` | [`apex.contrib.optimizers`](./apex/contrib/optimizers) |
173
+ | `fast_layer_norm` | `--fast_layer_norm` | [`apex.contrib.layer_norm`](./apex/contrib/layer_norm). different from `fused_layer_norm` |
174
+ | `fmhalib` | `--fmha` | [`apex.contrib.fmha`](./apex/contrib/fmha) |
175
+ | `fast_multihead_attn` | `--fast_multihead_attn` | [`apex.contrib.multihead_attn`](./apex/contrib/multihead_attn) |
176
+ | `transducer_joint_cuda` | `--transducer` | [`apex.contrib.transducer`](./apex/contrib/transducer) |
177
+ | `transducer_loss_cuda` | `--transducer` | [`apex.contrib.transducer`](./apex/contrib/transducer) |
178
+ | `cudnn_gbn_lib` | `--cudnn_gbn` | Requires cuDNN>=8.5, [`apex.contrib.cudnn_gbn`](./apex/contrib/cudnn_gbn) |
179
+ | `peer_memory_cuda` | `--peer_memory` | [`apex.contrib.peer_memory`](./apex/contrib/peer_memory) |
180
+ | `nccl_p2p_cuda` | `--nccl_p2p` | Requires NCCL >= 2.10, [`apex.contrib.nccl_p2p`](./apex/contrib/nccl_p2p) |
181
+ | `fast_bottleneck` | `--fast_bottleneck` | Requires `peer_memory_cuda` and `nccl_p2p_cuda`, [`apex.contrib.bottleneck`](./apex/contrib/bottleneck) |
182
+ | `fused_conv_bias_relu` | `--fused_conv_bias_relu` | Requires cuDNN>=8.4, [`apex.contrib.conv_bias_relu`](./apex/contrib/conv_bias_relu) |
apex.egg-info/PKG-INFO ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: apex
3
+ Version: 0.1
4
+ Summary: PyTorch Extensions written by NVIDIA
5
+ License: UNKNOWN
6
+ Platform: UNKNOWN
7
+ License-File: LICENSE
8
+
9
+ UNKNOWN
apex.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ apex/__init__.py
5
+ apex/_autocast_utils.py
6
+ apex.egg-info/PKG-INFO
7
+ apex.egg-info/SOURCES.txt
8
+ apex.egg-info/dependency_links.txt
9
+ apex.egg-info/requires.txt
10
+ apex.egg-info/top_level.txt
11
+ apex/RNN/RNNBackend.py
12
+ apex/RNN/__init__.py
13
+ apex/RNN/cells.py
14
+ apex/RNN/models.py
15
+ apex/amp/__init__.py
16
+ apex/amp/__version__.py
17
+ apex/amp/_amp_state.py
18
+ apex/amp/_initialize.py
19
+ apex/amp/_process_optimizer.py
20
+ apex/amp/amp.py
21
+ apex/amp/compat.py
22
+ apex/amp/frontend.py
23
+ apex/amp/handle.py
24
+ apex/amp/opt.py
25
+ apex/amp/rnn_compat.py
26
+ apex/amp/scaler.py
27
+ apex/amp/utils.py
28
+ apex/amp/wrap.py
29
+ apex/amp/lists/__init__.py
30
+ apex/amp/lists/functional_overrides.py
31
+ apex/amp/lists/tensor_overrides.py
32
+ apex/amp/lists/torch_overrides.py
33
+ apex/contrib/__init__.py
34
+ apex/contrib/bottleneck/__init__.py
35
+ apex/contrib/bottleneck/bottleneck.py
36
+ apex/contrib/bottleneck/halo_exchangers.py
37
+ apex/contrib/bottleneck/test.py
38
+ apex/contrib/clip_grad/__init__.py
39
+ apex/contrib/clip_grad/clip_grad.py
40
+ apex/contrib/conv_bias_relu/__init__.py
41
+ apex/contrib/conv_bias_relu/conv_bias_relu.py
42
+ apex/contrib/csrc/layer_norm/ln_api.cpp
43
+ apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
44
+ apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
45
+ apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
46
+ apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
47
+ apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp
48
+ apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu
49
+ apex/contrib/cudnn_gbn/__init__.py
50
+ apex/contrib/cudnn_gbn/batch_norm.py
51
+ apex/contrib/fmha/__init__.py
52
+ apex/contrib/fmha/fmha.py
53
+ apex/contrib/focal_loss/__init__.py
54
+ apex/contrib/focal_loss/focal_loss.py
55
+ apex/contrib/groupbn/__init__.py
56
+ apex/contrib/groupbn/batch_norm.py
57
+ apex/contrib/index_mul_2d/__init__.py
58
+ apex/contrib/index_mul_2d/index_mul_2d.py
59
+ apex/contrib/layer_norm/__init__.py
60
+ apex/contrib/layer_norm/layer_norm.py
61
+ apex/contrib/multihead_attn/__init__.py
62
+ apex/contrib/multihead_attn/encdec_multihead_attn.py
63
+ apex/contrib/multihead_attn/encdec_multihead_attn_func.py
64
+ apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py
65
+ apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py
66
+ apex/contrib/multihead_attn/fast_self_multihead_attn_func.py
67
+ apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py
68
+ apex/contrib/multihead_attn/mask_softmax_dropout_func.py
69
+ apex/contrib/multihead_attn/self_multihead_attn.py
70
+ apex/contrib/multihead_attn/self_multihead_attn_func.py
71
+ apex/contrib/optimizers/__init__.py
72
+ apex/contrib/optimizers/distributed_fused_adam.py
73
+ apex/contrib/optimizers/distributed_fused_lamb.py
74
+ apex/contrib/optimizers/fp16_optimizer.py
75
+ apex/contrib/optimizers/fused_adam.py
76
+ apex/contrib/optimizers/fused_lamb.py
77
+ apex/contrib/optimizers/fused_sgd.py
78
+ apex/contrib/peer_memory/__init__.py
79
+ apex/contrib/peer_memory/peer_halo_exchanger_1d.py
80
+ apex/contrib/peer_memory/peer_memory.py
81
+ apex/contrib/sparsity/__init__.py
82
+ apex/contrib/sparsity/asp.py
83
+ apex/contrib/sparsity/permutation_lib.py
84
+ apex/contrib/sparsity/sparse_masklib.py
85
+ apex/contrib/sparsity/permutation_search_kernels/__init__.py
86
+ apex/contrib/sparsity/permutation_search_kernels/call_permutation_search_kernels.py
87
+ apex/contrib/sparsity/permutation_search_kernels/channel_swap.py
88
+ apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py
89
+ apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py
90
+ apex/contrib/test/__init__.py
91
+ apex/contrib/test/bottleneck/__init__.py
92
+ apex/contrib/test/bottleneck/test_bottleneck_module.py
93
+ apex/contrib/test/clip_grad/__init__.py
94
+ apex/contrib/test/clip_grad/test_clip_grad.py
95
+ apex/contrib/test/conv_bias_relu/__init__.py
96
+ apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py
97
+ apex/contrib/test/cudnn_gbn/__init__.py
98
+ apex/contrib/test/cudnn_gbn/test_cudnn_gbn_with_two_gpus.py
99
+ apex/contrib/test/fmha/__init__.py
100
+ apex/contrib/test/fmha/test_fmha.py
101
+ apex/contrib/test/focal_loss/__init__.py
102
+ apex/contrib/test/focal_loss/test_focal_loss.py
103
+ apex/contrib/test/index_mul_2d/__init__.py
104
+ apex/contrib/test/index_mul_2d/test_index_mul_2d.py
105
+ apex/contrib/test/layer_norm/__init__.py
106
+ apex/contrib/test/layer_norm/test_fast_layer_norm.py
107
+ apex/contrib/test/multihead_attn/__init__.py
108
+ apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py
109
+ apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py
110
+ apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py
111
+ apex/contrib/test/multihead_attn/test_mha_fused_softmax.py
112
+ apex/contrib/test/multihead_attn/test_self_multihead_attn.py
113
+ apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py
114
+ apex/contrib/test/optimizers/__init__.py
115
+ apex/contrib/test/optimizers/test_dist_adam.py
116
+ apex/contrib/test/optimizers/test_distributed_fused_lamb.py
117
+ apex/contrib/test/peer_memory/__init__.py
118
+ apex/contrib/test/peer_memory/test_peer_halo_exchange_module.py
119
+ apex/contrib/test/transducer/__init__.py
120
+ apex/contrib/test/transducer/test_transducer_joint.py
121
+ apex/contrib/test/transducer/test_transducer_loss.py
122
+ apex/contrib/test/xentropy/__init__.py
123
+ apex/contrib/test/xentropy/test_label_smoothing.py
124
+ apex/contrib/transducer/__init__.py
125
+ apex/contrib/transducer/_transducer_ref.py
126
+ apex/contrib/transducer/transducer.py
127
+ apex/contrib/xentropy/__init__.py
128
+ apex/contrib/xentropy/softmax_xentropy.py
129
+ apex/fp16_utils/__init__.py
130
+ apex/fp16_utils/fp16_optimizer.py
131
+ apex/fp16_utils/fp16util.py
132
+ apex/fp16_utils/loss_scaler.py
133
+ apex/fused_dense/__init__.py
134
+ apex/fused_dense/fused_dense.py
135
+ apex/mlp/__init__.py
136
+ apex/mlp/mlp.py
137
+ apex/multi_tensor_apply/__init__.py
138
+ apex/multi_tensor_apply/multi_tensor_apply.py
139
+ apex/normalization/__init__.py
140
+ apex/normalization/fused_layer_norm.py
141
+ apex/optimizers/__init__.py
142
+ apex/optimizers/fused_adagrad.py
143
+ apex/optimizers/fused_adam.py
144
+ apex/optimizers/fused_lamb.py
145
+ apex/optimizers/fused_mixed_precision_lamb.py
146
+ apex/optimizers/fused_novograd.py
147
+ apex/optimizers/fused_sgd.py
148
+ apex/parallel/LARC.py
149
+ apex/parallel/__init__.py
150
+ apex/parallel/distributed.py
151
+ apex/parallel/multiproc.py
152
+ apex/parallel/optimized_sync_batchnorm.py
153
+ apex/parallel/optimized_sync_batchnorm_kernel.py
154
+ apex/parallel/sync_batchnorm.py
155
+ apex/parallel/sync_batchnorm_kernel.py
156
+ apex/transformer/__init__.py
157
+ apex/transformer/_ucc_util.py
158
+ apex/transformer/enums.py
159
+ apex/transformer/log_util.py
160
+ apex/transformer/microbatches.py
161
+ apex/transformer/parallel_state.py
162
+ apex/transformer/utils.py
163
+ apex/transformer/_data/__init__.py
164
+ apex/transformer/_data/_batchsampler.py
165
+ apex/transformer/amp/__init__.py
166
+ apex/transformer/amp/grad_scaler.py
167
+ apex/transformer/functional/__init__.py
168
+ apex/transformer/functional/fused_softmax.py
169
+ apex/transformer/layers/__init__.py
170
+ apex/transformer/layers/layer_norm.py
171
+ apex/transformer/pipeline_parallel/__init__.py
172
+ apex/transformer/pipeline_parallel/_timers.py
173
+ apex/transformer/pipeline_parallel/p2p_communication.py
174
+ apex/transformer/pipeline_parallel/utils.py
175
+ apex/transformer/pipeline_parallel/schedules/__init__.py
176
+ apex/transformer/pipeline_parallel/schedules/common.py
177
+ apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
178
+ apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py
179
+ apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
180
+ apex/transformer/tensor_parallel/__init__.py
181
+ apex/transformer/tensor_parallel/cross_entropy.py
182
+ apex/transformer/tensor_parallel/data.py
183
+ apex/transformer/tensor_parallel/layers.py
184
+ apex/transformer/tensor_parallel/mappings.py
185
+ apex/transformer/tensor_parallel/memory.py
186
+ apex/transformer/tensor_parallel/random.py
187
+ apex/transformer/tensor_parallel/utils.py
188
+ apex/transformer/testing/__init__.py
189
+ apex/transformer/testing/arguments.py
190
+ apex/transformer/testing/commons.py
191
+ apex/transformer/testing/distributed_test_base.py
192
+ apex/transformer/testing/global_vars.py
193
+ apex/transformer/testing/standalone_bert.py
194
+ apex/transformer/testing/standalone_gpt.py
195
+ apex/transformer/testing/standalone_transformer_lm.py
196
+ csrc/amp_C_frontend.cpp
197
+ csrc/flatten_unflatten.cpp
198
+ csrc/fused_dense.cpp
199
+ csrc/fused_dense_cuda.cu
200
+ csrc/layer_norm_cuda.cpp
201
+ csrc/layer_norm_cuda_kernel.cu
202
+ csrc/mlp.cpp
203
+ csrc/mlp_cuda.cu
204
+ csrc/multi_tensor_adagrad.cu
205
+ csrc/multi_tensor_adam.cu
206
+ csrc/multi_tensor_axpby_kernel.cu
207
+ csrc/multi_tensor_l2norm_kernel.cu
208
+ csrc/multi_tensor_l2norm_kernel_mp.cu
209
+ csrc/multi_tensor_l2norm_scale_kernel.cu
210
+ csrc/multi_tensor_lamb.cu
211
+ csrc/multi_tensor_lamb_mp.cu
212
+ csrc/multi_tensor_lamb_stage_1.cu
213
+ csrc/multi_tensor_lamb_stage_2.cu
214
+ csrc/multi_tensor_novograd.cu
215
+ csrc/multi_tensor_scale_kernel.cu
216
+ csrc/multi_tensor_sgd_kernel.cu
217
+ csrc/syncbn.cpp
218
+ csrc/welford.cu
219
+ csrc/megatron/fused_weight_gradient_dense.cpp
220
+ csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu
221
+ csrc/megatron/fused_weight_gradient_dense_cuda.cu
222
+ csrc/megatron/generic_scaled_masked_softmax.cpp
223
+ csrc/megatron/generic_scaled_masked_softmax_cuda.cu
224
+ csrc/megatron/scaled_masked_softmax.cpp
225
+ csrc/megatron/scaled_masked_softmax_cuda.cu
226
+ csrc/megatron/scaled_softmax.cpp
227
+ csrc/megatron/scaled_softmax_cuda.cu
228
+ csrc/megatron/scaled_upper_triang_masked_softmax.cpp
229
+ csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu
apex.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
apex.egg-info/requires.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ packaging>20.6
apex.egg-info/top_level.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ amp_C
2
+ apex
3
+ apex_C
4
+ distributed_adam_cuda
5
+ fast_layer_norm
6
+ fused_adam_cuda
7
+ fused_dense_cuda
8
+ fused_layer_norm_cuda
9
+ fused_weight_gradient_mlp_cuda
10
+ generic_scaled_masked_softmax_cuda
11
+ mlp_cuda
12
+ scaled_masked_softmax_cuda
13
+ scaled_softmax_cuda
14
+ scaled_upper_triang_masked_softmax_cuda
15
+ syncbn
apex/RNN/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ **This module will be removed by the end of February 2023**
2
+
3
+ Under construction...
apex/RNN/RNNBackend.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+
5
+ import torch.nn.functional as F
6
+
7
+ import math
8
+
9
+
10
+ def is_iterable(maybe_iterable):
11
+ return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple)
12
+
13
+
14
+ def flatten_list(tens_list):
15
+ """
16
+ flatten_list
17
+ """
18
+ if not is_iterable(tens_list):
19
+ return tens_list
20
+
21
+ return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() )
22
+
23
+
24
+ #These modules always assumes batch_first
25
+ class bidirectionalRNN(nn.Module):
26
+ """
27
+ bidirectionalRNN
28
+ """
29
+ def __init__(self, inputRNN, num_layers=1, dropout = 0):
30
+ super(bidirectionalRNN, self).__init__()
31
+ self.dropout = dropout
32
+ self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout)
33
+ self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout)
34
+ self.rnns = nn.ModuleList([self.fwd, self.bckwrd])
35
+
36
+ #collect hidden option will return all hidden/cell states from entire RNN
37
+ def forward(self, input, collect_hidden=False):
38
+ """
39
+ forward()
40
+ """
41
+ seq_len = input.size(0)
42
+ bsz = input.size(1)
43
+
44
+ fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden))
45
+ bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden))
46
+
47
+ output = torch.cat( [fwd_out, bckwrd_out], -1 )
48
+ hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) )
49
+
50
+ return output, hiddens
51
+
52
+ def reset_parameters(self):
53
+ """
54
+ reset_parameters()
55
+ """
56
+ for rnn in self.rnns:
57
+ rnn.reset_parameters()
58
+
59
+ def init_hidden(self, bsz):
60
+ """
61
+ init_hidden()
62
+ """
63
+ for rnn in self.rnns:
64
+ rnn.init_hidden(bsz)
65
+
66
+ def detach_hidden(self):
67
+ """
68
+ detach_hidden()
69
+ """
70
+ for rnn in self.rnns:
71
+ rnn.detachHidden()
72
+
73
+ def reset_hidden(self, bsz):
74
+ """
75
+ reset_hidden()
76
+ """
77
+ for rnn in self.rnns:
78
+ rnn.reset_hidden(bsz)
79
+
80
+ def init_inference(self, bsz):
81
+ """
82
+ init_inference()
83
+ """
84
+ for rnn in self.rnns:
85
+ rnn.init_inference(bsz)
86
+
87
+
88
+ #assumes hidden_state[0] of inputRNN is output hidden state
89
+ #constructor either takes an RNNCell or list of RNN layers
90
+ class stackedRNN(nn.Module):
91
+ """
92
+ stackedRNN
93
+ """
94
+ def __init__(self, inputRNN, num_layers=1, dropout=0):
95
+ super(stackedRNN, self).__init__()
96
+
97
+ self.dropout = dropout
98
+
99
+ if isinstance(inputRNN, RNNCell):
100
+ self.rnns = [inputRNN]
101
+ for i in range(num_layers-1):
102
+ self.rnns.append(inputRNN.new_like(inputRNN.output_size))
103
+ elif isinstance(inputRNN, list):
104
+ assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers"
105
+ self.rnns=inputRNN
106
+ else:
107
+ raise RuntimeError()
108
+
109
+ self.nLayers = len(self.rnns)
110
+
111
+ self.rnns = nn.ModuleList(self.rnns)
112
+
113
+
114
+ '''
115
+ Returns output as hidden_state[0] Tensor([sequence steps][batch size][features])
116
+ If collect hidden will also return Tuple(
117
+ [n_hidden_states][sequence steps] Tensor([layer][batch size][features])
118
+ )
119
+ If not collect hidden will also return Tuple(
120
+ [n_hidden_states] Tensor([layer][batch size][features])
121
+ '''
122
+ def forward(self, input, collect_hidden=False, reverse=False):
123
+ """
124
+ forward()
125
+ """
126
+ seq_len = input.size(0)
127
+ bsz = input.size(1)
128
+ inp_iter = reversed(range(seq_len)) if reverse else range(seq_len)
129
+
130
+ hidden_states = [[] for i in range(self.nLayers)]
131
+ outputs = []
132
+
133
+ for seq in inp_iter:
134
+ for layer in range(self.nLayers):
135
+
136
+ if layer == 0:
137
+ prev_out = input[seq]
138
+
139
+ outs = self.rnns[layer](prev_out)
140
+
141
+ if collect_hidden:
142
+ hidden_states[layer].append(outs)
143
+ elif seq == seq_len-1:
144
+ hidden_states[layer].append(outs)
145
+
146
+ prev_out = outs[0]
147
+
148
+ outputs.append(prev_out)
149
+
150
+ if reverse:
151
+ outputs = list(reversed(outputs))
152
+ '''
153
+ At this point outputs is in format:
154
+ list( [seq_length] x Tensor([bsz][features]) )
155
+ need to convert it to:
156
+ list( Tensor([seq_length][bsz][features]) )
157
+ '''
158
+ output = flatten_list(outputs)
159
+
160
+ '''
161
+ hidden_states at this point is in format:
162
+ list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) )
163
+ need to convert it to:
164
+ For not collect hidden:
165
+ list( [hidden_states] x Tensor([layer][bsz][features]) )
166
+ For collect hidden:
167
+ list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
168
+ '''
169
+ if not collect_hidden:
170
+ seq_len = 1
171
+ n_hid = self.rnns[0].n_hidden_states
172
+ new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ]
173
+
174
+
175
+ for i in range(n_hid):
176
+ for j in range(seq_len):
177
+ for k in range(self.nLayers):
178
+ new_hidden[i][j][k] = hidden_states[k][j][i]
179
+
180
+ hidden_states = new_hidden
181
+ #Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) )
182
+ #Reverse seq_length if reverse
183
+ if reverse:
184
+ hidden_states = list( list(reversed(list(entry))) for entry in hidden_states)
185
+
186
+ #flatten layer dimension into tensor
187
+ hiddens = list( list(
188
+ flatten_list(seq) for seq in hidden )
189
+ for hidden in hidden_states )
190
+
191
+ #Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
192
+ #Remove seq_length dimension if not collect_hidden
193
+ if not collect_hidden:
194
+ hidden_states = list( entry[0] for entry in hidden_states)
195
+ return output, hidden_states
196
+
197
+ def reset_parameters(self):
198
+ """
199
+ reset_parameters()
200
+ """
201
+ for rnn in self.rnns:
202
+ rnn.reset_parameters()
203
+
204
+ def init_hidden(self, bsz):
205
+ """
206
+ init_hidden()
207
+ """
208
+ for rnn in self.rnns:
209
+ rnn.init_hidden(bsz)
210
+
211
+ def detach_hidden(self):
212
+ """
213
+ detach_hidden()
214
+ """
215
+ for rnn in self.rnns:
216
+ rnn.detach_hidden()
217
+
218
+ def reset_hidden(self, bsz):
219
+ """
220
+ reset_hidden()
221
+ """
222
+ for rnn in self.rnns:
223
+ rnn.reset_hidden(bsz)
224
+
225
+ def init_inference(self, bsz):
226
+ """
227
+ init_inference()
228
+ """
229
+ for rnn in self.rnns:
230
+ rnn.init_inference(bsz)
231
+
232
+ class RNNCell(nn.Module):
233
+ """
234
+ RNNCell
235
+ gate_multiplier is related to the architecture you're working with
236
+ For LSTM-like it will be 4 and GRU-like will be 3.
237
+ Always assumes input is NOT batch_first.
238
+ Output size that's not hidden size will use output projection
239
+ Hidden_states is number of hidden states that are needed for cell
240
+ if one will go directly to cell as tensor, if more will go as list
241
+ """
242
+ def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None):
243
+ super(RNNCell, self).__init__()
244
+
245
+ self.gate_multiplier = gate_multiplier
246
+ self.input_size = input_size
247
+ self.hidden_size = hidden_size
248
+ self.cell = cell
249
+ self.bias = bias
250
+ self.output_size = output_size
251
+ if output_size is None:
252
+ self.output_size = hidden_size
253
+
254
+ self.gate_size = gate_multiplier * self.hidden_size
255
+ self.n_hidden_states = n_hidden_states
256
+
257
+ self.w_ih = nn.Parameter(torch.empty(self.gate_size, self.input_size))
258
+ self.w_hh = nn.Parameter(torch.empty(self.gate_size, self.output_size))
259
+
260
+ #Check if there's recurrent projection
261
+ if(self.output_size != self.hidden_size):
262
+ self.w_ho = nn.Parameter(torch.empty(self.output_size, self.hidden_size))
263
+
264
+ self.b_ih = self.b_hh = None
265
+ if self.bias:
266
+ self.b_ih = nn.Parameter(torch.empty(self.gate_size))
267
+ self.b_hh = nn.Parameter(torch.empty(self.gate_size))
268
+
269
+ #hidden states for forward
270
+ self.hidden = [ None for states in range(self.n_hidden_states)]
271
+
272
+ self.reset_parameters()
273
+
274
+ def new_like(self, new_input_size=None):
275
+ """
276
+ new_like()
277
+ """
278
+ if new_input_size is None:
279
+ new_input_size = self.input_size
280
+
281
+ return type(self)(self.gate_multiplier,
282
+ new_input_size,
283
+ self.hidden_size,
284
+ self.cell,
285
+ self.n_hidden_states,
286
+ self.bias,
287
+ self.output_size)
288
+
289
+
290
+ #Use xavier where we can (weights), otherwise use uniform (bias)
291
+ def reset_parameters(self, gain=1):
292
+ """
293
+ reset_parameters()
294
+ """
295
+ stdev = 1.0 / math.sqrt(self.hidden_size)
296
+ for param in self.parameters():
297
+ param.data.uniform_(-stdev, stdev)
298
+ '''
299
+ Xavier reset:
300
+ def reset_parameters(self, gain=1):
301
+ stdv = 1.0 / math.sqrt(self.gate_size)
302
+
303
+ for param in self.parameters():
304
+ if (param.dim() > 1):
305
+ torch.nn.init.xavier_normal(param, gain)
306
+ else:
307
+ param.data.uniform_(-stdv, stdv)
308
+ '''
309
+ def init_hidden(self, bsz):
310
+ """
311
+ init_hidden()
312
+ """
313
+ for param in self.parameters():
314
+ if param is not None:
315
+ a_param = param
316
+ break
317
+
318
+ for i, _ in enumerate(self.hidden):
319
+ if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz):
320
+
321
+ if i==0:
322
+ hidden_size = self.output_size
323
+ else:
324
+ hidden_size = self.hidden_size
325
+
326
+ tens = a_param.data.new(bsz, hidden_size).zero_()
327
+ self.hidden[i] = Variable(tens, requires_grad=False)
328
+
329
+
330
+ def reset_hidden(self, bsz):
331
+ """
332
+ reset_hidden()
333
+ """
334
+ for i, _ in enumerate(self.hidden):
335
+ self.hidden[i] = None
336
+ self.init_hidden(bsz)
337
+
338
+ def detach_hidden(self):
339
+ """
340
+ detach_hidden()
341
+ """
342
+ for i, _ in enumerate(self.hidden):
343
+ if self.hidden[i] is None:
344
+ raise RuntimeError("Must initialize hidden state before you can detach it")
345
+ for i, _ in enumerate(self.hidden):
346
+ self.hidden[i] = self.hidden[i].detach()
347
+
348
+ def forward(self, input):
349
+ """
350
+ forward()
351
+ if not inited or bsz has changed this will create hidden states
352
+ """
353
+ self.init_hidden(input.size()[0])
354
+
355
+ hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
356
+ self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh)
357
+ if(self.n_hidden_states > 1):
358
+ self.hidden = list(self.hidden)
359
+ else:
360
+ self.hidden=[self.hidden]
361
+
362
+ if self.output_size != self.hidden_size:
363
+ self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
364
+
365
+ return tuple(self.hidden)
apex/RNN/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .models import LSTM, GRU, ReLU, Tanh, mLSTM
2
+
3
+ __all__ = ['models']
apex/RNN/cells.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .RNNBackend import RNNCell
6
+
7
+ from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
8
+
9
+ import math
10
+
11
+
12
+ class mLSTMRNNCell(RNNCell):
13
+ """
14
+ mLSTMRNNCell
15
+ """
16
+
17
+ def __init__(self, input_size, hidden_size, bias = False, output_size = None):
18
+ gate_multiplier = 4
19
+ super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)
20
+
21
+ self.w_mih = nn.Parameter(torch.empty(self.output_size, self.input_size))
22
+ self.w_mhh = nn.Parameter(torch.empty(self.output_size, self.output_size))
23
+
24
+ self.reset_parameters()
25
+
26
+ def forward(self, input):
27
+ """
28
+ mLSTMRNNCell.forward()
29
+ """
30
+ #if not inited or bsz has changed this will create hidden states
31
+ self.init_hidden(input.size()[0])
32
+
33
+ hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
34
+
35
+ self.hidden = list(
36
+ self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh,
37
+ b_ih=self.b_ih, b_hh=self.b_hh)
38
+ )
39
+
40
+ if self.output_size != self.hidden_size:
41
+ self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
42
+ return tuple(self.hidden)
43
+
44
+
45
+ def new_like(self, new_input_size=None):
46
+ if new_input_size is None:
47
+ new_input_size = self.input_size
48
+
49
+ return type(self)(
50
+ new_input_size,
51
+ self.hidden_size,
52
+ self.bias,
53
+ self.output_size)
54
+
55
+ def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None):
56
+ """
57
+ mLSTMCell
58
+ """
59
+
60
+ if input.is_cuda:
61
+ igates = F.linear(input, w_ih)
62
+ m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
63
+ hgates = F.linear(m, w_hh)
64
+
65
+ state = fusedBackend.LSTMFused.apply
66
+ return state(igates, hgates, hidden[1], b_ih, b_hh)
67
+
68
+ hx, cx = hidden
69
+
70
+ m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
71
+ gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh)
72
+
73
+ ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
74
+
75
+ ingate = F.sigmoid(ingate)
76
+ forgetgate = F.sigmoid(forgetgate)
77
+ cellgate = F.tanh(cellgate)
78
+ outgate = F.sigmoid(outgate)
79
+
80
+ cy = (forgetgate * cx) + (ingate * cellgate)
81
+ hy = outgate * F.tanh(cy)
82
+
83
+ return hy, cy
84
+
apex/RNN/models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell
4
+
5
+ from apex import deprecated_warning
6
+ from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell
7
+ from .cells import mLSTMRNNCell, mLSTMCell
8
+
9
+ def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0):
10
+ """
11
+ :class:`toRNNBackend`
12
+ """
13
+
14
+ deprecated_warning("`apex.RNN` is deprecated and will be removed by the end of February 2023.")
15
+ if bidirectional:
16
+ return bidirectionalRNN(inputRNN, num_layers, dropout = dropout)
17
+ else:
18
+ return stackedRNN(inputRNN, num_layers, dropout = dropout)
19
+
20
+
21
+ def LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
22
+ """
23
+ :class:`LSTM`
24
+ """
25
+ inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size)
26
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
27
+
28
+ def GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
29
+ """
30
+ :class:`GRU`
31
+ """
32
+ inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size)
33
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
34
+
35
+ def ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
36
+ """
37
+ :class:`ReLU`
38
+ """
39
+ inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size)
40
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
41
+
42
+ def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
43
+ """
44
+ :class:`Tanh`
45
+ """
46
+ inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size)
47
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
48
+
49
+ def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
50
+ """
51
+ :class:`mLSTM`
52
+ """
53
+ inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size)
54
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
55
+
56
+
apex/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+
4
+ # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
5
+ import torch
6
+
7
+
8
+ __all__ = ["amp", "fp16_utils", "optimizers", "normalization", "transformer"]
9
+
10
+
11
+ if torch.distributed.is_available():
12
+ from . import parallel
13
+ __all__.append("parallel")
14
+
15
+ from . import amp
16
+ from . import fp16_utils
17
+
18
+ # For optimizers and normalization there is no Python fallback.
19
+ # Absence of cuda backend is a hard error.
20
+ # I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
21
+ # to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
22
+ # so they expect those backends to be available, but for some reason they actually aren't
23
+ # available (for example because they built improperly in a way that isn't revealed until
24
+ # load time) the error message is timely and visible.
25
+ from . import optimizers
26
+ from . import normalization
27
+ from . import transformer
28
+
29
+
30
+ # Logging utilities for apex.transformer module
31
+ class RankInfoFormatter(logging.Formatter):
32
+
33
+ def format(self, record):
34
+ from apex.transformer.parallel_state import get_rank_info
35
+ record.rank_info = get_rank_info()
36
+ return super().format(record)
37
+
38
+
39
+ _library_root_logger = logging.getLogger(__name__)
40
+ handler = logging.StreamHandler()
41
+ handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S"))
42
+ _library_root_logger.addHandler(handler)
43
+ _library_root_logger.propagate = False
44
+
45
+
46
+ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
47
+ cudnn_available = torch.backends.cudnn.is_available()
48
+ cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
49
+ if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
50
+ warnings.warn(
51
+ f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, "
52
+ f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
53
+ )
54
+ return False
55
+ return True
56
+
57
+
58
+ class DeprecatedFeatureWarning(FutureWarning):
59
+ pass
60
+
61
+
62
+ def deprecated_warning(msg: str) -> None:
63
+ if (
64
+ not torch.distributed.is_available
65
+ or not torch.distributed.is_initialized()
66
+ or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)
67
+ ):
68
+ warnings.warn(msg, DeprecatedFeatureWarning)
apex/_autocast_utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence
2
+
3
+ import torch
4
+
5
+
6
+ __all__ = ["_cast_if_autocast_enabled"]
7
+
8
+
9
+ def _get_autocast_dtypes() -> Sequence[torch.dtype]:
10
+ if torch.cuda.is_bf16_supported():
11
+ return [torch.half, torch.bfloat16]
12
+ return [torch.half]
13
+
14
+
15
+ def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
16
+ if not torch.is_autocast_enabled():
17
+ return torch.float or dtype
18
+ else:
19
+ return torch.get_autocast_gpu_dtype()
20
+
21
+
22
+ def _cast_if_autocast_enabled(*args):
23
+ if not torch.is_autocast_enabled():
24
+ return args
25
+ else:
26
+ return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
apex/amp/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # amp: Automatic Mixed Precision
2
+
3
+ ## Annotating User Functions
4
+
5
+ Nearly all PyTorch user code needs nothing more than the two steps
6
+ above to use amp. After all, custom layers are built out of simpler
7
+ PyTorch components, and amp already can see those.
8
+
9
+ However, any custom C++ or CUDA code is outside of amp's (default)
10
+ view of things. For example, suppose I implemented a new recurrent
11
+ cell called a "forgetful recurrent unit" that calls directly into a
12
+ CUDA backend:
13
+
14
+ ```python
15
+ from backend import FRUBackend
16
+
17
+ def fru(input, hidden, weight, bias):
18
+ # call to CUDA code
19
+ FRUBackend(input, hidden, weight, bias)
20
+ ```
21
+
22
+ In this case, it is possible to get a runtime type mismatch. For
23
+ example, you might have `input` in fp16, and `weight` in fp32, and amp
24
+ doesn't have the visibility to insert an appropriate cast.
25
+
26
+ amp exposes two ways to handle "invisible" backend code: function
27
+ annotations and explicit registration.
28
+
29
+ #### Function annotation
30
+
31
+ The first way to handle backend code is a set of function annotations:
32
+
33
+ - `@amp.half_function`
34
+ - `@amp.float_function`
35
+ - `@amp.promote_function`
36
+
37
+ These correspond to:
38
+
39
+ - Cast all arguments to fp16
40
+ - Cast all argumnets fo fp32
41
+ - If there are any type mismatches, cast everything to the widest type
42
+
43
+ In our example, we believe that the FRU unit is fp16-safe and will get
44
+ performance gains from casting its arguments to fp16, so we write:
45
+
46
+ ```python
47
+ @amp.half_function
48
+ def fru(input, hidden, weight, bias):
49
+ #...
50
+ ```
51
+
52
+ #### Explicit registration
53
+
54
+ The other way to handle backend code is with explicit function
55
+ registration:
56
+
57
+ - `amp.register_half_function(module, function_name)`
58
+ - `amp.register_float_function(module, function_name)`
59
+ - `amp.register_promote_function(module, function_name)`
60
+
61
+ When using this API, `module` is the containing class or module for
62
+ the function, and `function_name` is the _string_ name of the
63
+ function. Note that the function must be registered before the call to
64
+ `amp.initalize()`.
65
+
66
+ For our FRU unit, we can register the backend function directly:
67
+
68
+ ```python
69
+ import backend
70
+
71
+ amp.register_half_function(backend, 'FRUBackend')
72
+ ```
apex/amp/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .amp import init, half_function, float_function, promote_function,\
2
+ register_half_function, register_float_function, register_promote_function
3
+ from .handle import scale_loss, disable_casts
4
+ from .frontend import initialize, state_dict, load_state_dict
5
+ from ._amp_state import master_params, _amp_state
apex/amp/__version__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ VERSION = (0, 1, 0)
2
+ __version__ = '.'.join(map(str, VERSION))
apex/amp/_amp_state.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a "header object" that allows different amp modules to communicate.
2
+ # I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like.
3
+ # But apparently it's ok:
4
+ # http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
5
+ import torch
6
+
7
+
8
+ class AmpState(object):
9
+ def __init__(self):
10
+ self.hard_override=False
11
+ self.allow_incoming_model_not_fp32 = False
12
+ self.verbosity=1
13
+
14
+
15
+ # Attribute stash. Could also just stash things as global module attributes.
16
+ _amp_state = AmpState()
17
+
18
+
19
+ def warn_or_err(msg):
20
+ if _amp_state.hard_override:
21
+ print("Warning: " + msg)
22
+ else:
23
+ raise RuntimeError(msg)
24
+ # I'm not sure if allowing hard_override is a good idea.
25
+ # + " If you're sure you know what you're doing, supply " +
26
+ # "hard_override=True to amp.initialize.")
27
+
28
+
29
+ def maybe_print(msg, rank0=False):
30
+ distributed = torch.distributed.is_available() and \
31
+ torch.distributed.is_initialized() and \
32
+ torch.distributed.get_world_size() > 1
33
+ if _amp_state.verbosity > 0:
34
+ if rank0:
35
+ if distributed:
36
+ if torch.distributed.get_rank() == 0:
37
+ print(msg)
38
+ else:
39
+ print(msg)
40
+ else:
41
+ print(msg)
42
+
43
+
44
+ # def iter_params(param_groups):
45
+ # for group in param_groups:
46
+ # for p in group['params']:
47
+ # yield p
48
+
49
+
50
+ def master_params(optimizer):
51
+ """
52
+ Generator expression that iterates over the params owned by ``optimizer``.
53
+
54
+ Args:
55
+ optimizer: An optimizer previously returned from ``amp.initialize``.
56
+ """
57
+ for group in optimizer.param_groups:
58
+ for p in group['params']:
59
+ yield p
apex/amp/_initialize.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc as container_abcs
2
+ from types import MethodType
3
+ import functools
4
+ import sys
5
+ import warnings
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from ._amp_state import _amp_state, warn_or_err
11
+ from .handle import disable_casts
12
+ from .scaler import LossScaler
13
+ from ._process_optimizer import _process_optimizer
14
+ from apex.fp16_utils import convert_network
15
+ from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
16
+ from ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
17
+
18
+ if torch.distributed.is_available():
19
+ from ..parallel import DistributedDataParallel as apex_DDP
20
+ from ..parallel.LARC import LARC
21
+
22
+
23
+ def to_type(dtype, t):
24
+ if isinstance(t, torch.Tensor):
25
+ if not t.is_cuda:
26
+ # This should not be a hard error, since it may be legitimate.
27
+ warnings.warn("An input tensor was not cuda.")
28
+ # GANs require this.
29
+ # if t.requires_grad:
30
+ # warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
31
+ # "its gradients will not be properly allreduced by DDP.")
32
+ if t.is_floating_point():
33
+ return t.to(dtype)
34
+ return t
35
+ else:
36
+ # Trust the user's custom batch type, that's all I can do here.
37
+ return t.to(dtype)
38
+
39
+
40
+ # Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
41
+ def applier(value, fn):
42
+ if isinstance(value, torch.Tensor):
43
+ return fn(value)
44
+ elif isinstance(value, str):
45
+ return value
46
+ elif isinstance(value, np.ndarray):
47
+ return value
48
+ elif hasattr(value, "to"): # Allow handling of custom batch classes
49
+ return fn(value)
50
+ elif isinstance(value, container_abcs.Mapping):
51
+ return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
52
+ elif isinstance(value, container_abcs.Iterable):
53
+ return type(value)(applier(v, fn) for v in value)
54
+ else:
55
+ # Do I want this to fire off even if someone chooses to pass something ordinary like
56
+ # an int or float? May be more annoying than it's worth.
57
+ # print("Warning: unrecognized type in applier. If your input data is a custom class, "
58
+ # "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. "
59
+ # "Amp will check for your custom to() and invoke it to cast the batch's "
60
+ # "floating-point Tensors to the appropriate type. "
61
+ # "Also, if your data is a custom class, it is your responsibility to ensure that "
62
+ # "any Tensors you want to be cuda are already cuda."
63
+ return value
64
+
65
+
66
+ def check_models(models):
67
+ for model in models:
68
+ parallel_type = None
69
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
70
+ parallel_type = "torch.nn.parallel.DistributedDataParallel"
71
+ if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP):
72
+ parallel_type = "apex.parallel.DistributedDataParallel"
73
+ if isinstance(model, torch.nn.parallel.DataParallel):
74
+ parallel_type = "torch.nn.parallel.DataParallel"
75
+ if parallel_type is not None:
76
+ raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) +
77
+ "Parallel wrappers should only be applied to the model(s) AFTER \n"
78
+ "the model(s) have been returned from amp.initialize.")
79
+
80
+
81
+ def check_params_fp32(models):
82
+ for model in models:
83
+ for name, param in model.named_parameters():
84
+ if param.is_floating_point():
85
+ if 'Half' in param.type():
86
+ warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
87
+ "When using amp.initialize, you do not need to call .half() on your model\n"
88
+ "before passing it, no matter what optimization level you choose.".format(
89
+ name, param.type()))
90
+ elif not param.is_cuda:
91
+ warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
92
+ "When using amp.initialize, you need to provide a model with parameters\n"
93
+ "located on a CUDA device before passing it no matter what optimization level\n"
94
+ "you chose. Use model.to('cuda') to use the default device.".format(
95
+ name, param.type()))
96
+
97
+ # Backward compatibility for PyTorch 0.4
98
+ if hasattr(model, 'named_buffers'):
99
+ buf_iter = model.named_buffers()
100
+ else:
101
+ buf_iter = model._buffers
102
+ for obj in buf_iter:
103
+ if type(obj)==tuple:
104
+ name, buf = obj
105
+ else:
106
+ name, buf = obj, buf_iter[obj]
107
+ if buf.is_floating_point():
108
+ if 'Half' in buf.type():
109
+ warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
110
+ "When using amp.initialize, you do not need to call .half() on your model\n"
111
+ "before passing it, no matter what optimization level you choose.".format(
112
+ name, buf.type()))
113
+ elif not buf.is_cuda:
114
+ warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
115
+ "When using amp.initialize, you need to provide a model with buffers\n"
116
+ "located on a CUDA device before passing it no matter what optimization level\n"
117
+ "you chose. Use model.to('cuda') to use the default device.".format(
118
+ name, buf.type()))
119
+
120
+
121
+ def check_optimizers(optimizers):
122
+ for optim in optimizers:
123
+ bad_optim_type = None
124
+ if isinstance(optim, FP16_Optimizer_general):
125
+ bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
126
+ if isinstance(optim, FP16_Optimizer_for_fused):
127
+ bad_optim_type = "apex.optimizers.FP16_Optimizer"
128
+ if bad_optim_type is not None:
129
+ raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
130
+ "The optimizer(s) passed to amp.initialize() must be bare \n"
131
+ "instances of either ordinary Pytorch optimizers, or Apex fused \n"
132
+ "optimizers.\n")
133
+
134
+
135
+ class O2StateDictHook(object):
136
+ def __init__(self, fn):
137
+ self.fn = fn
138
+
139
+ def __call__(self, module, state_dict, prefix, local_metadata):
140
+ for key in state_dict:
141
+ param = state_dict[key]
142
+ if 'Half' in param.type():
143
+ param = param.to(torch.float32)
144
+ state_dict[key] = param
145
+
146
+
147
+ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
148
+ from .amp import init as amp_init
149
+
150
+ optimizers_was_list = False
151
+ if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
152
+ optimizers = [optimizers]
153
+ elif optimizers is None:
154
+ optimizers = []
155
+ elif isinstance(optimizers, list):
156
+ optimizers_was_list = True
157
+ check_optimizers(optimizers)
158
+ else:
159
+ check_optimizers([optimizers])
160
+ raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
161
+
162
+ if isinstance(models, torch.nn.Module):
163
+ models_was_list = False
164
+ models = [models]
165
+ elif isinstance(models, list):
166
+ models_was_list = True
167
+ else:
168
+ raise TypeError("models must be either a single model or a list of models.")
169
+
170
+ check_models(models)
171
+
172
+ if not _amp_state.allow_incoming_model_not_fp32:
173
+ check_params_fp32(models)
174
+
175
+ # In the future, when FP16_Optimizer can be deprecated and master weights can
176
+ # become an attribute, remember to stash master weights before casting the model.
177
+
178
+ if properties.cast_model_type:
179
+ if properties.keep_batchnorm_fp32:
180
+ for model in models:
181
+ convert_network(model, properties.cast_model_type)
182
+ else:
183
+ for model in models:
184
+ model.to(properties.cast_model_type)
185
+
186
+ input_caster = functools.partial(to_type, properties.cast_model_type)
187
+ if cast_model_outputs is not None:
188
+ output_caster = functools.partial(to_type, cast_model_outputs)
189
+ else:
190
+ output_caster = functools.partial(to_type, torch.float32)
191
+
192
+ for model in models:
193
+ # Patch the forward method to cast incoming data to the correct type, and
194
+ # outgoing data to float32, so "the user never needs to call .half()."
195
+ # I like writing things explicitly more than decorators.
196
+ def patch_forward(old_fwd):
197
+ def new_fwd(*args, **kwargs):
198
+ output = old_fwd(*applier(args, input_caster),
199
+ **applier(kwargs, input_caster))
200
+ return applier(output, output_caster)
201
+ return new_fwd
202
+
203
+ model.forward = patch_forward(model.forward)
204
+
205
+ # State dict trick to recast any preexisting per-param state tensors
206
+ for optimizer in optimizers:
207
+ optimizer.load_state_dict(optimizer.state_dict())
208
+
209
+ # patch model.state_dict() to return float32 params
210
+ for model in models:
211
+ for module in model.modules():
212
+ module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))
213
+
214
+ elif cast_model_outputs is not None:
215
+ output_caster = functools.partial(to_type, cast_model_outputs)
216
+
217
+ for model in models:
218
+ def patch_forward(old_fwd):
219
+ def new_fwd(*args, **kwargs):
220
+ output = old_fwd(*args, **kwargs)
221
+ return applier(output, output_caster)
222
+ return new_fwd
223
+
224
+ model.forward = patch_forward(model.forward)
225
+
226
+ for i, optimizer in enumerate(optimizers):
227
+ optimizers[i] = _process_optimizer(optimizer, properties)
228
+
229
+ _amp_state.loss_scalers = []
230
+ for _ in range(num_losses):
231
+ _amp_state.loss_scalers.append(LossScaler(properties.loss_scale,
232
+ min_loss_scale=_amp_state.min_loss_scale,
233
+ max_loss_scale=_amp_state.max_loss_scale))
234
+
235
+ if properties.patch_torch_functions:
236
+ # handle is unused here. It's accessible later through a global value anyway.
237
+ handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
238
+ for optimizer in optimizers:
239
+ # Disable Amp casting for the optimizer step, because it should only be
240
+ # applied to FP32 master params anyway.
241
+ def patch_step(old_step):
242
+ def new_step(self, *args, **kwargs):
243
+ with disable_casts():
244
+ output = old_step(*args, **kwargs)
245
+ return output
246
+ return new_step
247
+
248
+ optimizer.step = MethodType(patch_step(optimizer.step), optimizer)
249
+
250
+ if optimizers_was_list:
251
+ if models_was_list:
252
+ return models, optimizers
253
+ else:
254
+ return models[0], optimizers
255
+ else:
256
+ if models_was_list:
257
+ if len(optimizers) == 0:
258
+ return models
259
+ else:
260
+ return models, optimizers[0]
261
+ else:
262
+ if len(optimizers) == 0:
263
+ return models[0]
264
+ else:
265
+ return models[0], optimizers[0]
apex/amp/_process_optimizer.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from ..fp16_utils import master_params_to_model_params
3
+ from ..multi_tensor_apply import multi_tensor_applier
4
+ from ._amp_state import maybe_print
5
+ import torch
6
+ from ..optimizers import FusedSGD
7
+
8
+
9
+ class AmpOptimizerState(object):
10
+ def __init__(self):
11
+ pass
12
+
13
+
14
+ def _master_params_to_model_params(self):
15
+ stash = self._amp_stash
16
+ if multi_tensor_applier.available:
17
+ if len(stash.all_fp16_params) > 0:
18
+ multi_tensor_applier(
19
+ stash.multi_tensor_scale,
20
+ stash.dummy_overflow_buf,
21
+ [stash.all_fp32_from_fp16_params, stash.all_fp16_params],
22
+ 1.0)
23
+ else:
24
+ for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
25
+ master_params_to_model_params(fp16_group, fp32_from_fp16_group)
26
+
27
+
28
+ def lazy_init_with_master_weights(self):
29
+ stash = self._amp_stash
30
+ stash.fp16_groups = []
31
+ stash.fp32_from_fp16_groups = []
32
+ stash.fp32_from_fp32_groups = []
33
+ for i, param_group in enumerate(self.param_groups):
34
+ # maybe_print("FP16_Optimizer processing param group {}:".format(i))
35
+ fp16_params_this_group = []
36
+ fp32_params_this_group = []
37
+ fp32_from_fp16_params_this_group = []
38
+ for i, param in enumerate(param_group['params']):
39
+ if param.requires_grad:
40
+ if param.type() == 'torch.cuda.HalfTensor':
41
+ # maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
42
+ # .format(param.size()))
43
+ fp16_params_this_group.append(param)
44
+ master_param = param.detach().clone().float()
45
+ master_param.requires_grad = True
46
+ param_group['params'][i] = master_param
47
+ fp32_from_fp16_params_this_group.append(master_param)
48
+ # Reset existing state dict key to the new master param.
49
+ # We still need to recast per-param state tensors, if any, to FP32.
50
+ if param in self.state:
51
+ self.state[master_param] = self.state.pop(param)
52
+ elif param.type() == 'torch.cuda.FloatTensor':
53
+ # maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
54
+ # .format(param.size()))
55
+ fp32_params_this_group.append(param)
56
+ param_group['params'][i] = param
57
+ else:
58
+ raise TypeError("Optimizer's parameters must be either "
59
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
60
+ "Received {}".format(param.type()))
61
+
62
+ stash.fp16_groups.append(fp16_params_this_group)
63
+ stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
64
+ stash.fp32_from_fp32_groups.append(fp32_params_this_group)
65
+
66
+ stash.all_fp16_params = []
67
+ for group in stash.fp16_groups:
68
+ stash.all_fp16_params += group
69
+
70
+ stash.all_fp32_from_fp16_params = []
71
+ for group in stash.fp32_from_fp16_groups:
72
+ stash.all_fp32_from_fp16_params += group
73
+
74
+ stash.all_fp32_from_fp32_params = []
75
+ for group in stash.fp32_from_fp32_groups:
76
+ stash.all_fp32_from_fp32_params += group
77
+
78
+ # all_fp16_grad_stash is only needed for fused optimizers.
79
+ stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
80
+ # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
81
+ stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
82
+
83
+ for param in stash.all_fp32_from_fp16_params:
84
+ param.grad = None
85
+
86
+ for param in stash.all_fp32_from_fp32_params:
87
+ param.grad = None
88
+
89
+ # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
90
+ self.load_state_dict(self.state_dict())
91
+
92
+
93
+ def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
94
+ grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0
95
+
96
+ # not much to do if scale == 1.0 and static scaling
97
+ if scaler.loss_scale() == 1.0 and not scaler.dynamic:
98
+ # Clear the stash.
99
+ for i in range(len(stashed_grads)):
100
+ stashed_grads[i] = None
101
+ return
102
+
103
+ if scale_override is not None:
104
+ grads_have_scale, stashed_have_scale, out_scale = scale_override
105
+
106
+ # This is a lot of python overhead...
107
+ grads_needing_unscale = []
108
+ grads_needing_unscale_with_stash = []
109
+ stashed = []
110
+ for param, stashed_grad in zip(params, stashed_grads):
111
+ if param.grad is None and stashed_grad is not None:
112
+ param.grad = stashed_grad
113
+ elif param.grad is not None and stashed_grad is None:
114
+ grads_needing_unscale.append(param.grad)
115
+ elif param.grad is not None and stashed_grad is not None:
116
+ grads_needing_unscale_with_stash.append(param.grad)
117
+ stashed.append(stashed_grad)
118
+ else: # param.grad is None and stashed_grad is None
119
+ continue
120
+
121
+ # unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale.
122
+ if len(grads_needing_unscale) > 0:
123
+ scaler.unscale(
124
+ grads_needing_unscale,
125
+ grads_needing_unscale,
126
+ None, # unused_scale, currently present to avoid API breakage elsewhere
127
+ models_are_masters=True,
128
+ scale_override=grads_have_scale/out_scale)
129
+
130
+ if len(grads_needing_unscale_with_stash) > 0:
131
+ scaler.unscale_with_stashed(
132
+ grads_needing_unscale_with_stash,
133
+ stashed,
134
+ grads_needing_unscale_with_stash,
135
+ scale_override=(grads_have_scale, stashed_have_scale, out_scale))
136
+
137
+ # Clear the stash.
138
+ for i in range(len(stashed_grads)):
139
+ stashed_grads[i] = None
140
+
141
+
142
+ def prepare_backward_with_master_weights(self):
143
+ stash = self._amp_stash
144
+
145
+ self._amp_lazy_init()
146
+
147
+ for i, param in enumerate(stash.all_fp16_params):
148
+ # Set up to leverage grad copy elision.
149
+ # This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
150
+ param.grad = None
151
+
152
+ # for i, param in enumerate(stash.all_fp32_from_fp16_params):
153
+ # stash.all_fp32_from_fp16_grad_stash[i] = param.grad
154
+
155
+ for i, param in enumerate(stash.all_fp32_from_fp32_params):
156
+ stash.all_fp32_from_fp32_grad_stash[i] = param.grad
157
+ # Set up to leverage grad copy elision:
158
+ param.grad = None
159
+
160
+
161
+ def post_backward_with_master_weights(self, scaler):
162
+ stash = self._amp_stash
163
+
164
+ self._amp_lazy_init()
165
+
166
+ # This is a lot of python overhead...
167
+ fp16_grads_needing_unscale = []
168
+ new_fp32_grads = []
169
+ fp16_grads_needing_unscale_with_stash = []
170
+ preexisting_fp32_grads = []
171
+ for fp16_param, fp32_param in zip(stash.all_fp16_params,
172
+ stash.all_fp32_from_fp16_params):
173
+ if fp16_param.grad is None and fp32_param.grad is not None:
174
+ continue
175
+ elif fp16_param.grad is not None and fp32_param.grad is None:
176
+ fp32_param.grad = torch.empty_like(fp32_param)
177
+ fp16_grads_needing_unscale.append(fp16_param.grad)
178
+ new_fp32_grads.append(fp32_param.grad)
179
+ elif fp16_param.grad is not None and fp32_param.grad is not None:
180
+ fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)
181
+ preexisting_fp32_grads.append(fp32_param.grad)
182
+ else: # fp16_param.grad is None and fp32_param.grad is None:
183
+ continue
184
+
185
+ if len(fp16_grads_needing_unscale) > 0:
186
+ scaler.unscale(
187
+ fp16_grads_needing_unscale,
188
+ new_fp32_grads,
189
+ scaler.loss_scale(),
190
+ models_are_masters=False)
191
+
192
+ if len(fp16_grads_needing_unscale_with_stash) > 0:
193
+ scaler.unscale_with_stashed(
194
+ fp16_grads_needing_unscale_with_stash,
195
+ preexisting_fp32_grads,
196
+ preexisting_fp32_grads)
197
+
198
+ # fp32 params can be treated as they would be in the "no_master_weights" case.
199
+ post_backward_models_are_masters(
200
+ scaler,
201
+ stash.all_fp32_from_fp32_params,
202
+ stash.all_fp32_from_fp32_grad_stash)
203
+
204
+
205
+ def lazy_init_no_master_weights(self):
206
+ stash = self._amp_stash
207
+ stash.all_fp16_params = []
208
+ stash.all_fp32_params = []
209
+ for i, param_group in enumerate(self.param_groups):
210
+ for i, param in enumerate(param_group['params']):
211
+ if param.type() == 'torch.cuda.HalfTensor':
212
+ stash.all_fp16_params.append(param)
213
+ elif param.type() == 'torch.cuda.FloatTensor':
214
+ stash.all_fp32_params.append(param)
215
+ else:
216
+ raise TypeError("Optimizer's parameters must be either "
217
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
218
+ "Received {}".format(param.type()))
219
+
220
+ stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
221
+ stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
222
+
223
+
224
+ def prepare_backward_no_master_weights(self):
225
+ stash = self._amp_stash
226
+
227
+ self._amp_lazy_init()
228
+
229
+ for i, param in enumerate(stash.all_fp16_params):
230
+ stash.all_fp16_grad_stash[i] = param.grad
231
+ # Set up to leverage grad copy elision:
232
+ param.grad = None
233
+
234
+ for i, param in enumerate(stash.all_fp32_params):
235
+ stash.all_fp32_grad_stash[i] = param.grad
236
+ # Set up to leverage grad copy elision:
237
+ param.grad = None
238
+
239
+
240
+ def post_backward_no_master_weights(self, scaler):
241
+ stash = self._amp_stash
242
+
243
+ self._amp_lazy_init()
244
+
245
+ split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
246
+ (stash.all_fp32_params, stash.all_fp32_grad_stash))
247
+
248
+ for params, stashed_grads in split_types:
249
+ post_backward_models_are_masters(scaler, params, stashed_grads)
250
+
251
+
252
+ #####################################################################################
253
+ # FusedSGD versions
254
+ #####################################################################################
255
+
256
+ # FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
257
+ # outside the kernel, so we must accumulate directly into the model grads.
258
+ def prepare_backward_with_master_weights_FusedSGD(self):
259
+ if self.materialize_master_grads:
260
+ prepare_backward_with_master_weights(self)
261
+ else:
262
+ stash = self._amp_stash
263
+
264
+ self._amp_lazy_init()
265
+
266
+ for i, param in enumerate(stash.all_fp16_params):
267
+ stash.all_fp16_grad_stash[i] = param.grad
268
+ # Set up to leverage grad copy elision:
269
+ param.grad = None
270
+
271
+ for i, param in enumerate(stash.all_fp32_from_fp32_params):
272
+ stash.all_fp32_from_fp32_grad_stash[i] = param.grad
273
+ # Set up to leverage grad copy elision:
274
+ param.grad = None
275
+
276
+
277
+ def post_backward_with_master_weights_FusedSGD(self, scaler):
278
+ if self.materialize_master_grads:
279
+ post_backward_with_master_weights(self, scaler)
280
+ else:
281
+ stash = self._amp_stash
282
+
283
+ self._amp_lazy_init()
284
+
285
+ grads_have_scale = scaler.loss_scale()
286
+ stashed_have_scale = self.most_recent_scale
287
+ out_scale = grads_have_scale
288
+ if self.scale_set_by_backward:
289
+ out_scale = min(grads_have_scale, self.most_recent_scale)
290
+
291
+ split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
292
+ (stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
293
+
294
+
295
+ # unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
296
+ # stashed_grads are scaled by self.most_recent_scale.
297
+ for params, stashed_grads in split_types:
298
+ post_backward_models_are_masters(scaler, params, stashed_grads,
299
+ (grads_have_scale, stashed_have_scale, out_scale))
300
+
301
+ self.most_recent_scale = out_scale
302
+ self.scale_set_by_backward = True
303
+
304
+
305
+ def prepare_backward_no_master_weights_FusedSGD(self):
306
+ prepare_backward_no_master_weights(self)
307
+
308
+
309
+ def post_backward_no_master_weights_FusedSGD(self, scaler):
310
+ post_backward_no_master_weights(self, scaler)
311
+
312
+
313
+ def _amp_lazy_init(self):
314
+ stash = self._amp_stash
315
+
316
+ if not stash.lazy_init_called:
317
+ self._lazy_init_maybe_master_weights()
318
+ stash.lazy_init_called = True
319
+
320
+
321
+ def _process_optimizer(optimizer, properties):
322
+ if hasattr(optimizer, "_amp_stash"):
323
+ raise RuntimeError("A given optimizer should only be passed through amp.initialize once.")
324
+ else:
325
+ optimizer._amp_stash = AmpOptimizerState()
326
+
327
+ optimizer._amp_stash.lazy_init_called = False
328
+ optimizer._amp_stash.already_patched = False
329
+ optimizer._amp_stash.params_have_scaled_gradients = False
330
+
331
+ for name in ("_lazy_init_maybe_master_weights",
332
+ "_master_params_to_model_params",
333
+ "_prepare_amp_backward",
334
+ "_post_amp_backward",
335
+ "_amp_lazy_init"):
336
+ if hasattr(optimizer, name):
337
+ raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
338
+
339
+ # TODO: Centralize exposure and import error checking for the C backend.
340
+ if multi_tensor_applier.available:
341
+ import amp_C
342
+ optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
343
+ optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
344
+ optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
345
+
346
+ if properties.master_weights:
347
+ optimizer._lazy_init_maybe_master_weights = types.MethodType(
348
+ lazy_init_with_master_weights, optimizer)
349
+
350
+ optimizer._master_params_to_model_params = types.MethodType(
351
+ _master_params_to_model_params, optimizer)
352
+
353
+ old_step = optimizer.step
354
+ def new_step(self, closure=None):
355
+ if closure is not None:
356
+ raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
357
+ retval = old_step()
358
+ if not isinstance(self, FusedSGD):
359
+ self._master_params_to_model_params()
360
+ # Clear the master grads that wouldn't be zeroed by model.zero_grad()
361
+ for param in self._amp_stash.all_fp32_from_fp16_params:
362
+ param.grad = None
363
+ return retval
364
+ optimizer.step = types.MethodType(new_step, optimizer)
365
+
366
+ old_zero_grad = optimizer.zero_grad
367
+ def new_zero_grad(self):
368
+ stash = self._amp_stash
369
+ self._amp_lazy_init()
370
+ # Zero the model grads.
371
+ for param in stash.all_fp16_params:
372
+ if param.grad is not None:
373
+ param.grad.detach_()
374
+ param.grad.zero_()
375
+ for param in stash.all_fp32_from_fp32_params:
376
+ if param.grad is not None:
377
+ param.grad.detach_()
378
+ param.grad.zero_()
379
+ # Clear the master grads that are independent of model grads
380
+ for param in self._amp_stash.all_fp32_from_fp16_params:
381
+ param.grad = None
382
+ optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
383
+
384
+ if isinstance(optimizer, FusedSGD):
385
+ optimizer._prepare_amp_backward = types.MethodType(
386
+ prepare_backward_with_master_weights_FusedSGD, optimizer)
387
+ optimizer._post_amp_backward = types.MethodType(
388
+ post_backward_with_master_weights_FusedSGD, optimizer)
389
+ else:
390
+ optimizer._prepare_amp_backward = types.MethodType(
391
+ prepare_backward_with_master_weights, optimizer)
392
+ optimizer._post_amp_backward = types.MethodType(
393
+ post_backward_with_master_weights, optimizer)
394
+ else:
395
+ optimizer._lazy_init_maybe_master_weights = types.MethodType(
396
+ lazy_init_no_master_weights, optimizer)
397
+
398
+ if isinstance(optimizer, FusedSGD):
399
+ optimizer._prepare_amp_backward = types.MethodType(
400
+ prepare_backward_no_master_weights_FusedSGD, optimizer)
401
+ optimizer._post_amp_backward = types.MethodType(
402
+ post_backward_no_master_weights_FusedSGD, optimizer)
403
+ else:
404
+ optimizer._prepare_amp_backward = types.MethodType(
405
+ prepare_backward_no_master_weights, optimizer)
406
+ optimizer._post_amp_backward = types.MethodType(
407
+ post_backward_no_master_weights, optimizer)
408
+
409
+ optimizer._amp_lazy_init = types.MethodType(_amp_lazy_init, optimizer)
410
+
411
+ old_add_param_group = optimizer.add_param_group
412
+
413
+ def new_add_param_group(self, new_group):
414
+ stash = self._amp_stash
415
+
416
+ if not stash.lazy_init_called:
417
+ self._lazy_init_maybe_master_weights()
418
+ stash.lazy_init_called = True
419
+
420
+ assert isinstance(new_group, dict), "param group must be a dict"
421
+
422
+ new_params = new_group['params']
423
+ if isinstance(new_params, torch.Tensor):
424
+ new_group['params'] = [new_params]
425
+ elif isinstance(new_params, set):
426
+ raise TypeError('optimizer parameters need to be organized in ordered collections, but '
427
+ 'the ordering of tensors in sets will change between runs. Please use a list instead.')
428
+ else:
429
+ new_group['params'] = list(new_params)
430
+
431
+ if properties.master_weights:
432
+ # Mutate new_group in-place to use FP32 master params
433
+ fp16_params_this_group = []
434
+ fp32_params_this_group = []
435
+ fp32_from_fp16_params_this_group = []
436
+ for i, param in enumerate(new_group['params']):
437
+ if param.requires_grad:
438
+ if param.type() == 'torch.cuda.HalfTensor':
439
+ fp16_params_this_group.append(param)
440
+ master_param = param.detach().clone().float()
441
+ master_param.requires_grad = True
442
+ new_group['params'][i] = master_param
443
+ fp32_from_fp16_params_this_group.append(master_param)
444
+ elif param.type() == 'torch.cuda.FloatTensor':
445
+ fp32_params_this_group.append(param)
446
+ new_group['params'][i] = param
447
+ else:
448
+ raise TypeError("Optimizer's parameters must be either "
449
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
450
+ "Received {}".format(param.type()))
451
+
452
+ stash.fp16_groups.append(fp16_params_this_group)
453
+ stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
454
+ stash.fp32_from_fp32_groups.append(fp32_params_this_group)
455
+
456
+ stash.all_fp16_params += fp16_params_this_group
457
+ stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group
458
+ stash.all_fp32_from_fp32_params += fp32_params_this_group
459
+
460
+ # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
461
+ stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group]
462
+
463
+ # It should be ok to let params be added with existing .grad attributes.
464
+ # for param in fp16_params_this_group:
465
+ # param.grad = None
466
+
467
+ # for param in fp32_from_fp16_params_this_group:
468
+ # param.grad = None
469
+
470
+ # for param in stash.fp32_params_this_group:
471
+ # param.grad = None
472
+ else:
473
+ for param in new_group['params']:
474
+ if param.type() == 'torch.cuda.HalfTensor':
475
+ stash.all_fp16_params.append(param)
476
+ stash.all_fp16_grad_stash.append(None)
477
+ elif param.type() == 'torch.cuda.FloatTensor':
478
+ stash.all_fp32_params.append(param)
479
+ stash.all_fp32_grad_stash.append(None)
480
+ else:
481
+ raise TypeError("Optimizer's parameters must be either "
482
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
483
+ "Received {}".format(param.type()))
484
+
485
+ old_add_param_group(new_group)
486
+
487
+ optimizer.add_param_group = types.MethodType(new_add_param_group, optimizer)
488
+
489
+ return optimizer
apex/amp/amp.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import itertools
3
+
4
+ import torch
5
+
6
+ from . import compat, rnn_compat, utils, wrap
7
+ from .handle import AmpHandle, NoOpHandle
8
+ from .lists import functional_overrides, torch_overrides, tensor_overrides
9
+ from ._amp_state import _amp_state
10
+ from .frontend import *
11
+
12
+
13
+ _DECORATOR_HANDLE = None
14
+ _USER_CAST_REGISTRY = set()
15
+ _USER_PROMOTE_REGISTRY = set()
16
+
17
+
18
+ def _decorator_helper(orig_fn, cast_fn, wrap_fn):
19
+ def wrapper(*args, **kwargs):
20
+ handle = _DECORATOR_HANDLE
21
+ if handle is None or not handle.is_active():
22
+ return orig_fn(*args, **kwargs)
23
+ inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
24
+ handle.verbose)
25
+ return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
26
+ return wrapper
27
+
28
+
29
+ # Decorator form
30
+ def half_function(fn):
31
+ from apex import deprecated_warning
32
+ deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
33
+ wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
34
+ return _decorator_helper(fn, utils.maybe_half, wrap_fn)
35
+
36
+
37
+ def float_function(fn):
38
+ from apex import deprecated_warning
39
+ deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
40
+ wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
41
+ return _decorator_helper(fn, utils.maybe_float, wrap_fn)
42
+
43
+
44
+ def promote_function(fn):
45
+ from apex import deprecated_warning
46
+ deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
47
+ wrap_fn = functools.partial(wrap.make_promote_wrapper)
48
+ return _decorator_helper(fn, utils.maybe_float, wrap_fn)
49
+
50
+
51
+ # Registry form
52
+ def register_half_function(module, name):
53
+ if not hasattr(module, name):
54
+ raise ValueError('No function named {} in module {}.'.format(
55
+ name, module))
56
+ _USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
57
+
58
+
59
+ def register_float_function(module, name):
60
+ if not hasattr(module, name):
61
+ raise ValueError('No function named {} in module {}.'.format(
62
+ name, module))
63
+ _USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
64
+
65
+
66
+ def register_promote_function(module, name):
67
+ if not hasattr(module, name):
68
+ raise ValueError('No function named {} in module {}.'.format(
69
+ name, module))
70
+ _USER_PROMOTE_REGISTRY.add((module, name))
71
+
72
+
73
+ # Top-level function to insert _all_ the hooks.
74
+ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False):
75
+ global _DECORATOR_HANDLE
76
+
77
+ if not enabled:
78
+ handle = NoOpHandle()
79
+ _DECORATOR_HANDLE = handle
80
+ return handle
81
+
82
+ handle = AmpHandle(loss_scale, enable_caching, verbose)
83
+
84
+ # 0) Force-{fp16, fp32} for user-annotated functions
85
+ for mod, fn, cast_fn in _USER_CAST_REGISTRY:
86
+ try_caching = (cast_fn == utils.maybe_half)
87
+ wrap.cached_cast(mod, fn, cast_fn, handle,
88
+ try_caching, verbose)
89
+ _USER_CAST_REGISTRY.clear()
90
+
91
+ # 0.5) Force-promote for user-annotated functions
92
+ for mod, fn in _USER_PROMOTE_REGISTRY:
93
+ wrap.promote(mod, fn, handle, verbose)
94
+ _USER_PROMOTE_REGISTRY.clear()
95
+
96
+ # 1) Force-{fp16, fp32} on white- / black-list functions
97
+ override_modules = [functional_overrides,
98
+ torch_overrides,
99
+ tensor_overrides]
100
+ cast_table = [('FP16_FUNCS', utils.maybe_half),
101
+ ('FP32_FUNCS', utils.maybe_float)]
102
+ for module, (list_name, cast_fn) in itertools.product(override_modules,
103
+ cast_table):
104
+ for fn in getattr(module, list_name):
105
+ try_caching = (cast_fn == utils.maybe_half)
106
+ wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
107
+ try_caching, verbose)
108
+
109
+ # 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
110
+ # methods on FloatTensor, since they're distinct types.
111
+ if compat.tensor_is_float_tensor():
112
+ for fn in tensor_overrides.FP16_FUNCS:
113
+ wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,
114
+ handle, try_caching=True, verbose=verbose)
115
+ for fn in tensor_overrides.FP32_FUNCS:
116
+ wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,
117
+ handle, try_caching=False, verbose=verbose)
118
+
119
+ # 2) Enable type-promotion on multi-arg functions and methods.
120
+ # NB: special handling for sequence fns (e.g. `torch.cat`).
121
+ promote_modules = [torch_overrides, tensor_overrides]
122
+ promote_table = [('CASTS', wrap.promote),
123
+ ('SEQUENCE_CASTS', wrap.sequence_promote)]
124
+ for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,
125
+ promote_table):
126
+ for fn in getattr(promote_mod, list_name):
127
+ promote_fn(promote_mod.MODULE, fn, handle, verbose)
128
+
129
+ # 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
130
+ if compat.tensor_is_float_tensor():
131
+ for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,
132
+ torch.cuda.HalfTensor],
133
+ promote_table):
134
+ for fn in getattr(tensor_overrides, list_name):
135
+ promote_fn(cls, fn, handle, verbose)
136
+
137
+ # 3) For any in-place version of a blacklist function, error if any input is fp16.
138
+ # NB: this is overly conservative.
139
+ for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
140
+ wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
141
+
142
+ # 3.5) For any in-place blacklist method, error if called on fp16 tensor
143
+ for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
144
+ wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
145
+ if compat.tensor_is_float_tensor():
146
+ wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)
147
+
148
+ # 4) For other in-place methods, match the type of self tensor
149
+ for fn in utils.as_inplace(itertools.chain(
150
+ tensor_overrides.FP16_FUNCS,
151
+ tensor_overrides.CASTS)):
152
+ wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
153
+ if compat.tensor_is_float_tensor():
154
+ wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
155
+ wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
156
+
157
+ # 5) RNNs + RNN cells are whitelisted specially
158
+ if rnn_compat.has_old_rnns():
159
+ wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)
160
+ if not rnn_compat.has_old_rnns():
161
+ # Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable.
162
+ torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
163
+ # Wrap all the rnns
164
+ for x in rnn_compat.RNN_NAMES:
165
+ wrap.new_rnn_cast(x.upper(), handle, verbose)
166
+
167
+ # Wrap all the RNN cells
168
+ rnn_compat.whitelist_rnn_cells(handle, verbose)
169
+
170
+ # 6) Place error+print message on banned functions.
171
+ # Or, if allow_banned, then cast to FP32.
172
+ for fn, err_msg in functional_overrides.BANNED_FUNCS:
173
+ if allow_banned:
174
+ wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,
175
+ handle, try_caching=True, verbose=verbose)
176
+ else:
177
+ wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
178
+
179
+ _DECORATOR_HANDLE = handle
180
+
181
+ _amp_state.handle = handle
182
+
183
+ return handle
apex/amp/compat.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # True for post-0.4, when Variables/Tensors merged.
4
+ def variable_is_tensor():
5
+ v = torch.autograd.Variable()
6
+ return isinstance(v, torch.Tensor)
7
+
8
+ def tensor_is_variable():
9
+ x = torch.Tensor()
10
+ return type(x) == torch.autograd.Variable
11
+
12
+ # False for post-0.4
13
+ def tensor_is_float_tensor():
14
+ x = torch.Tensor()
15
+ return type(x) == torch.FloatTensor
16
+
17
+ # Akin to `torch.is_tensor`, but returns True for Variable
18
+ # objects in pre-0.4.
19
+ def is_tensor_like(x):
20
+ return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)
21
+
22
+ # Wraps `torch.is_floating_point` if present, otherwise checks
23
+ # the suffix of `x.type()`.
24
+ def is_floating_point(x):
25
+ if hasattr(torch, 'is_floating_point'):
26
+ return torch.is_floating_point(x)
27
+ try:
28
+ torch_type = x.type()
29
+ return torch_type.endswith('FloatTensor') or \
30
+ torch_type.endswith('HalfTensor') or \
31
+ torch_type.endswith('DoubleTensor')
32
+ except AttributeError:
33
+ return False
34
+
35
+ def scalar_python_val(x):
36
+ if hasattr(x, 'item'):
37
+ return x.item()
38
+ else:
39
+ if isinstance(x, torch.autograd.Variable):
40
+ return x.data[0]
41
+ else:
42
+ return x[0]
43
+
44
+ # Accounts for the possibility that some ops may be removed from a namespace.
45
+ def filter_attrs(module, attrs):
46
+ return list(attrname for attrname in attrs if hasattr(module, attrname))
apex/amp/frontend.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+ from ._initialize import _initialize
6
+ from ._amp_state import _amp_state, warn_or_err, maybe_print
7
+
8
+
9
+ class Properties(object):
10
+ """
11
+ This class has two purposes: to establish a set of default properties,
12
+ and to route setting of these attributes through __setattr__ so that (in theory)
13
+ they can be checked for consistency with other existing args.
14
+ """
15
+ def __init__(self):
16
+ self.options = {
17
+ "enabled" : False,
18
+ "opt_level" : None,
19
+ "cast_model_type" : None,
20
+ "patch_torch_functions" : False,
21
+ "keep_batchnorm_fp32" : None,
22
+ "master_weights" : None,
23
+ "loss_scale" : 1.0,
24
+ # Reserved for future functionality
25
+ # "fused_optimizer" : False,
26
+ # "enable_ddp_interop" : False,
27
+ }
28
+
29
+ """
30
+ This function allows updating several options at a time without routing through
31
+ __setattr__ checks, to avoid "you can't get there from here" scenarios.
32
+ Currently not intended to be exposed; users are expected to select an opt_level
33
+ and apply consistent modifications.
34
+ """
35
+ def _update_options_dict(self, new_options):
36
+ for k, v in new_options:
37
+ if k in self.options:
38
+ self.options[k] = v
39
+ else:
40
+ raise ValueError("Tried to set unexpected option {}".format(k))
41
+ """
42
+ The members of "options" are not direct attributes of self, so access attempts
43
+ will roll down to __getattr__. This borrows from the logic in torch.nn.Module.
44
+ """
45
+ def __getattr__(self, name):
46
+ if "options" in self.__dict__:
47
+ options = self.__dict__["options"]
48
+ if name in options:
49
+ return options[name]
50
+ raise AttributeError("'{}' object has no attribute '{}'".format(
51
+ type(self).__name__, name))
52
+
53
+ def __setattr__(self, name, value):
54
+ if "options" in self.__dict__:
55
+ if name in self.options:
56
+ # print("setting {} {}".format(name, value))
57
+ if name == "cast_model_type":
58
+ if self.opt_level == "O1" and value is not None:
59
+ if value is not False:
60
+ if value is not torch.float32:
61
+ warn_or_err("O1 inserts casts around Torch functions rather than "
62
+ "model weights, so with O1, the model weights themselves "
63
+ "should remain FP32. If you wish to cast the model to a "
64
+ "different type, use opt_level='O2' or 'O3'. " +
65
+ "cast_model_type was {}".format(value))
66
+ self.options[name] = value
67
+ elif name == "patch_torch_functions":
68
+ if self.opt_level != "O1" and value:
69
+ warn_or_err("Currently, patch_torch_functions=True should only be set by "
70
+ "selecting opt_level='O1'.")
71
+ self.options[name] = value
72
+ elif name == "keep_batchnorm_fp32":
73
+ if self.opt_level == "O1" and value is not None:
74
+ warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
75
+ "to run in FP32, so keep_batchnorm_fp32 should be None." +
76
+ " keep_batchnorm_fp32 was {}".format(value))
77
+ if value == "False":
78
+ self.options[name] = False
79
+ elif value == "True":
80
+ self.options[name] = True
81
+ else:
82
+ assert (value is True or value is False or value is None),\
83
+ "keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\
84
+ "or None, found keep_batchnorm_fp32={}".format(value)
85
+ self.options[name] = value
86
+ elif name == "master_weights":
87
+ if self.opt_level == "O1" and value is not None:
88
+ warn_or_err("It doesn't make sense to use master_weights with O1. "
89
+ "With O1, your model weights themselves should be FP32.")
90
+ self.options[name] = value
91
+ elif name == "loss_scale":
92
+ if value == "dynamic":
93
+ self.options[name] = value
94
+ else:
95
+ self.options[name] = float(value)
96
+ else:
97
+ self.options[name] = value
98
+ else:
99
+ super(Properties, self).__setattr__(name, value)
100
+
101
+
102
+ """ O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """
103
+
104
+ class O3:
105
+ brief = "O3: Pure FP16 training."
106
+ more = "Calls .half() on your model, converting the entire model to FP16.\n"\
107
+ "A casting operation is also inserted to cast incoming Tensors to FP16,\n"\
108
+ "so you don't need to change your data pipeline.\n"\
109
+ "This mode is useful for establishing a performance ceiling.\n"\
110
+ "It's also possible training may 'just work' in this mode.\n"\
111
+ "If not, try other optimization levels."
112
+
113
+ def __call__(self, properties):
114
+ properties.enabled = True
115
+ properties.opt_level = "O3"
116
+ properties.cast_model_type = torch.float16
117
+ properties.patch_torch_functions = False
118
+ properties.keep_batchnorm_fp32 = False
119
+ properties.master_weights = False
120
+ properties.loss_scale = 1.0
121
+ # properties.fused_optimizer = False
122
+ # properties.enable_ddp_interop = False
123
+ return properties # modified in place so this isn't really necessary
124
+
125
+
126
+ class O2:
127
+ brief = "O2: FP16 training with FP32 batchnorm and FP32 master weights.\n"
128
+ more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\
129
+ "to FP16. Batchnorms are retained in FP32 for additional stability.\n"\
130
+ "The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\
131
+ "your data pipeline.\n"\
132
+ "O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\
133
+ "these master weights, then copy the master weights into the FP16 model weights.\n"\
134
+ "Master weights can also improve convergence and stability."
135
+
136
+ def __call__(self, properties):
137
+ properties.enabled = True
138
+ properties.opt_level = "O2"
139
+ properties.cast_model_type = torch.float16
140
+ properties.patch_torch_functions = False
141
+ properties.keep_batchnorm_fp32 = True
142
+ properties.master_weights = True
143
+ properties.loss_scale = "dynamic"
144
+ # properties.fused_optimizer = False
145
+ # properties.enable_ddp_interop = False
146
+ return properties # modified in place so this isn't really necessary
147
+
148
+
149
+ class O1:
150
+ brief = "O1: Insert automatic casts around Pytorch functions and Tensor methods.\n"
151
+ more = "The type of your model's weights is not altered. However, internally,\n"\
152
+ "Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\
153
+ "while operations that might benefit from the additional stability of FP32 are patched\n"\
154
+ "to cast their inputs to fp32.\n"\
155
+ "O1 is the safest way to try mixed precision training, and is recommended when\n"\
156
+ "trying mixed precision training for the first time."
157
+
158
+ def __call__(self, properties):
159
+ properties.enabled = True
160
+ properties.opt_level = "O1"
161
+ properties.cast_model_type = None
162
+ properties.patch_torch_functions = True
163
+ properties.keep_batchnorm_fp32 = None
164
+ properties.master_weights = None
165
+ properties.loss_scale = "dynamic"
166
+ # properties.fused_optimizer = False
167
+ # properties.enable_ddp_interop = False
168
+ return properties # modified in place so this isn't really necessary
169
+
170
+
171
+ class O0:
172
+ brief = "O0: Pure FP32 training.\n"
173
+ more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
174
+ "types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
175
+ "FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n"
176
+
177
+ def __call__(self, properties):
178
+ properties.enabled = True
179
+ properties.opt_level = "O0"
180
+ properties.cast_model_type = torch.float32
181
+ properties.patch_torch_functions = False
182
+ properties.keep_batchnorm_fp32 = None
183
+ properties.master_weights = False
184
+ properties.loss_scale = 1.0
185
+ # properties.fused_optimizer = False
186
+ # properties.enable_ddp_interop = False
187
+ return properties # modified in place so this isn't really necessary
188
+
189
+
190
+ opt_levels = {"O3": O3(),
191
+ "O2": O2(),
192
+ "O1": O1(),
193
+ "O0": O0()}
194
+
195
+
196
+ # allow user to directly pass Properties struct as well?
197
+ def initialize(
198
+ models,
199
+ optimizers=None,
200
+ enabled=True,
201
+ opt_level="O1",
202
+ cast_model_type=None,
203
+ patch_torch_functions=None,
204
+ keep_batchnorm_fp32=None,
205
+ master_weights=None,
206
+ loss_scale=None,
207
+ cast_model_outputs=None,
208
+ num_losses=1,
209
+ verbosity=1,
210
+ min_loss_scale=None,
211
+ max_loss_scale=2.**24
212
+ ):
213
+ """
214
+ Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
215
+ chosen ``opt_level`` and overridden properties, if any.
216
+
217
+ ``amp.initialize`` should be called **after** you have finished
218
+ constructing your model(s) and
219
+ optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.
220
+ See `Distributed training`_ in the Imagenet example.
221
+
222
+ Currently, ``amp.initialize`` should only be called **once**,
223
+ although it can process an arbitrary number of
224
+ models and optimizers (see the corresponding `Advanced Amp Usage topic`_).
225
+ If you think your use case requires ``amp.initialize`` to be called more than once,
226
+ `let us know`_.
227
+
228
+ Any property keyword argument that is not ``None`` will be interpreted as a manual override.
229
+
230
+ To prevent having to rewrite anything else in your script, name the returned models/optimizers
231
+ to replace the passed models/optimizers, as in the code sample below.
232
+
233
+ Args:
234
+ models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
235
+ optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast.
236
+ REQUIRED for training, optional for inference.
237
+ enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
238
+ should run as if Amp were not present.
239
+ opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
240
+ "O0", "O1", "O2", and "O3", explained in detail above.
241
+ cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
242
+ above.
243
+ patch_torch_functions (bool, optional, default=None): Optional property override.
244
+ keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
245
+ passed as a string, must be the string "True" or "False".
246
+ master_weights (bool, optional, default=None): Optional property override.
247
+ loss_scale (float or str, optional, default=None): Optional property override. If passed as a string,
248
+ must be a string representing a number, e.g., "128.0", or the string "dynamic".
249
+ cast_model_outputs (torch.dtype, optional, default=None): Option to ensure that the outputs
250
+ of your model(s) are always cast to a particular type regardless of ``opt_level``.
251
+ num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward
252
+ passes you plan to use. When used in conjunction with the ``loss_id`` argument to
253
+ ``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,
254
+ which can improve stability. See "Multiple models/optimizers/losses"
255
+ under `Advanced Amp Usage`_ for examples. If ``num_losses`` is left to 1, Amp will still
256
+ support multiple losses/backward passes, but use a single global loss scale
257
+ for all of them.
258
+ verbosity (int, default=1): Set to 0 to suppress Amp-related output.
259
+ min_loss_scale (float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic
260
+ loss scaling. The default value of None means that no floor is imposed.
261
+ If dynamic loss scaling is not used, `min_loss_scale` is ignored.
262
+ max_loss_scale (float, default=2.**24): Sets a ceiling for the loss scale values that can be chosen by
263
+ dynamic loss scaling. If dynamic loss scaling is not used, `max_loss_scale` is ignored.
264
+
265
+ Returns:
266
+ Model(s) and optimizer(s) modified according to the ``opt_level``.
267
+ If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will
268
+ also be a list.
269
+
270
+ Permissible invocations::
271
+
272
+ model, optim = amp.initialize(model, optim,...)
273
+ model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)
274
+ [model1, model2], optim = amp.initialize([model1, model2], optim,...)
275
+ [model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)
276
+
277
+ # This is not an exhaustive list of the cross product of options that are possible,
278
+ # just a set of examples.
279
+ model, optim = amp.initialize(model, optim, opt_level="O0")
280
+ model, optim = amp.initialize(model, optim, opt_level="O0", loss_scale="dynamic"|128.0|"128.0")
281
+
282
+ model, optim = amp.initialize(model, optim, opt_level="O1") # uses "loss_scale="dynamic" default
283
+ model, optim = amp.initialize(model, optim, opt_level="O1", loss_scale=128.0|"128.0")
284
+
285
+ model, optim = amp.initialize(model, optim, opt_level="O2") # uses "loss_scale="dynamic" default
286
+ model, optim = amp.initialize(model, optim, opt_level="O2", loss_scale=128.0|"128.0")
287
+ model, optim = amp.initialize(model, optim, opt_level="O2", keep_batchnorm_fp32=True|False|"True"|"False")
288
+
289
+ model, optim = amp.initialize(model, optim, opt_level="O3") # uses loss_scale=1.0 default
290
+ model, optim = amp.initialize(model, optim, opt_level="O3", loss_scale="dynamic"|128.0|"128.0")
291
+ model, optim = amp.initialize(model, optim, opt_level="O3", keep_batchnorm_fp32=True|False|"True"|"False")
292
+
293
+ The `Imagenet example`_ demonstrates live use of various opt_levels and overrides.
294
+
295
+ .. _`Distributed training`:
296
+ https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training
297
+
298
+ .. _`Imagenet example`:
299
+ https://github.com/NVIDIA/apex/tree/master/examples/imagenet
300
+
301
+ .. _`Advanced Amp Usage`:
302
+ https://nvidia.github.io/apex/advanced.html
303
+
304
+ .. _`Advanced Amp Usage topic`:
305
+ https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses
306
+
307
+ .. _`let us know`:
308
+ https://github.com/NVIDIA/apex/issues
309
+ """
310
+ from apex import deprecated_warning
311
+ deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
312
+ _amp_state.opt_properties = Properties()
313
+ _amp_state.verbosity = verbosity
314
+
315
+ if not enabled:
316
+ if optimizers is None:
317
+ return models
318
+ else:
319
+ return models, optimizers
320
+
321
+ if not torch.backends.cudnn.enabled:
322
+ raise RuntimeError(
323
+ "Amp requires torch.backends.cudnn.enabled = True")
324
+
325
+ if opt_level not in opt_levels:
326
+ raise RuntimeError(
327
+ "Unexpected optimization level {}. ".format(opt_level) +
328
+ "Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
329
+ "not the number zero.")
330
+ else:
331
+ _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
332
+ maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
333
+ maybe_print("Defaults for this optimization level are:", True)
334
+ for k, v in _amp_state.opt_properties.options.items():
335
+ maybe_print("{:22} : {}".format(k, v), True)
336
+
337
+ _amp_state.min_loss_scale = min_loss_scale
338
+ _amp_state.max_loss_scale = max_loss_scale
339
+
340
+ maybe_print("Processing user overrides (additional kwargs that are not None)...", True)
341
+ # I chose to have the keyword arguments listed directly in the argument list,
342
+ # instead of **kwargs, so I can't use kwargs.items() here.
343
+ if enabled is not None:
344
+ _amp_state.opt_properties.enabled = enabled
345
+ if opt_level is not None:
346
+ _amp_state.opt_properties.opt_level = opt_level
347
+ if cast_model_type is not None:
348
+ _amp_state.opt_properties.cast_model_type = cast_model_type
349
+ if patch_torch_functions is not None:
350
+ _amp_state.opt_properties.patch_torch_functions = patch_torch_functions
351
+ if keep_batchnorm_fp32 is not None:
352
+ _amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
353
+ if master_weights is not None:
354
+ _amp_state.opt_properties.master_weights = master_weights
355
+ if loss_scale is not None:
356
+ _amp_state.opt_properties.loss_scale = loss_scale
357
+
358
+ maybe_print("After processing overrides, optimization options are:", True)
359
+ for k, v in _amp_state.opt_properties.options.items():
360
+ maybe_print("{:22} : {}".format(k, v), True)
361
+
362
+ return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
363
+
364
+
365
+ def state_dict(destination=None):
366
+ if destination is None:
367
+ destination = OrderedDict()
368
+
369
+ for idx, loss_scaler in enumerate(_amp_state.loss_scalers):
370
+ destination['loss_scaler%d' % idx] = {
371
+ 'loss_scale': loss_scaler.loss_scale(),
372
+ 'unskipped': loss_scaler._unskipped,
373
+ }
374
+ return destination
375
+
376
+
377
+ def load_state_dict(state_dict):
378
+ # Check if state_dict containes the same number of loss_scalers as current setup
379
+ if len(state_dict) != len(_amp_state.loss_scalers):
380
+ print('Warning: state_dict contains {} entries, while {} loss_scalers are used'.format(
381
+ len(state_dict), len(_amp_state.loss_scalers)))
382
+
383
+ state_dict = state_dict.copy()
384
+
385
+ nb_loss_scalers = len(_amp_state.loss_scalers)
386
+ unexpected_keys = []
387
+ # Initialize idx outside, since unexpected_keys will increase it if enumerate is used
388
+ idx = 0
389
+ for key in state_dict:
390
+ if 'loss_scaler' not in key:
391
+ unexpected_keys.append(key)
392
+ else:
393
+ if idx > (nb_loss_scalers - 1):
394
+ print('Skipping loss_scaler[{}], since num_losses was set to {}'.format(
395
+ idx, nb_loss_scalers))
396
+ break
397
+ _amp_state.loss_scalers[idx]._loss_scale = state_dict[key]['loss_scale']
398
+ _amp_state.loss_scalers[idx]._unskipped = state_dict[key]['unskipped']
399
+ idx += 1
400
+
401
+ if len(unexpected_keys) > 0:
402
+ raise RuntimeError(
403
+ 'Error(s) in loading state_dict. Unexpected key(s) in state_dict: {}. '.format(
404
+ ', '.join('"{}"'.format(k) for k in unexpected_keys)))
405
+
406
+
407
+ # TODO: is this necessary/useful?
408
+ # def check_option_consistency(enabled=True,
409
+ # opt_level=None,
410
+ # cast_model_type=None,
411
+ # patch_torch_functions=None,
412
+ # keep_batchnorm_fp32=None,
413
+ # master_weights=None,
414
+ # loss_scale=None,
415
+ # enable_ddp_interop=None,
416
+ # hard_override=False):
417
+ # """
418
+ # Utility function that enables users to quickly check if the option combination they intend
419
+ # to use is permitted. ``check_option_consistency`` does not require models or optimizers
420
+ # to be constructed, and can be called at any point in the script. ``check_option_consistency``
421
+ # is totally self-contained; it does not set any amp global state or affect anything outside
422
+ # of itself.
423
+ # """
424
+ #
425
+ # if not enabled:
426
+ # return
427
+ #
428
+ # if opt_level not in opt_levels:
429
+ # raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
430
+ # else:
431
+ # opt_properties = opt_levels[opt_level](Properties())
432
+ # print("Selected optimization level {}", opt_levels[opt_level].brief)
433
+ # print("Defaults for this optimization level are:")
434
+ # for k, v in opt_properties.options:
435
+ # print("{:22} : {}".format(k, v))
436
+ #
437
+ # print("Processing user overrides (additional kwargs that are not None)...")
438
+ # for k, v in kwargs:
439
+ # if k not in _amp_state.opt_properties.options:
440
+ # raise RuntimeError("Unexpected kwarg {}".format(k))
441
+ # if v is not None:
442
+ # setattr(opt_properties, k, v)
443
+ #
444
+ # print("After processing overrides, optimization options are:")
445
+ # for k, v in opt_properties.options:
446
+ # print("{:22} : {}".format(k, v))
apex/amp/handle.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+ import sys
4
+ import torch
5
+
6
+ from . import utils
7
+ from .opt import OptimWrapper
8
+ from .scaler import LossScaler
9
+ from ._amp_state import _amp_state, master_params, maybe_print
10
+
11
+ if torch.distributed.is_available():
12
+ from ..parallel.LARC import LARC
13
+
14
+
15
+ # There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
16
+ @contextlib.contextmanager
17
+ def scale_loss(loss,
18
+ optimizers,
19
+ loss_id=0,
20
+ model=None,
21
+ delay_unscale=False,
22
+ delay_overflow_check=False):
23
+ """
24
+ On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
25
+ ``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
26
+
27
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
28
+ scaled_loss.backward()
29
+
30
+ On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs
31
+ and unscaled, so that ``optimizer.step()`` can be called.
32
+
33
+ .. note::
34
+ If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and
35
+ can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)
36
+ any FP16 gradients are copied to FP32 master gradients before being unscaled.
37
+ ``optimizer.step()`` will then apply the unscaled master gradients to the master params.
38
+
39
+ .. warning::
40
+ If Amp is using explicit FP32 master params, only the FP32 master gradients will be
41
+ unscaled. The direct ``.grad`` attributes of any FP16
42
+ model params will remain scaled after context manager exit.
43
+ This subtlety affects gradient clipping. See "Gradient clipping" under
44
+ `Advanced Amp Usage`_ for best practices.
45
+
46
+ Args:
47
+ loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context
48
+ manager yields is simply ``loss.float()*loss_scale``, so in principle
49
+ ``loss`` could have more than one element, as long as you call
50
+ ``backward()`` on ``scaled_loss`` appropriately within the context manager body.
51
+ optimizers: All optimizer(s) for which the current backward pass is creating gradients.
52
+ Must be an optimizer or list of optimizers returned from an earlier call
53
+ to ``amp.initialize``. For example use with multiple optimizers, see
54
+ "Multiple models/optimizers/losses" under `Advanced Amp Usage`_.
55
+ loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument
56
+ to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id``
57
+ must be an integer between 0 and ``num_losses`` that tells Amp which loss is
58
+ being used for the current backward pass. See "Multiple models/optimizers/losses"
59
+ under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp
60
+ will use the default global loss scaler for this backward pass.
61
+ model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
62
+ optimizations.
63
+ delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary, and
64
+ the default value of ``False`` is strongly recommended.
65
+ If ``True``, Amp will not unscale the gradients or perform model->master
66
+ gradient copies on context manager exit.
67
+ ``delay_unscale=True`` is a minor ninja performance optimization and can result
68
+ in weird gotchas (especially with multiple models/optimizers/losses),
69
+ so only use it if you know what you're doing.
70
+ "Gradient accumulation across iterations" under `Advanced Amp Usage`_
71
+ illustrates a situation where this CAN (but does not need to) be used.
72
+
73
+ .. warning::
74
+ If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be
75
+ called yet after context manager exit, and must wait for another, later backward context
76
+ manager invocation with ``delay_unscale`` left to False.
77
+
78
+ .. _`Advanced Amp Usage`:
79
+ https://nvidia.github.io/apex/advanced.html
80
+ """
81
+ if not hasattr(_amp_state, "opt_properties"):
82
+ raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized. "
83
+ "model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called "
84
+ "before `with amp.scale_loss`.")
85
+
86
+ if not _amp_state.opt_properties.enabled:
87
+ yield loss
88
+ return
89
+
90
+ if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
91
+ optimizers = [optimizers]
92
+
93
+ loss_scaler = _amp_state.loss_scalers[loss_id]
94
+ loss_scale = loss_scaler.loss_scale()
95
+
96
+ if ((not _amp_state.opt_properties.master_weights)
97
+ and (not loss_scaler.dynamic)
98
+ and loss_scale == 1.0):
99
+ yield loss.float()
100
+ # Needing to drop the cache here as well is an ugly gotcha.
101
+ # But for now I think it's necessary to short-circuit.
102
+ # Probably ok to skip this if not delay_unscale
103
+ if _amp_state.opt_properties.patch_torch_functions:
104
+ _amp_state.handle._clear_cache()
105
+ return
106
+
107
+ if not delay_unscale:
108
+ if isinstance(optimizers, list):
109
+ for optimizer in optimizers:
110
+ if not optimizer._amp_stash.params_have_scaled_gradients:
111
+ optimizer._prepare_amp_backward()
112
+
113
+ yield (loss.float())*loss_scale
114
+
115
+ if delay_unscale:
116
+ for optimizer in optimizers:
117
+ optimizer._amp_stash.params_have_scaled_gradients = True
118
+ else:
119
+ # FusedSGD may take care of unscaling as part of their step() methods.
120
+ # if not isinstance(optimizers, FP16_Optimizer_for_fused):
121
+ loss_scaler.clear_overflow_state()
122
+ for optimizer in optimizers:
123
+ optimizer._post_amp_backward(loss_scaler)
124
+ optimizer._amp_stash.params_have_scaled_gradients = False
125
+ # For future fused optimizers that enable sync-free dynamic loss scaling,
126
+ # should_skip will always be False.
127
+ should_skip = False if delay_overflow_check else loss_scaler.update_scale()
128
+ if should_skip:
129
+ for optimizer in optimizers:
130
+ if not optimizer._amp_stash.already_patched:
131
+ # Close on loss_scaler and loss_id as well, to be safe. Probably not
132
+ # necessary because amp.scale_loss is already creating a temporary scope.
133
+ def patch_step(opt, loss_scaler, loss_id):
134
+ opt_step = opt.step
135
+ def skip_step(closure=None):
136
+ if closure is not None:
137
+ raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
138
+ maybe_print(("Gradient overflow. Skipping step, loss scaler " +
139
+ "{} reducing loss scale to {}").format(loss_id,
140
+ loss_scaler.loss_scale()))
141
+ # TODO: I don't like the special casing for different optimizer implementations.
142
+ # Maybe skip should delegate to a method owned by the optimizers themselves.
143
+ if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
144
+ # Clear the master grads that wouldn't be zeroed by model.zero_grad()
145
+ for param in opt._amp_stash.all_fp32_from_fp16_params:
146
+ param.grad = None
147
+ if hasattr(opt, "most_recent_scale"):
148
+ opt.most_recent_scale = 1.0
149
+ opt.scale_set_by_backward = False
150
+ opt.step = opt_step
151
+ opt._amp_stash.already_patched = False
152
+ return skip_step
153
+ optimizer.step = patch_step(optimizer, loss_scaler, loss_id)
154
+ optimizer._amp_stash.already_patched = True
155
+
156
+ # Probably ok to skip this if not delay_unscale
157
+ if _amp_state.opt_properties.patch_torch_functions:
158
+ _amp_state.handle._clear_cache()
159
+
160
+
161
+ # Free function version of AmpHandle.disable_casts, another step on the
162
+ # path to removing the concept of "AmpHandle"
163
+ @contextlib.contextmanager
164
+ def disable_casts():
165
+ _amp_state.handle._is_active = False
166
+ yield
167
+ _amp_state.handle._is_active = True
168
+
169
+
170
+ class AmpHandle(object):
171
+ def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False):
172
+ self._enable_caching = enable_caching
173
+ self._verbose = verbose
174
+ self._cache = dict()
175
+ self._default_scaler = LossScaler(loss_scale)
176
+ self._is_active = True
177
+ self._all_wrappers = []
178
+
179
+ def is_active(self):
180
+ return self._is_active
181
+
182
+ @contextlib.contextmanager
183
+ def _disable_casts(self):
184
+ self._is_active = False
185
+ yield
186
+ self._is_active = True
187
+
188
+ def wrap_optimizer(self, optimizer, num_loss=1):
189
+ self._default_scaler = None
190
+ return OptimWrapper(optimizer, self, num_loss)
191
+
192
+ @contextlib.contextmanager
193
+ def scale_loss(self, loss, optimizer):
194
+ raise RuntimeError("The old Amp API is no longer supported. Please move to the new API, "
195
+ "documented here: https://nvidia.github.io/apex/amp.html. Transition guide: "
196
+ "https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users")
197
+
198
+ if not self.is_active():
199
+ yield loss
200
+ return
201
+
202
+ if self._default_scaler is None:
203
+ raise RuntimeError(
204
+ 'After calling `handle.wrap_optimizer()`, you must explicitly ' +
205
+ 'use `optimizer.scale_loss(loss)`.')
206
+
207
+ # TODO: this code block is duplicated here and `opt.py`. Unify.
208
+ loss_scale = self._default_scaler.loss_scale()
209
+ yield loss * loss_scale
210
+
211
+ self._default_scaler.clear_overflow_state()
212
+ self._default_scaler.unscale(
213
+ master_params(optimizer),
214
+ master_params(optimizer),
215
+ loss_scale)
216
+ should_skip = self._default_scaler.update_scale()
217
+ if should_skip:
218
+ optimizer_step = optimizer.step
219
+ def skip_step():
220
+ maybe_print('Gradient overflow, skipping update')
221
+ optimizer.step = optimizer_step
222
+ optimizer.step = skip_step
223
+
224
+ self._clear_cache()
225
+
226
+ def _clear_cache(self):
227
+ self._cache.clear()
228
+
229
+ # Experimental support for saving / restoring uncasted versions of functions
230
+ def _save_func(self, mod, fn, func):
231
+ self._all_wrappers.append((mod, fn, func))
232
+
233
+ def _deactivate(self):
234
+ for mod, fn, func in self._all_wrappers:
235
+ utils.set_func(mod, fn, func)
236
+ self._all_wrappers = []
237
+
238
+ @property
239
+ def has_cache(self):
240
+ return self._enable_caching
241
+
242
+ @property
243
+ def cache(self):
244
+ return self._cache
245
+
246
+ def remove_cache(self, param):
247
+ if self.has_cache and param in self.cache:
248
+ del self.cache[param]
249
+
250
+ @property
251
+ def verbose(self):
252
+ return self._verbose
253
+
254
+ class NoOpHandle(object):
255
+ def is_active(self):
256
+ return False
257
+
258
+ @contextlib.contextmanager
259
+ def _disable_casts(self):
260
+ yield
261
+
262
+ def wrap_optimizer(self, optimizer, num_loss=1):
263
+ return OptimWrapper(optimizer, self, num_loss)
264
+
265
+ @contextlib.contextmanager
266
+ def scale_loss(self, loss, optimizer):
267
+ yield loss
268
+
269
+ @property
270
+ def has_cache(self):
271
+ return False
272
+
273
+ @property
274
+ def verbose(self):
275
+ return False
276
+
277
+ def _clear_cache(self):
278
+ pass
279
+
280
+ def _deactivate(self):
281
+ pass
apex/amp/lists/__init__.py ADDED
File without changes
apex/amp/lists/functional_overrides.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # TODO: think about the following two. They do weird things.
3
+ # - torch.nn.utils.clip_grad (but it should always be fp32 anyway)
4
+ # - torch.nn.utils.weight_norm
5
+
6
+ # Notes:
7
+ # F.instance_norm uses batch_norm internally. Which correctly handles
8
+ # fp16 in/out with fp32 weights. So we shouldn't do anything for
9
+ # either of these.
10
+ # F.normalize calls `input.norm()` internally, so it's redundant, but
11
+ # kept here in case impl. changes.
12
+ # F.cosine_similarity is same: calls `x.norm()` internally.
13
+
14
+ import torch.nn.functional
15
+
16
+ MODULE = torch.nn.functional
17
+
18
+ FP16_FUNCS = [
19
+ 'conv1d',
20
+ 'conv2d',
21
+ 'conv3d',
22
+ 'conv_transpose1d',
23
+ 'conv_transpose2d',
24
+ 'conv_transpose3d',
25
+ 'conv_tbc', # Undocumented / maybe new?
26
+ 'linear',
27
+ ]
28
+
29
+ FP32_FUNCS = [
30
+
31
+ # Interpolation/Upsampling TODO: Remove for 1.2
32
+ 'interpolate',
33
+ 'grid_sample',
34
+
35
+ # Pointwise
36
+ 'softplus',
37
+ 'softmin',
38
+ 'log_softmax',
39
+ 'softmax',
40
+ 'gelu',
41
+
42
+ # Normalization
43
+ 'layer_norm',
44
+ 'group_norm',
45
+ 'local_response_norm',
46
+ 'normalize',
47
+ 'cosine_similarity',
48
+
49
+ # Loss functions
50
+ # TODO: which of these can be fp16?
51
+ 'poisson_nll_loss',
52
+ 'cosine_embedding_loss',
53
+ 'cross_entropy',
54
+ 'hinge_embedding_loss',
55
+ 'kl_div',
56
+ 'l1_loss',
57
+ 'mse_loss',
58
+ 'margin_ranking_loss',
59
+ 'multilabel_margin_loss',
60
+ 'multilabel_soft_margin_loss',
61
+ 'multi_margin_loss',
62
+ 'nll_loss',
63
+ 'binary_cross_entropy_with_logits',
64
+ 'smooth_l1_loss',
65
+ 'soft_margin_loss',
66
+ 'triplet_margin_loss',
67
+ 'ctc_loss'
68
+ ]
69
+
70
+ BANNED_FUNCS = [
71
+ ('binary_cross_entropy',
72
+ ("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
73
+ "It requires that the output of the previous function be already a FloatTensor. \n\n"
74
+ "Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
75
+ " torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
76
+ "that is compatible with amp.\nAnother option is to add\n"
77
+ " amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
78
+ "If you _really_ know what you are doing, you can disable this warning by passing "
79
+ "allow_banned=True to `amp.init()`."))
80
+ ]
apex/amp/lists/tensor_overrides.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .. import compat
2
+ from . import torch_overrides
3
+
4
+ import importlib
5
+
6
+ import torch
7
+
8
+ # if compat.variable_is_tensor() and not compat.tensor_is_variable():
9
+ MODULE = torch.Tensor
10
+ # else:
11
+ # MODULE = torch.autograd.Variable
12
+
13
+
14
+ FP16_FUNCS = compat.filter_attrs(MODULE, [
15
+ '__matmul__',
16
+ ])
17
+
18
+ FP32_FUNCS = compat.filter_attrs(MODULE, [
19
+ '__ipow__',
20
+ '__pow__',
21
+ '__rpow__',
22
+
23
+ # Cast to fp32 before transfer to CPU
24
+ 'cpu',
25
+ ])
26
+
27
+ CASTS = compat.filter_attrs(MODULE, [
28
+ '__add__',
29
+ '__div__',
30
+ '__eq__',
31
+ '__ge__',
32
+ '__gt__',
33
+ '__iadd__',
34
+ '__idiv__',
35
+ '__imul__',
36
+ '__isub__',
37
+ '__itruediv__',
38
+ '__le__',
39
+ '__lt__',
40
+ '__mul__',
41
+ '__ne__',
42
+ '__radd__',
43
+ '__rdiv__',
44
+ '__rmul__',
45
+ '__rsub__',
46
+ '__rtruediv__',
47
+ '__sub__',
48
+ '__truediv__',
49
+ ])
50
+
51
+ # None of these, but here to make code cleaner.
52
+ SEQUENCE_CASTS = []
53
+
54
+ # We need to grab all the methods from torch_overrides and add them to
55
+ # the Tensor lists as well, as almost all methods are duplicated
56
+ # between `torch` and `torch.Tensor` (and check with `hasattr`,
57
+ # because a few random ones aren't defined on Tensor)
58
+ _self_mod = importlib.import_module(__name__)
59
+ for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
60
+ lst = getattr(_self_mod, attrname)
61
+ for fn in getattr(torch_overrides, attrname):
62
+ if hasattr(MODULE, fn):
63
+ lst.append(fn)
apex/amp/lists/torch_overrides.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .. import utils
4
+
5
+ MODULE = torch
6
+
7
+ FP16_FUNCS = [
8
+ # Low level functions wrapped by torch.nn layers.
9
+ # The wrapper layers contain the weights which are then passed in as a parameter
10
+ # to these functions.
11
+ 'conv1d',
12
+ 'conv2d',
13
+ 'conv3d',
14
+ 'conv_transpose1d',
15
+ 'conv_transpose2d',
16
+ 'conv_transpose3d',
17
+ 'conv_tbc',
18
+ 'prelu',
19
+
20
+ # BLAS
21
+ 'addmm',
22
+ 'addmv',
23
+ 'addr',
24
+ 'matmul',
25
+ 'mm',
26
+ 'mv',
27
+ ]
28
+
29
+ FP32_FUNCS = [
30
+ # Pointwise
31
+ 'acos',
32
+ 'asin',
33
+ 'cosh',
34
+ 'erfinv',
35
+ 'exp',
36
+ 'expm1',
37
+ 'log',
38
+ 'log10',
39
+ 'log2',
40
+ 'reciprocal',
41
+ 'rsqrt',
42
+ 'sinh',
43
+ 'tan',
44
+
45
+ # Other math
46
+ 'pow',
47
+
48
+ # Reduction
49
+ 'cumprod',
50
+ 'cumsum',
51
+ 'dist',
52
+ # 'mean',
53
+ 'norm',
54
+ 'prod',
55
+ 'std',
56
+ 'sum',
57
+ 'var',
58
+
59
+ # Misc
60
+ 'renorm'
61
+ ]
62
+
63
+ version_strings = torch.__version__.split('.')
64
+ version_major = version_strings[0]
65
+ version_minor = version_strings[1]
66
+ version_num = float(version_major + "." + version_minor)
67
+ # Before torch 1.1, mean must be blacklisted.
68
+ if version_num < 1.1:
69
+ FP32_FUNCS.append('mean')
70
+
71
+ # Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
72
+ # check the CUDA version -- if at least 9.1, then put the bmm
73
+ # functions on the fp16 list. Otherwise, put them on the fp32 list.
74
+ _bmms = ['addbmm',
75
+ 'baddbmm',
76
+ 'bmm']
77
+
78
+ if utils.is_cuda_enabled():
79
+ # workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802
80
+ if utils.get_cuda_version() >= (9, 1, 0):
81
+ FP16_FUNCS.extend(_bmms)
82
+ else:
83
+ FP32_FUNCS.extend(_bmms)
84
+
85
+ # Multi-tensor fns that may need type promotion
86
+ CASTS = [
87
+ # Multi-tensor math
88
+ 'addcdiv',
89
+ 'addcmul',
90
+ 'atan2',
91
+ 'cross',
92
+ 'bilinear',
93
+ 'dot',
94
+
95
+ # Element-wise _or_ tensor-wise math
96
+ 'add',
97
+ 'div',
98
+ 'mul',
99
+
100
+ # Comparison
101
+ 'eq',
102
+ 'equal',
103
+ 'ge',
104
+ 'gt',
105
+ 'le',
106
+ 'lt',
107
+ 'ne'
108
+ ]
109
+
110
+ # Functions that take sequence arguments. We need to inspect the whole
111
+ # sequence and cast to the widest type.
112
+ SEQUENCE_CASTS = [
113
+ 'cat',
114
+ 'stack'
115
+ ]
apex/amp/opt.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ from .scaler import LossScaler, master_params
5
+ from ._amp_state import maybe_print
6
+
7
+ import numpy as np
8
+
9
+ class OptimWrapper(object):
10
+ def __init__(self, optimizer, amp_handle, num_loss):
11
+ self._optimizer = optimizer
12
+ self._amp_handle = amp_handle
13
+ self._num_loss = num_loss
14
+ self._loss_idx = 0
15
+ self._skip_next = [False] * num_loss
16
+ self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)]
17
+
18
+ @contextlib.contextmanager
19
+ def scale_loss(self, loss):
20
+ if not self._amp_handle.is_active():
21
+ yield loss
22
+ return
23
+
24
+ # When there are multiple losses per-optimizer, we need
25
+ # to save out current grad accumulation, since we won't be
26
+ # able to unscale this particulare loss once the grads are
27
+ # all mixed together.
28
+ cached_grads = []
29
+ if self._loss_idx > 0:
30
+ for p in master_params(self._optimizer):
31
+ if p.grad is not None:
32
+ cached_grads.append(p.grad.data.detach().clone())
33
+ else:
34
+ cached_grads.append(None)
35
+ self._optimizer.zero_grad()
36
+
37
+ loss_scale = self._cur_loss_scaler().loss_scale()
38
+ yield loss * loss_scale
39
+
40
+ self._cur_loss_scaler().clear_overflow_state()
41
+ self._cur_loss_scaler().unscale(
42
+ master_params(self._optimizer),
43
+ master_params(self._optimizer),
44
+ loss_scale)
45
+ self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()
46
+ self._loss_idx += 1
47
+
48
+ if len(cached_grads) > 0:
49
+ for p, cached_grad in zip(master_params(self._optimizer),
50
+ cached_grads):
51
+ if cached_grad is not None:
52
+ p.grad.data.add_(cached_grad)
53
+ cached_grads = []
54
+
55
+ def _cur_loss_scaler(self):
56
+ assert 0 <= self._loss_idx < self._num_loss
57
+ return self._loss_scaler[self._loss_idx]
58
+
59
+ def step(self, closure=None):
60
+ if not self._amp_handle.is_active():
61
+ return self._optimizer.step(closure=closure)
62
+
63
+ self._loss_idx = 0
64
+
65
+ for group in self._optimizer.param_groups:
66
+ for p in group['params']:
67
+ self._amp_handle.remove_cache(p)
68
+
69
+ if closure is not None:
70
+ raise NotImplementedError(
71
+ 'The `closure` argument is unsupported by the amp ' +
72
+ 'optimizer wrapper.')
73
+ if any(self._skip_next):
74
+ maybe_print('Gradient overflow, skipping update')
75
+ self._skip_next = [False] * self._num_loss
76
+ else:
77
+ return self._optimizer.step(closure=closure)
78
+
79
+ # Forward any attribute lookups
80
+ def __getattr__(self, attr):
81
+ return getattr(self._optimizer, attr)
82
+
83
+ # Forward all torch.optim.Optimizer methods
84
+ def __getstate__(self):
85
+ return self._optimizer.__getstate__()
86
+
87
+ def __setstate__(self):
88
+ return self._optimizer.__setstate__()
89
+
90
+ def __repr__(self):
91
+ return self._optimizer.__repr__()
92
+
93
+ def state_dict(self):
94
+ return self._optimizer.state_dict()
95
+
96
+ def load_state_dict(self, state_dict):
97
+ return self._optimizer.load_state_dict(state_dict)
98
+
99
+ def zero_grad(self):
100
+ return self._optimizer.zero_grad()
101
+
102
+ def add_param_group(self, param_group):
103
+ return self._optimizer.add_param_group(param_group)
apex/amp/rnn_compat.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import utils, wrap
2
+
3
+ import torch
4
+ _VF = torch._C._VariableFunctions
5
+ RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
6
+
7
+ def _gen_VF_wrapper(name):
8
+ def wrapper(*args, **kwargs):
9
+ return getattr(_VF, name)(*args, **kwargs)
10
+ return wrapper
11
+
12
+ # Some python magic to generate an object that has the rnn cell functions
13
+ # defined on it, all of which call into corresponding _VF version.
14
+ # Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF"
15
+ # imported at module scope within torch.nn.modules.rnn). This should
16
+ # not affect third-party importers of _VF.py.
17
+ class VariableFunctionsShim(object):
18
+ def __init__(self):
19
+ for name in RNN_NAMES:
20
+ for suffix in ['', '_cell']:
21
+ fn_name = name + suffix
22
+ setattr(self, fn_name, _gen_VF_wrapper(fn_name))
23
+
24
+ def has_old_rnns():
25
+ try:
26
+ torch.nn.backends.thnn.backend.LSTMCell
27
+ return True
28
+ except:
29
+ return False
30
+
31
+ def whitelist_rnn_cells(handle, verbose):
32
+ # Different module + function names in old/new RNN cases
33
+ if has_old_rnns():
34
+ fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
35
+ mod = torch.nn.backends.thnn.backend
36
+ else:
37
+ fn_names = [x + '_cell' for x in RNN_NAMES]
38
+ mod = torch.nn.modules.rnn._VF
39
+ assert isinstance(mod, VariableFunctionsShim)
40
+
41
+ # Insert casts on cell functions
42
+ for fn in fn_names:
43
+ wrap.cached_cast(mod, fn, utils.maybe_half, handle,
44
+ try_caching=True, verbose=verbose)
45
+
46
+ if has_old_rnns():
47
+ # Special handling of `backward` for fused gru / lstm:
48
+ # The `backward` method calls Tensor.sum() (blacklist) internally,
49
+ # and then the resulting grad_input has the wrong type.
50
+ # TODO: where else is this a problem?
51
+ for rnn_type in ['GRUFused', 'LSTMFused']:
52
+ mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
53
+ wrap.disable_casts(mod, 'backward', handle)
apex/amp/scaler.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..multi_tensor_apply import multi_tensor_applier
3
+ from ._amp_state import _amp_state, master_params, maybe_print
4
+ from itertools import product
5
+
6
+ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
7
+ # Exception handling for 18.04 compatibility
8
+ if check_overflow:
9
+ cpu_sum = float(model_grad.float().sum())
10
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
11
+ return True
12
+
13
+ if master_grad is not model_grad: # copy_ probably internally short-circuits this
14
+ master_grad.copy_(model_grad)
15
+ if scale != 1.0:
16
+ master_grad.mul_(scale)
17
+ return False
18
+
19
+ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
20
+ # Exception handling for 18.04 compatibility
21
+ if check_overflow:
22
+ cpu_sum = float(model_grad.float().sum())
23
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
24
+ return True
25
+
26
+ # if master_grad is not model_grad: # copy_ probably internally short-circuits this
27
+ # master_grad.copy_(model_grad)
28
+ assert stashed_grad.dtype == master_grad.dtype
29
+ converted_model_grad = model_grad.data.to(master_grad.dtype)
30
+ master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
31
+ return False
32
+
33
+ class LossScaler(object):
34
+ warned_no_fused_kernel = False
35
+ warned_unscaling_non_fp32_grad = False
36
+ has_fused_kernel = False
37
+
38
+ def __init__(self,
39
+ loss_scale,
40
+ init_scale=2.**16,
41
+ scale_factor=2.,
42
+ scale_window=2000,
43
+ min_loss_scale=None,
44
+ max_loss_scale=2.**24):
45
+ if loss_scale == "dynamic":
46
+ self.dynamic = True
47
+ self._loss_scale = min(max_loss_scale, init_scale)
48
+ else:
49
+ self.dynamic = False
50
+ self._loss_scale = loss_scale
51
+ self._max_loss_scale = max_loss_scale
52
+ self._min_loss_scale = min_loss_scale
53
+ self._scale_seq_len = scale_window
54
+ self._unskipped = 0
55
+ self._has_overflow = False
56
+ self._overflow_buf = torch.cuda.IntTensor([0])
57
+ if multi_tensor_applier.available:
58
+ import amp_C
59
+ LossScaler.has_fused_kernel = multi_tensor_applier.available
60
+ LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
61
+ LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
62
+ else:
63
+ if not LossScaler.warned_no_fused_kernel:
64
+ maybe_print(
65
+ "Warning: multi_tensor_applier fused unscale kernel is unavailable, "
66
+ "possibly because apex was installed without --cuda_ext --cpp_ext. "
67
+ "Using Python fallback. Original ImportError was: " +
68
+ repr(multi_tensor_applier.import_err),
69
+ True)
70
+ LossScaler.has_fused_kernel = False
71
+ LossScaler.warned_no_fused_kernel = True
72
+
73
+ def loss_scale(self):
74
+ return self._loss_scale
75
+
76
+ def unscale_python(self, model_grads, master_grads, scale):
77
+ for model, master in zip(model_grads, master_grads):
78
+ if model is not None:
79
+ if not LossScaler.warned_unscaling_non_fp32_grad:
80
+ if master.dtype != torch.float32:
81
+ maybe_print(
82
+ "Attempting to unscale a grad with type {} ".format(master.type()) +
83
+ "Unscaling non-fp32 grads may indicate an error. "
84
+ "When using Amp, you don't need to call .half() on your model.")
85
+ LossScaler.warned_unscaling_non_fp32_grad = True
86
+ self._has_overflow = scale_check_overflow_python(model,
87
+ master,
88
+ 1./scale,
89
+ self.dynamic)
90
+ if self._has_overflow and self.dynamic:
91
+ break
92
+
93
+ # unused_scale keeps some of the old API alive for hopefully a short time.
94
+ def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):
95
+ if self._has_overflow:
96
+ return
97
+
98
+ scale = self._loss_scale
99
+ if scale_override is not None:
100
+ scale = scale_override
101
+
102
+ if scale == 1.0 and models_are_masters and not self.dynamic:
103
+ return
104
+
105
+ if LossScaler.has_fused_kernel:
106
+ # if (not LossScaler.warned_unscaling_non_fp32_grad
107
+ # and master_grads[0].dtype == torch.float16):
108
+ # print("Warning: unscaling grads that are not FP32. "
109
+ # "Unscaling non-fp32 grads may indicate an error. "
110
+ # "When using Amp, you don't need to call .half() on your model.")
111
+ # # Setting this to True unconditionally allows the possibility of an escape
112
+ # # if never-before-seen non-fp32 grads are created in some later iteration.
113
+ # LossScaler.warned_unscaling_non_fp32_grad = True
114
+ multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,
115
+ self._overflow_buf,
116
+ [model_grads, master_grads],
117
+ 1./scale)
118
+ else:
119
+ self.unscale_python(model_grads, master_grads, scale)
120
+
121
+ # Defer to update_scale
122
+ # If the fused kernel is available, we only need one D2H memcopy and sync.
123
+ # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
124
+ # self._has_overflow = self._overflow_buf.item()
125
+
126
+ def unscale_with_stashed_python(self,
127
+ model_grads,
128
+ stashed_master_grads,
129
+ master_grads,
130
+ a,
131
+ b):
132
+ for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
133
+ if model is None and stashed is None:
134
+ continue
135
+ else:
136
+ if not LossScaler.warned_unscaling_non_fp32_grad:
137
+ if master.dtype != torch.float32:
138
+ maybe_print(
139
+ "Attempting to unscale a grad with type {} ".format(master.type()) +
140
+ "Unscaling non-fp32 grads may indicate an error. "
141
+ "When using Amp, you don't need to call .half() on your model.")
142
+ LossScaler.warned_unscaling_non_fp32_grad = True
143
+ self._has_overflow = axpby_check_overflow_python(model,
144
+ stashed,
145
+ master,
146
+ a,
147
+ b,
148
+ self.dynamic)
149
+ if self._has_overflow and self.dynamic:
150
+ break
151
+
152
+ def unscale_with_stashed(self,
153
+ model_grads,
154
+ stashed_master_grads,
155
+ master_grads,
156
+ scale_override=None):
157
+ if self._has_overflow:
158
+ return
159
+
160
+ grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
161
+ if scale_override is not None:
162
+ grads_have_scale, stashed_have_scale, out_scale = scale_override
163
+
164
+ if LossScaler.has_fused_kernel:
165
+ if (not LossScaler.warned_unscaling_non_fp32_grad
166
+ and master_grads[0].dtype == torch.float16):
167
+ print("Warning: unscaling grads that are not FP32. "
168
+ "Unscaling non-fp32 grads may indicate an error. "
169
+ "When using Amp, you don't need to call .half() on your model.")
170
+ # Setting this to True unconditionally allows the possibility of an escape
171
+ # if never-before-seen non-fp32 grads are created in some later iteration.
172
+ LossScaler.warned_unscaling_non_fp32_grad = True
173
+ multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
174
+ self._overflow_buf,
175
+ [model_grads, stashed_master_grads, master_grads],
176
+ out_scale/grads_have_scale, # 1./scale,
177
+ out_scale/stashed_have_scale, # 1.0,
178
+ 0) # check only arg 0, aka the incoming model grads, for infs
179
+ else:
180
+ self.unscale_with_stashed_python(model_grads,
181
+ stashed_master_grads,
182
+ master_grads,
183
+ out_scale/grads_have_scale,
184
+ out_scale/stashed_have_scale)
185
+
186
+ # Defer to update_scale
187
+ # If the fused kernel is available, we only need one D2H memcopy and sync.
188
+ # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
189
+ # self._has_overflow = self._overflow_buf.item()
190
+
191
+ def clear_overflow_state(self):
192
+ self._has_overflow = False
193
+ if self.has_fused_kernel:
194
+ self._overflow_buf.zero_()
195
+
196
+ # Separate so unscale() can be called more that once before updating.
197
+ def update_scale(self):
198
+ # If the fused kernel is available, we only need one D2H memcopy and sync.
199
+ if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
200
+ self._has_overflow = self._overflow_buf.item()
201
+
202
+ if self._has_overflow and self.dynamic:
203
+ should_skip = True
204
+ if(self._min_loss_scale):
205
+ self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
206
+ else:
207
+ self._loss_scale = self._loss_scale/2.
208
+ self._unskipped = 0
209
+ else:
210
+ should_skip = False
211
+ self._unskipped += 1
212
+
213
+ if self._unskipped == self._scale_seq_len and self.dynamic:
214
+ self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)
215
+ self._unskipped = 0
216
+
217
+ return should_skip
apex/amp/utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import compat
2
+
3
+ import functools
4
+ import itertools
5
+
6
+ import torch
7
+
8
+ def is_cuda_enabled():
9
+ return torch.version.cuda is not None
10
+
11
+ def get_cuda_version():
12
+ return tuple(int(x) for x in torch.version.cuda.split('.'))
13
+
14
+ def is_fp_tensor(x):
15
+ if is_nested(x):
16
+ # Fast-fail version of all(is_fp_tensor)
17
+ for y in x:
18
+ if not is_fp_tensor(y):
19
+ return False
20
+ return True
21
+ return compat.is_tensor_like(x) and compat.is_floating_point(x)
22
+
23
+ def is_nested(x):
24
+ return isinstance(x, tuple) or isinstance(x, list)
25
+
26
+ def should_cache(x):
27
+ if is_nested(x):
28
+ # Fast-fail version of all(should_cache)
29
+ for y in x:
30
+ if not should_cache(y):
31
+ return False
32
+ return True
33
+ return isinstance(x, torch.nn.parameter.Parameter) and \
34
+ type_string(x) == 'FloatTensor'
35
+
36
+ def collect_fp_tensor_types(args, kwargs):
37
+ def collect_types(x, types):
38
+ if is_nested(x):
39
+ for y in x:
40
+ collect_types(y, types)
41
+ else:
42
+ types.add(type_string(x))
43
+
44
+ all_args = itertools.chain(args, kwargs.values())
45
+ types = set()
46
+ for x in all_args:
47
+ if is_fp_tensor(x):
48
+ collect_types(x, types)
49
+ return types
50
+
51
+ def type_string(x):
52
+ return x.type().split('.')[-1]
53
+
54
+ def maybe_half(x, name='', verbose=False):
55
+ if is_nested(x):
56
+ return type(x)([maybe_half(y) for y in x])
57
+
58
+ if not x.is_cuda or type_string(x) == 'HalfTensor':
59
+ return x
60
+ else:
61
+ if verbose:
62
+ print('Float->Half ({})'.format(name))
63
+ return x.half()
64
+
65
+ def maybe_float(x, name='', verbose=False):
66
+ if is_nested(x):
67
+ return type(x)([maybe_float(y) for y in x])
68
+
69
+ if not x.is_cuda or type_string(x) == 'FloatTensor':
70
+ return x
71
+ else:
72
+ if verbose:
73
+ print('Half->Float ({})'.format(name))
74
+ return x.float()
75
+
76
+ # NB: returneds casted `args`, mutates `kwargs` in-place
77
+ def casted_args(cast_fn, args, kwargs):
78
+ new_args = []
79
+ for x in args:
80
+ if is_fp_tensor(x):
81
+ new_args.append(cast_fn(x))
82
+ else:
83
+ new_args.append(x)
84
+ for k in kwargs:
85
+ val = kwargs[k]
86
+ if is_fp_tensor(val):
87
+ kwargs[k] = cast_fn(val)
88
+ return new_args
89
+
90
+ def cached_cast(cast_fn, x, cache):
91
+ if is_nested(x):
92
+ return type(x)([cached_cast(y) for y in x])
93
+ if x in cache:
94
+ cached_x = cache[x]
95
+ if x.requires_grad and cached_x.requires_grad:
96
+ # Make sure x is actually cached_x's autograd parent.
97
+ if cached_x.grad_fn.next_functions[1][0].variable is not x:
98
+ raise RuntimeError("x and cache[x] both require grad, but x is not "
99
+ "cache[x]'s parent. This is likely an error.")
100
+ # During eval, it's possible to end up caching casted weights with
101
+ # requires_grad=False. On the next training iter, if cached_x is found
102
+ # and reused from the cache, it will not actually have x as its parent.
103
+ # Therefore, we choose to invalidate the cache (and force refreshing the cast)
104
+ # if x.requires_grad and cached_x.requires_grad do not match.
105
+ #
106
+ # During eval (i.e. running under with torch.no_grad()) the invalidation
107
+ # check would cause the cached value to be dropped every time, because
108
+ # cached_x would always be created with requires_grad=False, while x would
109
+ # still have requires_grad=True. This would render the cache effectively
110
+ # useless during eval. Therefore, if we are running under the no_grad()
111
+ # context manager (torch.is_grad_enabled=False) we elide the invalidation
112
+ # check, and use the cached value even though its requires_grad flag doesn't
113
+ # match. During eval, we don't care that there's no autograd-graph
114
+ # connection between x and cached_x.
115
+ if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
116
+ del cache[x]
117
+ else:
118
+ return cached_x
119
+
120
+ casted_x = cast_fn(x)
121
+ cache[x] = casted_x
122
+ return casted_x
123
+
124
+ def verbosify(cast_fn, fn_name, verbose):
125
+ if verbose:
126
+ return functools.partial(cast_fn, name=fn_name, verbose=verbose)
127
+ else:
128
+ return cast_fn
129
+
130
+ def as_inplace(fns):
131
+ for x in fns:
132
+ yield x + '_'
133
+
134
+ def has_func(mod, fn):
135
+ if isinstance(mod, dict):
136
+ return fn in mod
137
+ else:
138
+ return hasattr(mod, fn)
139
+
140
+ def get_func(mod, fn):
141
+ if isinstance(mod, dict):
142
+ return mod[fn]
143
+ else:
144
+ return getattr(mod, fn)
145
+
146
+ def set_func(mod, fn, new_fn):
147
+ if isinstance(mod, dict):
148
+ mod[fn] = new_fn
149
+ else:
150
+ setattr(mod, fn, new_fn)
151
+
152
+ def set_func_save(handle, mod, fn, new_fn):
153
+ cur_fn = get_func(mod, fn)
154
+ handle._save_func(mod, fn, cur_fn)
155
+ set_func(mod, fn, new_fn)
156
+
157
+ # A couple problems get solved here:
158
+ # - The flat_weight buffer is disconnected from autograd graph,
159
+ # so the fp16 weights need to be derived from the input weights
160
+ # to this forward call, not the flat buffer.
161
+ # - The ordering of weights in the flat buffer is...idiosyncratic.
162
+ # First problem is solved with combination of set_ (to set up
163
+ # correct storage) and copy_ (so the fp16 weight derives from the
164
+ # fp32 one in autograd.
165
+ # Second is solved by doing ptr arithmetic on the fp32 weights
166
+ # to derive the correct offset.
167
+ #
168
+ # TODO: maybe this should actually use
169
+ # `torch._cudnn_rnn_flatten_weight`? But then I need to call
170
+ # on first iter and cache the right offsets. Ugh.
171
+ def synthesize_flattened_rnn_weights(fp32_weights,
172
+ fp16_flat_tensor,
173
+ rnn_fn='',
174
+ verbose=False):
175
+ fp16_weights = []
176
+ fp32_base_ptr = fp32_weights[0][0].data_ptr()
177
+ for layer_weights in fp32_weights:
178
+ fp16_layer_weights = []
179
+ for w_fp32 in layer_weights:
180
+ w_fp16 = w_fp32.new().half()
181
+ offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
182
+ w_fp16.set_(fp16_flat_tensor.storage(),
183
+ offset,
184
+ w_fp32.shape)
185
+ w_fp16.copy_(w_fp32)
186
+ if verbose:
187
+ print('Float->Half ({})'.format(rnn_fn))
188
+ fp16_layer_weights.append(w_fp16)
189
+ fp16_weights.append(fp16_layer_weights)
190
+ return fp16_weights
191
+
192
+ # Roughly same as above, just the `fp32_weights` aren't nested.
193
+ # Code kept separate for readability.
194
+ def new_synthesize_flattened_rnn_weights(fp32_weights,
195
+ fp16_flat_tensor,
196
+ rnn_fn='',
197
+ verbose=False):
198
+ fp16_weights = []
199
+ fp32_base_ptr = fp32_weights[0].data_ptr()
200
+ for w_fp32 in fp32_weights:
201
+ w_fp16 = w_fp32.new().half()
202
+ offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
203
+ w_fp16.set_(fp16_flat_tensor.storage(),
204
+ offset,
205
+ w_fp32.shape)
206
+ w_fp16.copy_(w_fp32)
207
+ if verbose:
208
+ print('Float->Half ({})'.format(rnn_fn))
209
+ fp16_weights.append(w_fp16)
210
+ return fp16_weights
apex/amp/wrap.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import compat
2
+ from . import utils
3
+ from ._amp_state import _amp_state
4
+ from . import rnn_compat
5
+
6
+ import functools
7
+
8
+ import torch
9
+
10
+ def make_cast_wrapper(orig_fn, cast_fn, handle,
11
+ try_caching=False):
12
+ @functools.wraps(orig_fn)
13
+ def wrapper(*args, **kwargs):
14
+ if not handle.is_active():
15
+ return orig_fn(*args, **kwargs)
16
+
17
+ if try_caching and handle.has_cache:
18
+ args = list(args)
19
+ for i in range(len(args)):
20
+ if utils.should_cache(args[i]):
21
+ args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)
22
+ for k in kwargs:
23
+ if utils.should_cache(kwargs[k]):
24
+ kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)
25
+ new_args = utils.casted_args(cast_fn,
26
+ args,
27
+ kwargs)
28
+ return orig_fn(*new_args, **kwargs)
29
+ return wrapper
30
+
31
+ def cached_cast(mod, fn, cast_fn, handle,
32
+ try_caching=False, verbose=False):
33
+ if not utils.has_func(mod, fn):
34
+ return
35
+
36
+ orig_fn = utils.get_func(mod, fn)
37
+ cast_fn = utils.verbosify(cast_fn, fn, verbose)
38
+ wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
39
+ utils.set_func_save(handle, mod, fn, wrapper)
40
+
41
+ # `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
42
+ # Annoyingly, make_promote_wrapper still uses the global handle. Once everyone
43
+ # is on the new API and I am free to get rid of handle, I can clean this up.
44
+ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
45
+ @functools.wraps(orig_fn)
46
+ def wrapper(*args, **kwargs):
47
+ if not _amp_state.handle.is_active():
48
+ return orig_fn(*args, **kwargs)
49
+
50
+ types = utils.collect_fp_tensor_types(args, kwargs)
51
+
52
+ if len(types) <= 1:
53
+ return orig_fn(*args, **kwargs)
54
+ elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
55
+ new_args = utils.casted_args(cast_fn,
56
+ args,
57
+ kwargs)
58
+ return orig_fn(*new_args, **kwargs)
59
+ else:
60
+ raise NotImplementedError('Do not know how to handle ' +
61
+ 'these types to promote: {}'
62
+ .format(types))
63
+ return wrapper
64
+
65
+ def promote(mod, fn, handle, verbose=False):
66
+ orig_fn = utils.get_func(mod, fn)
67
+ maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
68
+ wrapper = make_promote_wrapper(orig_fn, maybe_float)
69
+ utils.set_func_save(handle, mod, fn, wrapper)
70
+
71
+ def sequence_promote(mod, fn, handle, verbose=False):
72
+ orig_fn = utils.get_func(mod, fn)
73
+ maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
74
+ @functools.wraps(orig_fn)
75
+ def wrapper(seq, *args, **kwargs):
76
+ if not _amp_state.handle.is_active():
77
+ return orig_fn(seq, *args, **kwargs)
78
+
79
+ types = set([utils.type_string(x) for x in seq])
80
+ if len(types) <= 1:
81
+ return orig_fn(seq, *args, **kwargs)
82
+ elif types == set(['HalfTensor', 'FloatTensor']):
83
+ cast_seq = utils.casted_args(maybe_float,
84
+ seq, {})
85
+ return orig_fn(cast_seq, *args, **kwargs)
86
+ else:
87
+ # TODO: other mixed-type cases aren't due to amp.
88
+ # Just pass through?
89
+ return orig_fn(seq, *args, **kwargs)
90
+ utils.set_func_save(handle, mod, fn, wrapper)
91
+
92
+ def promote_match_arg0(mod, fn, handle, verbose=False):
93
+ if not utils.has_func(mod, fn):
94
+ return
95
+
96
+ orig_fn = utils.get_func(mod, fn)
97
+ @functools.wraps(orig_fn)
98
+ def wrapper(arg0, *args, **kwargs):
99
+ assert compat.is_tensor_like(arg0)
100
+ if not _amp_state.handle.is_active():
101
+ return orig_fn(arg0, *args, **kwargs)
102
+
103
+ if utils.type_string(arg0) == 'HalfTensor':
104
+ cast_fn = utils.maybe_half
105
+ elif utils.type_string(arg0) == 'FloatTensor':
106
+ cast_fn = utils.maybe_float
107
+ else:
108
+ return orig_fn(arg0, *args, **kwargs)
109
+ cast_fn = utils.verbosify(cast_fn, fn, verbose)
110
+ new_args = utils.casted_args(cast_fn, args, kwargs)
111
+ return orig_fn(arg0, *new_args, **kwargs)
112
+ utils.set_func_save(handle, mod, fn, wrapper)
113
+
114
+ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
115
+ if not utils.has_func(mod, fn):
116
+ return
117
+
118
+ orig_fn = utils.get_func(mod, fn)
119
+ @functools.wraps(orig_fn)
120
+ def wrapper(*args, **kwargs):
121
+ types = utils.collect_fp_tensor_types(args, kwargs)
122
+ if 'HalfTensor' in types:
123
+ if custom_err_msg:
124
+ raise NotImplementedError(custom_err_msg)
125
+ else:
126
+ raise NotImplementedError('Cannot call in-place function ' +
127
+ '{} with fp16 arguments.'.format(fn))
128
+ else:
129
+ return orig_fn(*args, **kwargs)
130
+ utils.set_func_save(handle, mod, fn, wrapper)
131
+
132
+ def err_if_arg0_half(mod, fn, handle, verbose=False):
133
+ if not utils.has_func(mod, fn):
134
+ return
135
+
136
+ orig_fn = utils.get_func(mod, fn)
137
+ @functools.wraps(orig_fn)
138
+ def wrapper(arg0, *args, **kwargs):
139
+ assert compat.is_tensor_like(arg0)
140
+ if utils.type_string(arg0) == 'HalfTensor':
141
+ raise NotImplementedError('Cannot call in-place method ' +
142
+ '{} on fp16 Tensors.'.format(fn))
143
+ else:
144
+ cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
145
+ new_args = utils.casted_args(cast_fn, args, kwargs)
146
+ return orig_fn(arg0, *new_args, **kwargs)
147
+ utils.set_func_save(handle, mod, fn, wrapper)
148
+
149
+ # Current RNN approach:
150
+ # - Wrap top-level `RNN` function in thnn backend
151
+ # - Will call into either CudnnRNN or AutogradRNN
152
+ # - Each of these are factory functions that return a per-iter
153
+ # `forward` function
154
+ # - We interpose on the factory function to:
155
+ # 1) Interpose on the actual forward function and put in casts
156
+ # 2) Insert an fp16 `flat_weight` if necessary
157
+ def rnn_cast(backend, fn, handle, verbose=False):
158
+ orig_rnn = utils.get_func(backend, fn)
159
+ @functools.wraps(orig_rnn)
160
+ def rnn_wrapper(*args, **kwargs):
161
+ flat_weight = kwargs.get('flat_weight')
162
+ if flat_weight is not None:
163
+ # We replace `flat_weight` with an uninitialized fp16
164
+ # Tensor. The "actual" weight tensors (provided in `forward`),
165
+ # will then be set up as ptrs into the buffer and have the
166
+ # corresponding fp32 values copied in.
167
+ # We need to call `copy` on the "actual" weights so that the
168
+ # autograd graph correctly backprops from the wgrads computed
169
+ # inside cuDNN (on fp16 weights) into the fp32 weights.
170
+ assert utils.type_string(flat_weight) == 'FloatTensor'
171
+ if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
172
+ # Pre-0.4. A little slower, since it zeros out memory.
173
+ flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
174
+ else:
175
+ flat_weight_fp16 = torch.empty_like(flat_weight,
176
+ dtype=torch.float16)
177
+ kwargs['flat_weight'] = flat_weight_fp16
178
+ else:
179
+ flat_weight_fp16 = None
180
+
181
+ forward = orig_rnn(*args, **kwargs)
182
+ @functools.wraps(forward)
183
+ def fwd_wrapper(*fargs, **fkwargs):
184
+ assert len(fargs) == 3 or len(fargs) == 4
185
+ inputs, weights, hiddens = fargs[:3]
186
+ assert utils.is_fp_tensor(inputs)
187
+ assert isinstance(weights, list)
188
+ cast_fn = utils.verbosify(utils.maybe_half,
189
+ fn,
190
+ verbose)
191
+ new_args = []
192
+
193
+ # 0) Inputs
194
+ new_args.append(cast_fn(inputs))
195
+
196
+ # 1) Weights
197
+ if flat_weight_fp16 is not None:
198
+ fp16_weights = utils.synthesize_flattened_rnn_weights(
199
+ weights, flat_weight_fp16, fn, verbose)
200
+ else:
201
+ fp16_weights = [[cast_fn(w) for w in layer]
202
+ for layer in weights]
203
+ new_args.append(fp16_weights)
204
+
205
+ # 2) Inputs: either a tuple (for LSTM) or single tensor
206
+ if isinstance(hiddens, tuple):
207
+ new_args.append(tuple(cast_fn(x) for x in hiddens))
208
+ elif utils.is_fp_tensor(hiddens):
209
+ new_args.append(cast_fn(hiddens))
210
+ else:
211
+ # Hiddens can, in principle, be `None` -- pass through
212
+ new_args.append(hiddens)
213
+
214
+ # 3) Batch sizes (0.4 or later only)
215
+ if len(fargs) == 4:
216
+ new_args.append(fargs[3])
217
+
218
+ return forward(*new_args, **fkwargs)
219
+ return fwd_wrapper
220
+ utils.set_func_save(handle, backend, fn, rnn_wrapper)
221
+
222
+ def new_rnn_cast(fn, handle, verbose=False):
223
+ # Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
224
+ # For rnn backend calls that route through _rnn_impls, we must patch the ref
225
+ # that _rnn_impls stashed. For rnn backend calls that directly invoke
226
+ # _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,
227
+ # which in turn has patched the ref named "_VF" in torch.nn.modules.rnn.
228
+ if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):
229
+ mod = torch.nn.modules.rnn._rnn_impls
230
+ else:
231
+ mod = torch.nn.modules.rnn._VF
232
+ assert isinstance(mod, rnn_compat.VariableFunctionsShim)
233
+ fn = fn.lower()
234
+ orig_fn = utils.get_func(mod, fn)
235
+ cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
236
+ @functools.wraps(orig_fn)
237
+ def wrapper(*args, **kwargs):
238
+ # Exact call signature from modules/rnn.py
239
+ assert len(args) == 9
240
+ assert len(kwargs) == 0
241
+
242
+ if not _amp_state.handle.is_active():
243
+ return orig_fn(*args, **kwargs)
244
+
245
+ if isinstance(args[6], bool):
246
+ params_idx = 2 # Not PackedSequence case
247
+ else:
248
+ params_idx = 3 # PackedSequence case
249
+
250
+ new_args = []
251
+ for i, arg in enumerate(args):
252
+ if i == params_idx:
253
+ num_params = sum([x.numel() for x in arg])
254
+ fp16_weight_buf = args[0].new_empty((num_params,),
255
+ dtype=torch.half)
256
+ casted_weights = utils.new_synthesize_flattened_rnn_weights(
257
+ arg, fp16_weight_buf, fn, verbose)
258
+ new_args.append(casted_weights)
259
+ elif utils.is_fp_tensor(arg):
260
+ new_args.append(cast_fn(arg))
261
+ else:
262
+ new_args.append(arg)
263
+
264
+ return orig_fn(*new_args)
265
+ utils.set_func_save(handle, mod, fn, wrapper)
266
+
267
+ def disable_casts(mod, fn, handle):
268
+ if not utils.has_func(mod, fn):
269
+ return
270
+
271
+ orig_fn = utils.get_func(mod, fn)
272
+ @functools.wraps(orig_fn)
273
+ def wrapper(*args, **kwargs):
274
+ with handle._disable_casts():
275
+ return orig_fn(*args, **kwargs)
276
+ utils.set_func_save(handle, mod, fn, wrapper)
apex/contrib/__init__.py ADDED
File without changes
apex/contrib/bottleneck/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .bottleneck import Bottleneck, SpatialBottleneck
2
+ from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
apex/contrib/bottleneck/bottleneck.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools as func
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch import nn
6
+
7
+ from apex import check_cudnn_version_and_warn
8
+ import fast_bottleneck
9
+ import nccl_p2p_cuda as inc
10
+
11
+
12
+ assert check_cudnn_version_and_warn(__name__, 8400)
13
+
14
+
15
+ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
16
+ weight_tensor_nchw = tensor
17
+ nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
18
+
19
+ def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias):
20
+ scale = weight * running_var.rsqrt()
21
+ bias = bias - running_mean * scale
22
+ w_scale.copy_(scale)
23
+ w_bias.copy_(bias)
24
+
25
+ def compute_scale_bias_method(nhwc, args):
26
+ for arg in args:
27
+ # arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
28
+ compute_scale_bias_one(nhwc, *arg)
29
+
30
+ class FrozenBatchNorm2d(torch.jit.ScriptModule):
31
+ """
32
+ BatchNorm2d where the batch statistics and the affine parameters are fixed
33
+ """
34
+ def __init__(self, n):
35
+ super(FrozenBatchNorm2d, self).__init__()
36
+ self.register_buffer("weight", torch.ones(n))
37
+ self.register_buffer("bias", torch.zeros(n))
38
+ self.register_buffer("running_mean", torch.zeros(n))
39
+ self.register_buffer("running_var", torch.ones(n))
40
+
41
+ @torch.jit.script_method
42
+ def get_scale_bias(self, nhwc):
43
+ # type: (bool) -> List[torch.Tensor]
44
+ scale = self.weight * self.running_var.rsqrt()
45
+ bias = self.bias - self.running_mean * scale
46
+ if nhwc:
47
+ scale = scale.reshape(1, 1, 1, -1)
48
+ bias = bias.reshape(1, 1, 1, -1)
49
+ else:
50
+ scale = scale.reshape(1, -1, 1, 1)
51
+ bias = bias.reshape(1, -1, 1, 1)
52
+ return scale, bias
53
+
54
+ @torch.jit.script_method
55
+ def forward(self, x):
56
+ scale, bias = self.get_scale_bias(False)
57
+ return x * scale + bias
58
+
59
+ @torch.jit.script
60
+ def drelu_dscale1(grad_o, output, scale1):
61
+ relu_mask = (output>0)
62
+ dx_relu = relu_mask * grad_o
63
+ g1 = dx_relu * scale1
64
+ return g1, dx_relu
65
+
66
+ @torch.jit.script
67
+ def drelu_dscale2(grad_o, output, scale1, scale2):
68
+ relu_mask = (output>0)
69
+ dx_relu = relu_mask * grad_o
70
+ g1 = dx_relu * scale1
71
+ g2 = dx_relu * scale2
72
+ return g1, g2
73
+
74
+ class BottleneckFunction(torch.autograd.Function):
75
+ @staticmethod
76
+ def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv):
77
+ # TODO: clean up order of tensors
78
+ args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
79
+ ctx.downsample = len(conv) > 3
80
+ if ctx.downsample:
81
+ args.append(conv[3])
82
+ args.append(scale[3])
83
+ args.append(bias[3])
84
+
85
+ # weight buffers are always in nhwc while shape can be nhwc or channels_last
86
+ # here we pass in flag and let c++ handle it
87
+ # alternatively, we can put all sizes into a fixed format and pass it in
88
+ outputs = fast_bottleneck.forward(nhwc, stride_1x1, args)
89
+ ctx.save_for_backward(*(args+outputs))
90
+ # save relu outputs for drelu
91
+ ctx.nhwc = nhwc
92
+ ctx.stride_1x1 = stride_1x1
93
+ return outputs[2]
94
+
95
+ # backward relu is not exposed, MUL with mask used now
96
+ # only support dgrad
97
+ @staticmethod
98
+ def backward(ctx, grad_o):
99
+ outputs = ctx.saved_tensors[-3:]
100
+
101
+ if ctx.downsample:
102
+ grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])
103
+ else:
104
+ grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])
105
+
106
+ # create input vector for backward
107
+ t_list = [*ctx.saved_tensors[0:10]]
108
+ t_list.append(grad_conv3)
109
+ t_list.append(grad_conv4)
110
+
111
+ # outputs used for wgrad and generating drelu mask
112
+ t_list.append(outputs[0])
113
+ t_list.append(outputs[1])
114
+
115
+ # in case there is downsample
116
+ if ctx.downsample:
117
+ t_list.append(ctx.saved_tensors[10])
118
+
119
+ grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list)
120
+
121
+ return (None, None, None, None, *grads)
122
+
123
+ bottleneck_function = BottleneckFunction.apply
124
+
125
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
126
+ """3x3 convolution with padding"""
127
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
128
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
129
+
130
+ def conv1x1(in_planes, out_planes, stride=1):
131
+ """1x1 convolution"""
132
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
133
+
134
+ class Bottleneck(torch.nn.Module):
135
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
136
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
137
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
138
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
139
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
140
+ # here we put it at 1x1
141
+
142
+ def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
143
+ dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False):
144
+ super(Bottleneck, self).__init__()
145
+ if groups != 1:
146
+ raise RuntimeError('Only support groups == 1')
147
+ if dilation != 1:
148
+ raise RuntimeError('Only support dilation == 1')
149
+ if norm_func == None:
150
+ norm_func = FrozenBatchNorm2d
151
+ else:
152
+ raise RuntimeError('Only support frozen BN now.')
153
+
154
+ if stride != 1 or in_channels != out_channels:
155
+ self.downsample = nn.Sequential(
156
+ conv1x1(in_channels, out_channels, stride),
157
+ norm_func(out_channels),
158
+ )
159
+ else:
160
+ self.downsample = None
161
+
162
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
163
+ self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)
164
+ self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
165
+ self.conv3 = conv1x1(bottleneck_channels, out_channels)
166
+ self.relu = nn.ReLU(inplace=True)
167
+ self.stride = stride
168
+
169
+ self.bn1 = norm_func(bottleneck_channels)
170
+ self.bn2 = norm_func(bottleneck_channels)
171
+ self.bn3 = norm_func(out_channels)
172
+ self.w_scale = None
173
+
174
+ self.use_cudnn = use_cudnn
175
+
176
+ # setup conv weights
177
+ self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]
178
+ if self.downsample is not None:
179
+ self.w_conv.append(self.downsample[0].weight)
180
+
181
+ # init weight in nchw format before possible transpose
182
+ for w in self.w_conv:
183
+ kaiming_uniform_(w, a=1)
184
+
185
+ # TODO: prevent unsupported case usage
186
+ # support cases
187
+ # native cudnn
188
+ # normal yes no
189
+ # channel_last yes yes
190
+ # explicit_nhwc no yes
191
+ self.explicit_nhwc = explicit_nhwc
192
+ if self.explicit_nhwc:
193
+ for p in self.parameters():
194
+ with torch.no_grad():
195
+ p.data = p.data.permute(0,2,3,1).contiguous()
196
+
197
+ return
198
+
199
+ # Returns single callable that recomputes scale and bias for all frozen batch-norms.
200
+ # This method must be called before cuda graphing.
201
+ # The callable it returns can be called anytime.
202
+ # Calling this method will prevent these from being computed every forward call.
203
+ def get_scale_bias_callable(self):
204
+ self.w_scale, self.w_bias, args = [], [], []
205
+ batch_norms = [self.bn1, self.bn2, self.bn3]
206
+ if self.downsample is not None:
207
+ batch_norms.append(self.downsample[1])
208
+ for bn in batch_norms:
209
+ s = torch.empty_like(bn.weight)
210
+ b = torch.empty_like(s)
211
+ args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
212
+ if self.explicit_nhwc:
213
+ self.w_scale.append( s.reshape(1, 1, 1, -1) )
214
+ self.w_bias.append( b.reshape(1, 1, 1, -1) )
215
+ else:
216
+ self.w_scale.append( s.reshape(1, -1, 1, 1) )
217
+ self.w_bias.append( b.reshape(1, -1, 1, 1) )
218
+ return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
219
+
220
+ def forward(self, x):
221
+ if self.use_cudnn:
222
+ if self.w_scale is None:
223
+ # calculate scale/bias from registered buffers
224
+ # TODO: make this better
225
+ s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
226
+ s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
227
+ s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
228
+ w_scale = [s1, s2, s3]
229
+ w_bias = [b1, b2, b3]
230
+ if self.downsample is not None:
231
+ s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
232
+ w_scale.append(s4)
233
+ w_bias.append(b4)
234
+ out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
235
+ else:
236
+ out = bottleneck_function(self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, x, *self.w_conv)
237
+ return out
238
+
239
+ if self.explicit_nhwc:
240
+ raise RuntimeError('explicit nhwc with native ops is not supported.')
241
+
242
+ # fallback to native ops
243
+ identity = x
244
+
245
+ out = self.conv1(x)
246
+ out = self.bn1(out)
247
+ out = self.relu(out)
248
+
249
+ out = self.conv2(out)
250
+ out = self.bn2(out)
251
+ out = self.relu(out)
252
+
253
+ out = self.conv3(out)
254
+ out = self.bn3(out)
255
+
256
+ if self.downsample is not None:
257
+ identity = self.downsample(x)
258
+
259
+ out += identity
260
+ out = self.relu(out)
261
+
262
+ return out
263
+
264
+
265
+ class SpatialBottleneckFunction(torch.autograd.Function):
266
+ @staticmethod
267
+ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv):
268
+ if spatial_group_size > 1:
269
+ stream1 = spatial_halo_exchanger.stream1
270
+ stream2 = spatial_halo_exchanger.stream2
271
+ stream3 = spatial_halo_exchanger.stream3
272
+
273
+ # TODO: clean up order of tensors
274
+ args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
275
+ ctx.downsample = len(conv) > 3
276
+ if ctx.downsample:
277
+ args.append(conv[3])
278
+ args.append(scale[3])
279
+ args.append(bias[3])
280
+
281
+ # weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last
282
+ # here we pass in flag and let c++ handle it
283
+ # alternatively, we can put all sizes into a fixed format and pass it in
284
+ outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args)
285
+ fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs)
286
+
287
+ if spatial_group_size > 1:
288
+ out1 = outputs[0]
289
+ if explicit_nhwc:
290
+ N,Hs,W,C = list(out1.shape)
291
+ memory_format = torch.contiguous_format
292
+ out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
293
+ else:
294
+ N,C,Hs,W = list(out1.shape)
295
+ memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
296
+ out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
297
+ stream1.wait_stream(torch.cuda.current_stream())
298
+ if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream())
299
+ with torch.cuda.stream(stream1):
300
+ if explicit_nhwc:
301
+ top_out1_halo = out1_pad[:,:1,:,:]
302
+ btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
303
+ spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
304
+ else:
305
+ top_out1_halo = out1_pad[:,:,:1,:]
306
+ btm_out1_halo = out1_pad[:,:,Hs+1:Hs+2,:]
307
+ spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
308
+ if spatial_method == 1:
309
+ # overlap mid convolution with halo transfer
310
+ if spatial_group_rank < spatial_group_size-1:
311
+ stream2.wait_stream(stream1)
312
+ with torch.cuda.stream(stream2):
313
+ if explicit_nhwc:
314
+ btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
315
+ btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
316
+ btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
317
+ else:
318
+ btm_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
319
+ btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
320
+ btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
321
+ btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
322
+ if spatial_group_rank > 0:
323
+ with torch.cuda.stream(stream1):
324
+ if explicit_nhwc:
325
+ top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
326
+ top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
327
+ top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
328
+ else:
329
+ top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
330
+ top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
331
+ top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
332
+ top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
333
+ if use_delay_kernel: inc.add_delay(10)
334
+ elif spatial_method != 2 and spatial_method != 3:
335
+ assert(False), "spatial_method must be 1, 2 or 3"
336
+
337
+ if spatial_group_size <= 1:
338
+ fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
339
+ elif spatial_method == 1:
340
+ fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
341
+ with torch.cuda.stream(stream3):
342
+ if explicit_nhwc:
343
+ out1_pad[:,1:Hs+1,:,:].copy_(out1)
344
+ else:
345
+ out1_pad[:,:,1:Hs+1,:].copy_(out1)
346
+ elif spatial_method == 2:
347
+ # wait for halo transfer to finish before doing a full convolution of padded x
348
+ if explicit_nhwc:
349
+ out1_pad[:,1:Hs+1,:,:].copy_(out1)
350
+ else:
351
+ out1_pad[:,:,1:Hs+1,:].copy_(out1)
352
+ torch.cuda.current_stream().wait_stream(stream1)
353
+ fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
354
+ elif spatial_method == 3:
355
+ fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
356
+ with torch.cuda.stream(stream3):
357
+ if explicit_nhwc:
358
+ out1_pad[:,1:Hs+1,:,:].copy_(out1)
359
+ else:
360
+ out1_pad[:,:,1:Hs+1,:].copy_(out1)
361
+
362
+ # compute halo cells for outputs[1] (out2)
363
+ if spatial_group_size > 1:
364
+ out2 = outputs[1]
365
+ if explicit_nhwc:
366
+ top_out2_halo = out2[:,:1,:,:]
367
+ btm_out2_halo = out2[:,Hs-1:,:,:]
368
+ else:
369
+ top_out2_halo = out2[:,:,:1,:]
370
+ btm_out2_halo = out2[:,:,Hs-1:,:]
371
+ if spatial_method == 1:
372
+ if spatial_group_rank > 0:
373
+ torch.cuda.current_stream().wait_stream(stream1)
374
+ top_out2_halo.copy_(top_out2)
375
+ if spatial_group_rank < spatial_group_size-1:
376
+ torch.cuda.current_stream().wait_stream(stream2)
377
+ btm_out2_halo.copy_(btm_out2)
378
+ elif spatial_method == 3:
379
+ # Note
380
+ # out2 halo correction cannot overlap with anything since it has
381
+ # to wait for out2_mask to finish, but itself has to finish before
382
+ # the first kernel of _forward_rest can launch.
383
+ # At least we can overlap the two halo correction kernels.
384
+ if spatial_group_rank < spatial_group_size-1:
385
+ stream2.wait_stream(stream1) # wait for halo transfers to finish
386
+ stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
387
+ with torch.cuda.stream(stream2):
388
+ w1by3 = args[2][:,2:3,:,:].clone()
389
+ btm_out1_halo = btm_out1_halo.clone()
390
+ btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
391
+ btm_out2_halo.copy_(btm_out2)
392
+ if spatial_group_rank > 0:
393
+ stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
394
+ with torch.cuda.stream(stream1):
395
+ w1by3 = args[2][:,:1,:,:].clone()
396
+ top_out1_halo = top_out1_halo.clone()
397
+ top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
398
+ top_out2_halo.copy_(top_out2)
399
+ if spatial_group_rank < spatial_group_size-1:
400
+ torch.cuda.current_stream().wait_stream(stream2)
401
+ if spatial_group_rank > 0:
402
+ torch.cuda.current_stream().wait_stream(stream1)
403
+
404
+ fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
405
+ # save halos for backward pass
406
+ if spatial_group_size > 1:
407
+ if spatial_method != 2:
408
+ # make sure copy of mid-section of out1 into out1_pad is done before exiting
409
+ torch.cuda.current_stream().wait_stream(stream3)
410
+ ctx.save_for_backward(*(args+outputs+[out1_pad,]))
411
+ else:
412
+ ctx.save_for_backward(*(args+outputs))
413
+ # save relu outputs for drelu
414
+ ctx.explicit_nhwc = explicit_nhwc
415
+ ctx.stride_1x1 = stride_1x1
416
+ ctx.spatial_group_size = spatial_group_size
417
+ if spatial_group_size > 1:
418
+ ctx.spatial_group_rank = spatial_group_rank
419
+ ctx.spatial_halo_exchanger = spatial_halo_exchanger
420
+ ctx.spatial_method = spatial_method
421
+ ctx.use_delay_kernel = use_delay_kernel
422
+ ctx.thresholdTop = thresholdTop
423
+ ctx.thresholdBottom = thresholdBottom
424
+ ctx.stream1 = stream1
425
+ ctx.stream2 = stream2
426
+ ctx.stream3 = stream3
427
+ return outputs[2]
428
+
429
+ # backward relu is not exposed, MUL with mask used now
430
+ # only support dgrad
431
+ @staticmethod
432
+ def backward(ctx, grad_o):
433
+ if ctx.spatial_group_size > 1:
434
+ out1_pad = ctx.saved_tensors[-1]
435
+ outputs = ctx.saved_tensors[-4:-1]
436
+ else:
437
+ outputs = ctx.saved_tensors[-3:]
438
+
439
+ if ctx.downsample:
440
+ grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11])
441
+ else:
442
+ grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6])
443
+
444
+ # create input vector for backward
445
+ t_list = [*ctx.saved_tensors[0:10]]
446
+ t_list.append(grad_conv3)
447
+ t_list.append(grad_conv4)
448
+
449
+ # outputs used for wgrad and generating drelu mask
450
+ t_list.append(outputs[0])
451
+ t_list.append(outputs[1])
452
+
453
+ # in case there is downsample
454
+ if ctx.downsample:
455
+ t_list.append(ctx.saved_tensors[10])
456
+
457
+ grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
458
+ wgrad3_stream = torch.cuda.Stream()
459
+ wgrad3_stream.wait_stream(torch.cuda.current_stream())
460
+ grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
461
+ wgrad2_stream = torch.cuda.Stream()
462
+ wgrad2_stream.wait_stream(torch.cuda.current_stream())
463
+ # do halo exchange of grad_out2 here
464
+ # compute halo cells for grad_out1
465
+ if ctx.spatial_group_size > 1:
466
+ if ctx.explicit_nhwc:
467
+ N,Hs,W,C = list(grad_out2.shape)
468
+ else:
469
+ N,C,Hs,W = list(grad_out2.shape)
470
+ relu1 = t_list[12]
471
+ ctx.stream1.wait_stream(torch.cuda.current_stream())
472
+ with torch.cuda.stream(ctx.stream1):
473
+ top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:])
474
+ # copy halos to send buffer
475
+ if ctx.spatial_method == 1 or ctx.spatial_method == 2:
476
+ # 1 -> halo recompute approach
477
+ # 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
478
+ if ctx.spatial_group_rank < ctx.spatial_group_size-1:
479
+ ctx.stream2.wait_stream(ctx.stream1)
480
+ with torch.cuda.stream(ctx.stream2):
481
+ if ctx.explicit_nhwc:
482
+ btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
483
+ btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
484
+ btm_fat_halo[:,2:,:,:].copy_(btm_halo)
485
+ btm_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
486
+ btm_fat_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
487
+ btm_fat_relu_halo[:,2:,:,:].zero_()
488
+ else:
489
+ btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
490
+ btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
491
+ btm_fat_halo[:,:,2:,:].copy_(btm_halo)
492
+ btm_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
493
+ btm_fat_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
494
+ btm_fat_relu_halo[:,:,2:,:].zero_()
495
+ btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo)
496
+ if ctx.explicit_nhwc:
497
+ btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
498
+ else:
499
+ btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
500
+ if ctx.spatial_group_rank > 0:
501
+ with torch.cuda.stream(ctx.stream1):
502
+ if ctx.explicit_nhwc:
503
+ top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
504
+ top_fat_halo[:,:1,:,:].copy_(top_halo)
505
+ top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
506
+ top_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
507
+ top_fat_relu_halo[:,:1,:,:].zero_()
508
+ top_fat_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
509
+ else:
510
+ top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
511
+ top_fat_halo[:,:,:1,:].copy_(top_halo)
512
+ top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
513
+ top_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
514
+ top_fat_relu_halo[:,:,:1,:].zero_()
515
+ top_fat_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
516
+ top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo)
517
+ if ctx.explicit_nhwc:
518
+ top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
519
+ else:
520
+ top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
521
+ if ctx.use_delay_kernel: inc.add_delay(10)
522
+ elif ctx.spatial_method != 3:
523
+ assert(False), "spatial_method must be 1, 2 or 3"
524
+
525
+ # compute grad_out1 for internal cells
526
+ if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
527
+ grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
528
+ elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:
529
+ grad_out1 = fast_bottleneck.backward_grad_out1_mask(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom)
530
+
531
+ # apply halo cells to grad_out1
532
+ if ctx.spatial_group_size > 1:
533
+ w = t_list[2]
534
+ z = t_list[4]
535
+ relu1 = t_list[12]
536
+ #print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
537
+ if ctx.spatial_method == 1 or ctx.spatial_method == 2:
538
+ if ctx.spatial_group_rank < ctx.spatial_group_size-1:
539
+ torch.cuda.current_stream().wait_stream(ctx.stream2)
540
+ if ctx.explicit_nhwc:
541
+ grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
542
+ else:
543
+ grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
544
+ #print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
545
+ if ctx.spatial_group_rank > 0:
546
+ torch.cuda.current_stream().wait_stream(ctx.stream1)
547
+ if ctx.explicit_nhwc:
548
+ grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
549
+ else:
550
+ grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
551
+ #print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
552
+ elif ctx.spatial_method == 3:
553
+ if ctx.spatial_group_rank < ctx.spatial_group_size-1:
554
+ if ctx.explicit_nhwc:
555
+ btm_relu_halo = relu1[:,Hs-1:,:,:].clone()
556
+ btm_grad_out1 = grad_out1[:,Hs-1:,:,:]
557
+ else:
558
+ btm_relu_halo = relu1[:,:,Hs-1:,:].clone()
559
+ btm_grad_out1 = grad_out1[:,:,Hs-1:,:]
560
+ w1by3 = w[:,:1,:,:].clone()
561
+ ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish
562
+ ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
563
+ with torch.cuda.stream(ctx.stream2):
564
+ btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone())
565
+ btm_grad_out1.copy_(btm_grad_out1_halo)
566
+ if ctx.spatial_group_rank > 0:
567
+ if ctx.explicit_nhwc:
568
+ top_relu_halo = relu1[:,:1,:,:].clone()
569
+ top_grad_out1 = grad_out1[:,:1,:,:]
570
+ else:
571
+ top_relu_halo = relu1[:,:,:1,:].clone()
572
+ top_grad_out1 = grad_out1[:,:,:1,:]
573
+ w1by3 = w[:,2:,:,:].clone()
574
+ ctx.stream1.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
575
+ with torch.cuda.stream(ctx.stream1):
576
+ top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone())
577
+ top_grad_out1.copy_(top_grad_out1_halo)
578
+ if ctx.spatial_group_rank < ctx.spatial_group_size-1:
579
+ torch.cuda.current_stream().wait_stream(ctx.stream2) # wait for halo correction to finish
580
+ if ctx.spatial_group_rank > 0:
581
+ torch.cuda.current_stream().wait_stream(ctx.stream1)
582
+
583
+ wgrad1_stream = torch.cuda.Stream()
584
+ wgrad1_stream.wait_stream(torch.cuda.current_stream())
585
+ fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1)
586
+ with torch.cuda.stream(wgrad3_stream):
587
+ fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
588
+ with torch.cuda.stream(wgrad2_stream):
589
+ if ctx.spatial_group_size > 1:
590
+ fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
591
+ else:
592
+ fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
593
+ with torch.cuda.stream(wgrad1_stream):
594
+ fast_bottleneck.backward_wgrad1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1)
595
+ torch.cuda.current_stream().wait_stream(wgrad3_stream)
596
+ torch.cuda.current_stream().wait_stream(wgrad2_stream)
597
+ torch.cuda.current_stream().wait_stream(wgrad1_stream)
598
+
599
+ return (None, None, None, None, None, None, None, None, None, None, None, None, *grads)
600
+
601
+ spatial_bottleneck_function = SpatialBottleneckFunction.apply
602
+
603
+ class SpatialBottleneck(torch.nn.Module):
604
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
605
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
606
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
607
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
608
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
609
+ # here we put it at 1x1
610
+
611
+ def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
612
+ dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
613
+ spatial_parallel_args=None):
614
+ super(SpatialBottleneck, self).__init__()
615
+ if groups != 1:
616
+ raise RuntimeError('Only support groups == 1')
617
+ if dilation != 1:
618
+ raise RuntimeError('Only support dilation == 1')
619
+ if norm_func == None:
620
+ norm_func = FrozenBatchNorm2d
621
+ else:
622
+ raise RuntimeError('Only support frozen BN now.')
623
+
624
+ if stride != 1 or in_channels != out_channels:
625
+ self.downsample = nn.Sequential(
626
+ conv1x1(in_channels, out_channels, stride),
627
+ norm_func(out_channels),
628
+ )
629
+ else:
630
+ self.downsample = None
631
+
632
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
633
+ self.conv1 = conv1x1(in_channels, bottleneck_channels, stride)
634
+ self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
635
+ self.conv3 = conv1x1(bottleneck_channels, out_channels)
636
+ self.relu = nn.ReLU(inplace=True)
637
+ self.stride = stride
638
+
639
+ self.bn1 = norm_func(bottleneck_channels)
640
+ self.bn2 = norm_func(bottleneck_channels)
641
+ self.bn3 = norm_func(out_channels)
642
+ self.w_scale = None
643
+
644
+ self.use_cudnn = use_cudnn
645
+
646
+ # setup conv weights
647
+ self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight]
648
+ if self.downsample is not None:
649
+ self.w_conv.append(self.downsample[0].weight)
650
+
651
+ # init weight in nchw format before possible transpose
652
+ for w in self.w_conv:
653
+ kaiming_uniform_(w, a=1)
654
+
655
+ self.thresholdTop, self.thresholdBottom = None, None
656
+
657
+ # TODO: prevent unsupported case usage
658
+ # support cases
659
+ # native cudnn
660
+ # normal yes no
661
+ # channel_last yes yes
662
+ # explicit_nhwc no yes
663
+ self.explicit_nhwc = explicit_nhwc
664
+ if self.explicit_nhwc:
665
+ for p in self.parameters():
666
+ with torch.no_grad():
667
+ p.data = p.data.permute(0,2,3,1).contiguous()
668
+
669
+ # spatial communicator
670
+ if spatial_parallel_args is None:
671
+ self.spatial_parallel_args = (1, 0, None, None, 0, False)
672
+ else:
673
+ self.spatial_parallel_args = spatial_parallel_args
674
+ return
675
+
676
+ # Returns single callable that recomputes scale and bias for all frozen batch-norms.
677
+ # This method must be called before cuda graphing.
678
+ # The callable it returns can be called anytime.
679
+ # Calling this method will prevent these from being computed every forward call.
680
+ def get_scale_bias_callable(self):
681
+ self.w_scale, self.w_bias, args = [], [], []
682
+ batch_norms = [self.bn1, self.bn2, self.bn3]
683
+ if self.downsample is not None:
684
+ batch_norms.append(self.downsample[1])
685
+ for bn in batch_norms:
686
+ s = torch.empty_like(bn.weight)
687
+ b = torch.empty_like(s)
688
+ args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
689
+ if self.explicit_nhwc:
690
+ self.w_scale.append( s.reshape(1, 1, 1, -1) )
691
+ self.w_bias.append( b.reshape(1, 1, 1, -1) )
692
+ else:
693
+ self.w_scale.append( s.reshape(1, -1, 1, 1) )
694
+ self.w_bias.append( b.reshape(1, -1, 1, 1) )
695
+ return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
696
+
697
+ def forward(self, x):
698
+ if self.use_cudnn:
699
+ if self.thresholdTop is None:
700
+ spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args
701
+ if self.explicit_nhwc:
702
+ N,H,W,C = list(x.shape)
703
+ else:
704
+ N,C,H,W = list(x.shape)
705
+ self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda')
706
+ self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda')
707
+
708
+ if self.w_scale is None:
709
+ # calculate scale/bias from registered buffers
710
+ # TODO: make this better
711
+ s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
712
+ s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
713
+ s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
714
+ w_scale = [s1, s2, s3]
715
+ w_bias = [b1, b2, b3]
716
+ if self.downsample is not None:
717
+ s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
718
+ w_scale.append(s4)
719
+ w_bias.append(b4)
720
+ out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
721
+ else:
722
+ out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
723
+ return out
724
+
725
+ if self.explicit_nhwc:
726
+ raise RuntimeError('explicit nhwc with native ops is not supported.')
727
+
728
+ # fallback to native ops
729
+ identity = x
730
+
731
+ out = self.conv1(x)
732
+ out = self.bn1(out)
733
+ out = self.relu(out)
734
+
735
+ out = self.conv2(out)
736
+ out = self.bn2(out)
737
+ out = self.relu(out)
738
+
739
+ out = self.conv3(out)
740
+ out = self.bn3(out)
741
+
742
+ if self.downsample is not None:
743
+ identity = self.downsample(x)
744
+
745
+ out += identity
746
+ out = self.relu(out)
747
+
748
+ return out
749
+
apex/contrib/bottleneck/halo_exchangers.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch import nn
4
+ import nccl_p2p_cuda as inc
5
+ import peer_memory_cuda as pm
6
+
7
+ # Communication free halo exchanger.
8
+ # NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
9
+ # NB! This is only useful for performance testing.
10
+ # NB! Do not use for actual production runs
11
+ class HaloExchanger(object):
12
+ def __init__(self, ranks, rank_in_group):
13
+ self.stream1 = torch.cuda.Stream()
14
+ self.stream2 = torch.cuda.Stream()
15
+ self.stream3 = torch.cuda.Stream()
16
+ self.group_size = len(ranks)
17
+ self.ranks = ranks
18
+ self.rank_in_group = rank_in_group
19
+ self.wrap_around_left_rank_in_group = (rank_in_group + self.group_size - 1) % self.group_size
20
+ self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size
21
+ self.left_rank = ranks[rank_in_group-1] if rank_in_group > 0 else -1
22
+ self.left_zero = True if rank_in_group == 0 else False
23
+ self.right_rank = ranks[rank_in_group+1] if rank_in_group < self.group_size - 1 else -1
24
+ self.right_zero = True if rank_in_group == self.group_size - 1 else False
25
+
26
+ class HaloExchangerNoComm(HaloExchanger):
27
+ def __init__(self, ranks, rank_in_group):
28
+ super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group)
29
+
30
+ def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
31
+ if left_input_halo is None:
32
+ return right_output_halo, left_output_halo
33
+ else:
34
+ left_input_halo.copy_(right_output_halo)
35
+ right_input_halo.copy_(left_output_halo)
36
+
37
+ class HaloExchangerAllGather(HaloExchanger):
38
+ def __init__(self, ranks, rank_in_group, comm):
39
+ super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group)
40
+ # self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
41
+ self.comm = comm
42
+
43
+ def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
44
+ N,Hh,W,C = list(left_output_halo.shape)
45
+ send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
46
+ send_halos[:,:Hh,:,:].copy_(left_output_halo)
47
+ send_halos[:,Hh:,:,:].copy_(right_output_halo)
48
+ all_halos = torch.empty((N,2*Hh*self.group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
49
+ all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.group_size)]
50
+ torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
51
+ ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:,Hh:,:,:]
52
+ ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:,:Hh,:,:]
53
+ if left_input_halo is None:
54
+ if self.left_zero:
55
+ ag_left_input_halo.zero_()
56
+ if self.right_zero:
57
+ ag_right_input_halo.zero_()
58
+ return ag_left_input_halo, ag_right_input_halo
59
+ else:
60
+ if self.left_zero:
61
+ left_input_halo.zero_()
62
+ else:
63
+ left_input_halo.copy_(ag_left_input_halo)
64
+ if self.right_zero:
65
+ right_input_halo.zero_()
66
+ else:
67
+ right_input_halo.copy_(ag_right_input_halo)
68
+
69
+ class HaloExchangerSendRecv(HaloExchanger):
70
+ def __init__(self, ranks, rank_in_group):
71
+ super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group)
72
+ nccl_id = inc.get_unique_nccl_id(1).cuda()
73
+ torch.distributed.broadcast(nccl_id, 0)
74
+ nccl_id = nccl_id.cpu()
75
+ print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id)))
76
+ # Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
77
+ # This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
78
+ # it cannot be accessed from another class.
79
+ # TODO: Figure out a way to avoid creating a second global communicator
80
+ assert(torch.distributed.get_rank() == self.ranks[self.rank_in_group]), "ranks[%d](%d) != torch.distributed.get_rank()(%d)" % (self.rank_in_group, self.ranks[self.rank_in_group], torch.distributed.get_rank())
81
+ self.handle = inc.init_nccl_comm(nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size())
82
+
83
+ def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
84
+ if left_input_halo is None:
85
+ left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_rank, self.right_rank , left_output_halo, right_output_halo)
86
+ return left_input_halo, right_input_halo
87
+ else:
88
+ inc.left_right_halo_exchange_inplace(self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo)
89
+
90
+ class HaloExchangerPeer(HaloExchanger):
91
+ def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0):
92
+ super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
93
+ self.diagnostics = False
94
+ self.explicit_nhwc = explicit_nhwc
95
+ self.numSM = numSM
96
+ self.peer_pool = peer_pool
97
+
98
+ def _allocate_peer_tensor(self, halo):
99
+
100
+ # Compute size in bytes
101
+ # Note: Pad buffer so each CUDA block gets required buffer size
102
+ size = 4 * halo.numel() * halo.element_size()
103
+ size_per_block = 128 * 2 * 16 # 128 threads each require two 128b buffers
104
+ size = (size + size_per_block - 1) // size_per_block * size_per_block
105
+
106
+ # Construct dtype peer buffer with desired size
107
+ shape = [1, 1, 1, size // halo.element_size()]
108
+ return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True)
109
+
110
+ def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
111
+ inplace = False if left_input_halo is None and right_input_halo is None else True
112
+ if not inplace:
113
+ left_input_halo = torch.empty_like(right_output_halo)
114
+ right_input_halo = torch.empty_like(left_output_halo)
115
+ channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
116
+ left_tx = self._allocate_peer_tensor(left_input_halo)
117
+ right_tx = self._allocate_peer_tensor(right_input_halo)
118
+ pm.push_pull_halos_1d(
119
+ self.diagnostics, self.explicit_nhwc, self.numSM, self.rank_in_group,
120
+ self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
121
+ self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
122
+ )
123
+ if not inplace:
124
+ return left_input_halo, right_input_halo
125
+
126
+ # Class that combines input volume with halos from neighbors (1d).
127
+ class HaloPadder:
128
+ def __init__(self, halo_ex):
129
+ self.halo_ex = halo_ex
130
+ self.stream1 = torch.cuda.Stream()
131
+ self.stream2 = torch.cuda.Stream()
132
+
133
+ def __call__(self, y, half_halo, explicit_nhwc, H_split):
134
+ channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
135
+ if explicit_nhwc:
136
+ N,H,W,C = list(y.shape)
137
+ if H_split:
138
+ padded_shape = [N,H+2*half_halo,W,C]
139
+ ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
140
+ yleft = ypad[:,:half_halo,:,:]
141
+ ymid = ypad[:,half_halo:H+half_halo,:,:]
142
+ yright = ypad[:,H+half_halo:H+2*half_halo,:,:]
143
+ oleft = y[:,:half_halo,:,:]
144
+ oright = y[:,H-half_halo:,:,:]
145
+ else:
146
+ padded_shape = [N,H,W+2*half_halo,C]
147
+ ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
148
+ yleft = ypad[:,:,:half_halo,:]
149
+ ymid = ypad[:,:,half_halo:W+half_halo,:]
150
+ yright = ypad[:,:,W+half_halo:W+2*half_halo,:]
151
+ oleft = y[:,:,:half_halo,:]
152
+ oright = y[:,:,W-half_halo:,:]
153
+ else:
154
+ N,C,H,W = list(y.shape)
155
+ if H_split:
156
+ padded_shape = [N,C,H+2*half_halo,W]
157
+ ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
158
+ yleft = ypad[:,:,:half_halo,:]
159
+ ymid = ypad[:,:,half_halo:H+half_halo,:]
160
+ yright = ypad[:,:,H+half_halo:H+2*half_halo,:]
161
+ oleft = y[:,:,:half_halo,:]
162
+ oright = y[:,:,H-half_halo:,:]
163
+ else:
164
+ padded_shape = [N,C,H,W+2*half_halo]
165
+ ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
166
+ yleft = ypad[:,:,:,:half_halo]
167
+ ymid = ypad[:,:,:,half_halo:W+half_halo]
168
+ yright = ypad[:,:,:,W+half_halo:W+2*half_halo]
169
+ oleft = y[:,:,:,:half_halo]
170
+ oright = y[:,:,:,W-half_halo:]
171
+ with torch.cuda.stream(self.stream1):
172
+ self.halo_ex(oleft, oright, yleft, yright)
173
+ with torch.cuda.stream(self.stream2):
174
+ ymid.copy_(y)
175
+ return ypad
176
+
177
+ def wait(self):
178
+ current_stream = torch.cuda.current_stream()
179
+ current_stream.wait_stream(self.stream1)
180
+ current_stream.wait_stream(self.stream2)
apex/contrib/bottleneck/test.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from bottleneck import Bottleneck
3
+ torch.manual_seed(23337)
4
+
5
+ # use True to print layerwise sum for all outputs in reference code path
6
+ DEBUG = False#True
7
+
8
+ for stride, o_channel in [(1,32), (1,128), (2,32)]:
9
+ print("testing stride ==", stride, ", in_channel == 32 , out_channel ==", o_channel)
10
+ a_ = torch.randn(17,32,28,28)
11
+
12
+ a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_()
13
+ model = Bottleneck(32,8,o_channel,stride=stride).cuda().half().to(memory_format=torch.channels_last)
14
+
15
+ # test model
16
+ b = model(a)
17
+ b.mean().backward()
18
+ d_grad = a.grad.float()
19
+ a.grad = None
20
+ torch.cuda.synchronize()
21
+
22
+ if DEBUG:
23
+ print("[DEBUG] ref dx :", d_grad.sum().item())
24
+ # print wgrad. we don't need to reset since later cpp print before accumulation
25
+ for i, w in enumerate(model.w_conv):
26
+ print("[DEBUG] ref wgrad{} :".format(i+1), w.grad.sum().item())
27
+
28
+ wgrads = []
29
+ for w in model.w_conv:
30
+ wgrads.append(w.grad.float())
31
+
32
+ model.use_cudnn = True
33
+ model.zero_grad()
34
+ c = model(a)
35
+ c.mean().backward()
36
+
37
+ torch.cuda.synchronize()
38
+ print("comparing native and channels_last:")
39
+ print("max error fprop:", (b-c).abs().max().item(), "max elem:", b.abs().max().item())
40
+ print("max error dgrad:", (d_grad-a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item())
41
+ for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)):
42
+ print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item())
43
+
44
+ nhwc_a = a_.permute(0,2,3,1).contiguous().cuda().half().requires_grad_()
45
+ nhwc_model = Bottleneck(32,8,o_channel,stride=stride,explicit_nhwc=True, use_cudnn=True).cuda().half()
46
+ for p,q in zip(model.parameters(), nhwc_model.parameters()):
47
+ # model's storage is already in nhwc, we clone and assign to explicit nhwc model
48
+ q.data.copy_(p.data.permute(0,2,3,1).contiguous())
49
+ for p,q in zip(model.buffers(), nhwc_model.buffers()):
50
+ q.data.copy_(p.data)
51
+
52
+ d = nhwc_model(nhwc_a)
53
+ d.mean().backward()
54
+ torch.cuda.synchronize()
55
+
56
+ # reset reference to cudnn channels_last permute
57
+ #c_s = c.storage().tolist()
58
+ #d_s = d.storage().tolist()
59
+ #print(max([x-y for x,y in zip(c_s,d_s)]))
60
+ c = c.contiguous(memory_format=torch.contiguous_format).permute(0,2,3,1).contiguous()
61
+ d_grad = a.grad.float().permute(0,2,3,1).contiguous()
62
+ wgrads = []
63
+ for w in model.w_conv:
64
+ wgrads.append(w.grad.float().permute(0,2,3,1).contiguous())
65
+
66
+ torch.cuda.synchronize()
67
+ print("comparing nhwc and channels_last:")
68
+ print("max error fprop:", (d-c).abs().max().item(), "max elem:", c.abs().max().item())
69
+ print("max error dgrad:", (d_grad-nhwc_a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item())
70
+ for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)):
71
+ print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item())
apex/contrib/clip_grad/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip_grad import clip_grad_norm_
apex/contrib/clip_grad/clip_grad.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Iterable
2
+
3
+ import torch
4
+
5
+ _kernel_import_succeeded = False
6
+ try:
7
+ import amp_C
8
+ from apex.multi_tensor_apply import multi_tensor_applier
9
+ _kernel_import_succeeded = True
10
+ except ImportError:
11
+ _kernel_import_succeeded = False
12
+
13
+ _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
14
+
15
+
16
+ def clip_grad_norm_(
17
+ parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
18
+ error_if_nonfinite: bool = False) -> torch.Tensor:
19
+ r"""Clips gradient norm of an iterable of parameters.
20
+
21
+ The norm is computed over all gradients together, as if they were
22
+ concatenated into a single vector. Gradients are modified in-place.
23
+
24
+ This is identical to torch.nn.utils.clip_grad_norm_, except it
25
+ uses a fused CUDA kernel when computing the 2-norm of GPU tensors
26
+ in float32 and float16.
27
+
28
+ Args:
29
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
30
+ single Tensor that will have gradients normalized
31
+ max_norm (float or int): max norm of the gradients
32
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
33
+ infinity norm.
34
+ error_if_nonfinite (bool): if True, an error is thrown if the total
35
+ norm of the gradients from :attr:`parameters` is ``nan``,
36
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
37
+
38
+ Returns:
39
+ Total norm of the parameters (viewed as a single vector).
40
+
41
+ """
42
+ if isinstance(parameters, torch.Tensor):
43
+ parameters = [parameters]
44
+ parameters = [p for p in parameters if p.grad is not None]
45
+ max_norm = float(max_norm)
46
+ norm_type = float(norm_type)
47
+
48
+ # Trivial case
49
+ if len(parameters) == 0:
50
+ return torch.tensor(0.)
51
+
52
+ # Fallback implementation
53
+ if not (_kernel_import_succeeded
54
+ and norm_type == 2.0
55
+ and any(p.is_cuda for p in parameters)):
56
+ return torch.nn.utils.clip_grad_norm_(
57
+ parameters,
58
+ max_norm,
59
+ norm_type=norm_type,
60
+ error_if_nonfinite = error_if_nonfinite,
61
+ )
62
+
63
+ # Find fp32 and fp16 gradients on GPU
64
+ device = next(p.device for p in parameters if p.is_cuda)
65
+ grads_fp32, grads_fp16, grads_misc = [], [], []
66
+ for p in parameters:
67
+ grad = p.grad.detach()
68
+ if p.dtype == torch.float32 and p.device == device:
69
+ grads_fp32.append(grad)
70
+ elif p.dtype == torch.float16 and p.device == device:
71
+ grads_fp16.append(grad)
72
+ else:
73
+ grads_misc.append(grad)
74
+
75
+ # Compute gradient L2 norms
76
+ norms = []
77
+ dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device)
78
+ if grads_fp32:
79
+ norms.append(
80
+ multi_tensor_applier(
81
+ amp_C.multi_tensor_l2norm,
82
+ dummy_overflow_buf,
83
+ [grads_fp32],
84
+ False,
85
+ )[0]
86
+ )
87
+ if grads_fp16:
88
+ norms.append(
89
+ multi_tensor_applier(
90
+ amp_C.multi_tensor_l2norm,
91
+ dummy_overflow_buf,
92
+ [grads_fp16],
93
+ False,
94
+ )[0],
95
+ )
96
+ for g in grads_misc:
97
+ norms.append(torch.linalg.norm(g).unsqueeze(0).to(device))
98
+ total_norm = torch.linalg.norm(torch.cat(norms))
99
+
100
+ # Check for non-finite values
101
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
102
+ raise RuntimeError(
103
+ f'The total norm of order {norm_type} for gradients from '
104
+ '`parameters` is non-finite, so it cannot be clipped. To disable '
105
+ 'this error and scale the gradients by the non-finite norm anyway, '
106
+ 'set `error_if_nonfinite=False`')
107
+
108
+ # Scale gradients
109
+ clip_coef = max_norm / (total_norm + 1e-6)
110
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
111
+ if grads_fp32:
112
+ multi_tensor_applier(
113
+ amp_C.multi_tensor_scale,
114
+ dummy_overflow_buf,
115
+ [grads_fp32, grads_fp32],
116
+ clip_coef_clamped,
117
+ )
118
+ if grads_fp16:
119
+ multi_tensor_applier(
120
+ amp_C.multi_tensor_scale,
121
+ dummy_overflow_buf,
122
+ [grads_fp16, grads_fp16],
123
+ clip_coef_clamped,
124
+ )
125
+ for g in grads_misc:
126
+ g.mul_(clip_coef_clamped.to(g.device))
127
+
128
+ return total_norm
apex/contrib/conv_bias_relu/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU
2
+
apex/contrib/conv_bias_relu/conv_bias_relu.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+
3
+ import torch
4
+ from torch.autograd import gradcheck
5
+
6
+ from apex import check_cudnn_version_and_warn
7
+ import fused_conv_bias_relu
8
+
9
+ check_cudnn_version_and_warn(__name__, 8400)
10
+
11
+
12
+ class ConvBiasReLU_(torch.autograd.Function):
13
+ @staticmethod
14
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
15
+ def forward(ctx, x, weight, bias, padding, stride):
16
+ outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride)
17
+ ctx.save_for_backward(x, weight, outputs[0])
18
+ ctx.padding = padding
19
+ ctx.stride = stride
20
+
21
+ return outputs[0]
22
+
23
+ @staticmethod
24
+ @torch.cuda.amp.custom_bwd
25
+ def backward(ctx, grad_output):
26
+ bwd_args = [*ctx.saved_tensors, grad_output]
27
+ padding = ctx.padding
28
+ stride = ctx.stride
29
+ grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
30
+
31
+ return grads[0], grads[1], grads[2], None, None
32
+
33
+
34
+ class ConvBiasMaskReLU_(torch.autograd.Function):
35
+ @staticmethod
36
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
37
+ def forward(ctx, x, weight, bias, mask, padding, stride):
38
+ outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride)
39
+ ctx.save_for_backward(x, weight, outputs[0])
40
+ ctx.padding = padding
41
+ ctx.stride = stride
42
+
43
+ return outputs[0]
44
+
45
+ @staticmethod
46
+ @torch.cuda.amp.custom_bwd
47
+ def backward(ctx, grad_output):
48
+ bwd_args = [*ctx.saved_tensors, grad_output]
49
+ padding = ctx.padding
50
+ stride = ctx.stride
51
+ grads = fused_conv_bias_relu.backward(bwd_args, padding, stride)
52
+
53
+ return grads[0], grads[1], grads[2], None, None, None
54
+
55
+
56
+ class ConvBias_(torch.autograd.Function):
57
+ @staticmethod
58
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
59
+ def forward(ctx, x, weight, bias, padding, stride):
60
+ outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride)
61
+ ctx.save_for_backward(x, weight)
62
+ ctx.padding = padding
63
+ ctx.stride = stride
64
+
65
+ return outputs[0]
66
+
67
+ @staticmethod
68
+ @torch.cuda.amp.custom_bwd
69
+ def backward(ctx, grad_output):
70
+ bwd_args = [*ctx.saved_tensors, grad_output]
71
+ padding = ctx.padding
72
+ stride = ctx.stride
73
+ grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride)
74
+
75
+ return grads[0], grads[1], grads[2], None, None
76
+
77
+
78
+ ConvBiasReLU = ConvBiasReLU_.apply
79
+ ConvBiasMaskReLU = ConvBiasMaskReLU_.apply
80
+ ConvBias = ConvBias_.apply
81
+
apex/contrib/csrc/bottleneck/bottleneck.cpp ADDED
The diff for this file is too large to render. See raw diff
 
apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp ADDED
@@ -0,0 +1,1639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/cudnn/Handle.h> // for getcudnnhandle
3
+ #include <torch/extension.h>
4
+ #include <torch/torch.h>
5
+ #include <vector>
6
+ #include <cudnn_frontend.h>
7
+
8
+ #include <iostream>
9
+
10
+ #ifdef DEBUG
11
+ #define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false )
12
+ #else
13
+ #define DEBUG_MSG(str) do { } while ( false )
14
+ #endif
15
+
16
+ #ifdef DEBUG_CUDNN
17
+ #define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false )
18
+ #else
19
+ #define DEBUG_CUDNN_MSG(buf, str) do { } while ( false )
20
+ #endif
21
+
22
+ #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
23
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be contiguous")
24
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
25
+
26
+ #define checkCudnnErr(...) \
27
+ do { \
28
+ int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
29
+ if (err) { \
30
+ return; \
31
+ } \
32
+ } while (0)
33
+
34
+
35
+ int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {
36
+ if (code) {
37
+ printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr);
38
+ return 1;
39
+ }
40
+ return 0;
41
+ }
42
+
43
+ void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true);
44
+ #define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function
45
+
46
+ void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) {
47
+ if (code != cudaSuccess)
48
+ {
49
+ const char * errorMessage = cudaGetErrorString(code);
50
+ fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage);
51
+ if (abort){
52
+ cudaDeviceReset();
53
+ exit(code);
54
+ }
55
+ }
56
+ }
57
+
58
+ void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {
59
+ // For INT8x4 and INT8x32 we still compute standard strides here to input
60
+ // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.
61
+ if (filterFormat == CUDNN_TENSOR_NCHW) {
62
+ strideA[nbDims - 1] = 1;
63
+ for (int64_t d = nbDims - 2; d >= 0; d--) {
64
+ strideA[d] = strideA[d + 1] * dimA[d + 1];
65
+ }
66
+ } else {
67
+ // Here we assume that the format is CUDNN_TENSOR_NHWC
68
+ strideA[1] = 1;
69
+ strideA[nbDims - 1] = strideA[1] * dimA[1];
70
+ for (int64_t d = nbDims - 2; d >= 2; d--) {
71
+ strideA[d] = strideA[d + 1] * dimA[d + 1];
72
+ }
73
+ strideA[0] = strideA[2] * dimA[2];
74
+ }
75
+ }
76
+
77
+
78
+ int getFwdConvDilatedFilterDim(int filterDim, int dilation) {
79
+ return ((filterDim - 1) * dilation) + 1;
80
+ }
81
+
82
+
83
+ int getFwdConvPaddedImageDim(int tensorDim, int pad) {
84
+ return tensorDim + (2 * pad);
85
+ }
86
+
87
+
88
+ int getFwdConvOutputDim(int tensorDim,
89
+ int pad,
90
+ int filterDim,
91
+ int stride,
92
+ int dilation) {
93
+ int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;
94
+ return (p);
95
+ }
96
+
97
+
98
+ // create a cache for plan
99
+ std::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;
100
+
101
+
102
+ std::string getConvFusionString(int64_t* x_dim_padded,
103
+ int64_t* padA,
104
+ int64_t* convstrideA,
105
+ int64_t* dilationA,
106
+ int64_t* w_dim_padded,
107
+ cudnnDataType_t dataType,
108
+ std::string fusion_string) {
109
+
110
+ for(int i=0;i<4;i++) {
111
+ fusion_string += 'X';
112
+ fusion_string += std::to_string(x_dim_padded[i]);
113
+ }
114
+ for(int i=0;i<4;i++) {
115
+ fusion_string += 'W';
116
+ fusion_string += std::to_string(w_dim_padded[i]);
117
+ }
118
+ for(int i=0;i<2;i++) {
119
+ fusion_string += 'P';
120
+ fusion_string += std::to_string(padA[i]);
121
+ }
122
+ for(int i=0;i<2;i++) {
123
+ fusion_string += 'S';
124
+ fusion_string += std::to_string(convstrideA[i]);
125
+ }
126
+ for(int i=0;i<2;i++) {
127
+ fusion_string += 'D';
128
+ fusion_string += std::to_string(dilationA[i]);
129
+ }
130
+ fusion_string += 'T';
131
+ fusion_string += std::to_string(dataType);
132
+ return fusion_string;
133
+ }
134
+
135
+
136
+ cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_,
137
+ std::stringstream& log_buf,
138
+ cudnn_frontend::OperationGraph& opGraph,
139
+ std::string cache_string,
140
+ bool use_heuristic = true){
141
+ auto it = plan_cache.find(cache_string);
142
+ if (it != plan_cache.end()) {
143
+ DEBUG_CUDNN_MSG(log_buf, "Found plan in cache");
144
+ return it->second;
145
+ } else {
146
+ if (use_heuristic){
147
+ // TODO: confirm which mode to use
148
+ auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
149
+ .setOperationGraph(opGraph)
150
+ .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
151
+ .build();
152
+ // try 3 times for now as WAR for no heuristic training
153
+ int max_tries = 3, count = 0;
154
+ auto& engine_configs = heuristics.getEngineConfig(max_tries);
155
+ while(true) {
156
+ try {
157
+ plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()
158
+ .setHandle(handle_)
159
+ .setEngineConfig(engine_configs[count], opGraph.getTag())
160
+ .build()));
161
+ break;
162
+ } catch (cudnn_frontend::cudnnException e) {
163
+ if (++count == max_tries) throw e;
164
+ }
165
+ }
166
+ }else{
167
+ DEBUG_CUDNN_MSG(log_buf, "No plan in cache");
168
+ // How many engines support this operation graph ?
169
+ auto total_engines = opGraph.getEngineCount();
170
+ DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines.");
171
+ // We have to randomly pick one engine from [0, total_engines)
172
+ // Selecting "0" by default
173
+ auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();
174
+ DEBUG_CUDNN_MSG(log_buf, engine.describe());
175
+ auto& knobs = engine.getSupportedKnobs();
176
+ for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {
177
+ DEBUG_CUDNN_MSG(log_buf, it->describe());
178
+ }
179
+ if (knobs.begin() != knobs.end()) {
180
+ DEBUG_CUDNN_MSG(log_buf, "Updated knob choice");
181
+ knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);
182
+ DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());
183
+ }
184
+
185
+ // Createmplacee the requisite engine config
186
+ auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();
187
+ DEBUG_CUDNN_MSG(log_buf, engine_config.describe());
188
+ plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));
189
+ }
190
+
191
+ return plan_cache.find(cache_string)->second;
192
+ }
193
+ }
194
+
195
+
196
+ void
197
+ run_conv_bias(int64_t* x_dim,
198
+ int64_t* w_dim,
199
+ int64_t* y_dim,
200
+ int64_t* conv_pad,
201
+ int64_t* convstride,
202
+ int64_t* dilation,
203
+ cudnnDataType_t dataType,
204
+ at::Half* devPtrX,
205
+ at::Half* devPtrW,
206
+ at::Half* devPtrB,
207
+ at::Half* devPtrY) {
208
+
209
+ cudnnHandle_t handle_ = torch::native::getCudnnHandle();
210
+ std::stringstream log_buf;
211
+
212
+ try {
213
+ int convDim = 2;
214
+ float alpha = 1.0f;
215
+ float beta = 0.0f;
216
+ int64_t b_dim[] = {1, y_dim[1], 1, 1};
217
+
218
+ // Creates the necessary tensor descriptors
219
+ int64_t stride[4];
220
+ generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
221
+ auto xTensor = cudnn_frontend::TensorBuilder()
222
+ .setDim(4, x_dim)
223
+ .setStrides(4, stride)
224
+ .setId('x')
225
+ .setAlignment(16)
226
+ .setDataType(dataType)
227
+ .build();
228
+ DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
229
+
230
+ generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
231
+ auto wTensor = cudnn_frontend::TensorBuilder()
232
+ .setDim(4, w_dim)
233
+ .setStrides(4, stride)
234
+ .setId('w')
235
+ .setAlignment(16)
236
+ .setDataType(dataType)
237
+ .build();
238
+ DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
239
+
240
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
241
+ auto afterConvTensor = cudnn_frontend::TensorBuilder()
242
+ .setDim(4, y_dim)
243
+ .setStrides(4, stride)
244
+ .setId('c')
245
+ .setAlignment(16)
246
+ .setDataType(CUDNN_DATA_FLOAT)
247
+ .setVirtual()
248
+ .build();
249
+ DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
250
+
251
+ generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
252
+ auto bTensor = cudnn_frontend::TensorBuilder()
253
+ .setDim(4, b_dim)
254
+ .setStrides(4, stride)
255
+ .setId('b')
256
+ .setAlignment(16)
257
+ .setDataType(dataType)
258
+ .build();
259
+ DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
260
+
261
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
262
+ auto afterBiasTensor = cudnn_frontend::TensorBuilder()
263
+ .setDim(4, y_dim)
264
+ .setStrides(4, stride)
265
+ .setId('y')
266
+ .setAlignment(16)
267
+ .setDataType(dataType)
268
+ .build();
269
+ DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
270
+
271
+ // Define the bias operation
272
+ auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
273
+ .setMode(CUDNN_POINTWISE_ADD)
274
+ .setMathPrecision(CUDNN_DATA_FLOAT)
275
+ .build();
276
+ DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
277
+
278
+ // Define the convolution problem
279
+ auto convDesc = cudnn_frontend::ConvDescBuilder()
280
+ .setDataType(CUDNN_DATA_FLOAT)
281
+ .setMathMode(CUDNN_CROSS_CORRELATION)
282
+ .setNDims(convDim)
283
+ .setStrides(convDim, convstride)
284
+ .setPrePadding(convDim, conv_pad)
285
+ .setPostPadding(convDim, conv_pad)
286
+ .setDilation(convDim, dilation)
287
+ .build();
288
+ DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
289
+
290
+
291
+ // Create a convolution Node
292
+ auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
293
+ .setxDesc(xTensor)
294
+ .setwDesc(wTensor)
295
+ .setyDesc(afterConvTensor)
296
+ .setcDesc(convDesc)
297
+ .setAlpha(alpha)
298
+ .setBeta(beta)
299
+ .build();
300
+ DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
301
+
302
+ // Create a Bias Node.
303
+ auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
304
+ .setxDesc(conv_op.getOutputTensor())
305
+ .setbDesc(bTensor)
306
+ .setyDesc(afterBiasTensor)
307
+ .setpwDesc(biasDesc)
308
+ .build();
309
+ DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
310
+
311
+ // Create an Operation Graph. In this case it is convolution bias activation
312
+ std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &bias_op};
313
+
314
+ auto opGraph = cudnn_frontend::OperationGraphBuilder()
315
+ .setHandle(handle_)
316
+ .setOperationGraph(2, ops.data())
317
+ .build();
318
+
319
+ // Create string encoding for plan caching
320
+ auto cache_string = getConvFusionString(x_dim, conv_pad, convstride, dilation, w_dim, dataType, opGraph.getTag());
321
+ DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
322
+
323
+ auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
324
+ DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
325
+
326
+ auto workspace_size = plan.getWorkspaceSize();
327
+ DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
328
+
329
+ void* workspace_ptr = nullptr;
330
+ auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
331
+ if (workspace_size > 0) {
332
+ workspace_ptr = workspace_tensor.data_ptr<float>();
333
+ }
334
+ void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};
335
+ int64_t uids[] = {'x', 'w', 'b', 'y'};
336
+ auto variantPack = cudnn_frontend::VariantPackBuilder()
337
+ .setWorkspacePointer(workspace_ptr)
338
+ .setDataPointers(4, data_ptrs)
339
+ .setUids(4, uids)
340
+ .build();
341
+ DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
342
+ cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
343
+ checkCudnnErr(status);
344
+ cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
345
+ } catch (cudnn_frontend::cudnnException e) {
346
+ std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
347
+ }
348
+ }
349
+
350
+
351
+ void
352
+ run_conv_bias_mask_relu(int64_t* x_dim,
353
+ int64_t* w_dim,
354
+ int64_t* y_dim,
355
+ int64_t* conv_pad,
356
+ int64_t* conv_stride,
357
+ int64_t* conv_dilation,
358
+ cudnnDataType_t dataType,
359
+ at::Half* devPtrX,
360
+ at::Half* devPtrW,
361
+ at::Half* devPtrB,
362
+ int8_t* devPtrM,
363
+ at::Half* devPtrY) {
364
+
365
+ cudnnHandle_t handle_ = torch::native::getCudnnHandle();
366
+ std::stringstream log_buf;
367
+
368
+ try {
369
+ int conv_dim = 2;
370
+ float alpha = 1.0f;
371
+ float beta = 0.0f;
372
+ int64_t b_dim[] = {1, y_dim[1], 1, 1};
373
+
374
+ // Creates the necessary tensor descriptors
375
+ int64_t stride[4];
376
+ generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
377
+ auto xTensor = cudnn_frontend::TensorBuilder()
378
+ .setDim(4, x_dim)
379
+ .setStrides(4, stride)
380
+ .setId('x')
381
+ .setAlignment(16)
382
+ .setDataType(dataType)
383
+ .build();
384
+ DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
385
+
386
+ generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
387
+ auto wTensor = cudnn_frontend::TensorBuilder()
388
+ .setDim(4, w_dim)
389
+ .setStrides(4, stride)
390
+ .setId('w')
391
+ .setAlignment(16)
392
+ .setDataType(dataType)
393
+ .build();
394
+ DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
395
+
396
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
397
+ auto mTensor = cudnn_frontend::TensorBuilder()
398
+ .setDim(4, y_dim)
399
+ .setStrides(4, stride)
400
+ .setId('m')
401
+ .setAlignment(16)
402
+ .setDataType(CUDNN_DATA_INT8)
403
+ .build();
404
+ DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
405
+
406
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
407
+ auto afterConvTensor = cudnn_frontend::TensorBuilder()
408
+ .setDim(4, y_dim)
409
+ .setStrides(4, stride)
410
+ .setId('c')
411
+ .setAlignment(16)
412
+ .setDataType(CUDNN_DATA_FLOAT)
413
+ .setVirtual()
414
+ .build();
415
+ DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
416
+
417
+ generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
418
+ auto bTensor = cudnn_frontend::TensorBuilder()
419
+ .setDim(4, b_dim)
420
+ .setStrides(4, stride)
421
+ .setId('b')
422
+ .setAlignment(16)
423
+ .setDataType(dataType)
424
+ .build();
425
+ DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
426
+
427
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
428
+ auto afterBiasTensor = cudnn_frontend::TensorBuilder()
429
+ .setDim(4, y_dim)
430
+ .setStrides(4, stride)
431
+ .setId('B')
432
+ .setAlignment(16)
433
+ .setDataType(CUDNN_DATA_FLOAT)
434
+ .setVirtual()
435
+ .build();
436
+ DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
437
+
438
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
439
+ auto afterMaskTensor = cudnn_frontend::TensorBuilder()
440
+ .setDim(4, y_dim)
441
+ .setStrides(4, stride)
442
+ .setId('M')
443
+ .setAlignment(16)
444
+ .setDataType(CUDNN_DATA_FLOAT)
445
+ .setVirtual()
446
+ .build();
447
+ DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
448
+
449
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
450
+ auto afterReLUTensor = cudnn_frontend::TensorBuilder()
451
+ .setDim(4, y_dim)
452
+ .setStrides(4, stride)
453
+ .setId('y')
454
+ .setAlignment(16)
455
+ .setDataType(dataType)
456
+ .build();
457
+ DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());
458
+
459
+ // Define the convolution problem
460
+ auto convDesc = cudnn_frontend::ConvDescBuilder()
461
+ .setDataType(CUDNN_DATA_FLOAT)
462
+ .setMathMode(CUDNN_CROSS_CORRELATION)
463
+ .setNDims(conv_dim)
464
+ .setStrides(conv_dim, conv_stride)
465
+ .setPrePadding(conv_dim, conv_pad)
466
+ .setPostPadding(conv_dim, conv_pad)
467
+ .setDilation(conv_dim, conv_dilation)
468
+ .build();
469
+ DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
470
+
471
+ // Define the bias operation
472
+ auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
473
+ .setMode(CUDNN_POINTWISE_ADD)
474
+ .setMathPrecision(CUDNN_DATA_FLOAT)
475
+ .build();
476
+ DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
477
+
478
+ // Define the mask operation
479
+ auto maskDesc = cudnn_frontend::PointWiseDescBuilder()
480
+ .setMode(CUDNN_POINTWISE_MUL)
481
+ .setMathPrecision(CUDNN_DATA_FLOAT)
482
+ .build();
483
+
484
+ // Define the activation operation
485
+ auto actDesc = cudnn_frontend::PointWiseDescBuilder()
486
+ .setMode(CUDNN_POINTWISE_RELU_FWD)
487
+ .setMathPrecision(CUDNN_DATA_FLOAT)
488
+ .build();
489
+ DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
490
+
491
+ // Create a convolution Node
492
+ auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
493
+ .setxDesc(xTensor)
494
+ .setwDesc(wTensor)
495
+ .setyDesc(afterConvTensor)
496
+ .setcDesc(convDesc)
497
+ .setAlpha(alpha)
498
+ .setBeta(beta)
499
+ .build();
500
+ DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
501
+
502
+ // Create a Bias Node
503
+ auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
504
+ .setxDesc(conv_op.getOutputTensor())
505
+ .setbDesc(bTensor)
506
+ .setyDesc(afterBiasTensor)
507
+ .setpwDesc(biasDesc)
508
+ .build();
509
+ DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
510
+
511
+ // create a Mask Node
512
+ auto mask_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
513
+ .setxDesc(bias_op.getOutputTensor())
514
+ .setbDesc(mTensor)
515
+ .setyDesc(afterMaskTensor)
516
+ .setpwDesc(maskDesc)
517
+ .build();
518
+ DEBUG_CUDNN_MSG(log_buf, mask_op.describe());
519
+
520
+ // Create an Activation Node
521
+ auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
522
+ .setxDesc(mask_op.getOutputTensor())
523
+ .setyDesc(afterReLUTensor)
524
+ .setpwDesc(actDesc)
525
+ .build();
526
+ DEBUG_CUDNN_MSG(log_buf, act_op.describe());
527
+
528
+ // Create an Operation Graph. In this case it is convolution bias activation
529
+ std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &bias_op, &mask_op, &act_op};
530
+
531
+ auto opGraph = cudnn_frontend::OperationGraphBuilder()
532
+ .setHandle(handle_)
533
+ .setOperationGraph(4, ops.data())
534
+ .build();
535
+
536
+ // Create string encoding for plan caching
537
+ auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
538
+ DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
539
+
540
+ auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
541
+ DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
542
+
543
+ auto workspace_size = plan.getWorkspaceSize();
544
+ DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
545
+
546
+ void* workspace_ptr = nullptr;
547
+ auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
548
+ if (workspace_size > 0) {
549
+ workspace_ptr = workspace_tensor.data_ptr<float>();
550
+ }
551
+ void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrM, devPtrY};
552
+ int64_t uids[] = {'x', 'w', 'b', 'm', 'y'};
553
+ auto variantPack = cudnn_frontend::VariantPackBuilder()
554
+ .setWorkspacePointer(workspace_ptr)
555
+ .setDataPointers(5, data_ptrs)
556
+ .setUids(5, uids)
557
+ .build();
558
+ DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
559
+ cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
560
+ checkCudnnErr(status);
561
+ cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
562
+ } catch (cudnn_frontend::cudnnException e) {
563
+ std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
564
+ }
565
+ }
566
+
567
+
568
+ void
569
+ run_conv_bias_relu(int64_t* x_dim,
570
+ int64_t* w_dim,
571
+ int64_t* y_dim,
572
+ int64_t* conv_pad,
573
+ int64_t* conv_stride,
574
+ int64_t* conv_dilation,
575
+ cudnnDataType_t dataType,
576
+ at::Half* devPtrX,
577
+ at::Half* devPtrW,
578
+ at::Half* devPtrB,
579
+ at::Half* devPtrY) {
580
+
581
+ cudnnHandle_t handle_ = torch::native::getCudnnHandle();
582
+ std::stringstream log_buf;
583
+
584
+ try {
585
+ int conv_dim = 2;
586
+ float alpha = 1.0f;
587
+ float beta = 0.0f;
588
+ int64_t b_dim[] = {1, y_dim[1], 1, 1};
589
+
590
+ // Creates the necessary tensor descriptors
591
+ int64_t stride[4];
592
+ generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
593
+ auto xTensor = cudnn_frontend::TensorBuilder()
594
+ .setDim(4, x_dim)
595
+ .setStrides(4, stride)
596
+ .setId('x')
597
+ .setAlignment(16)
598
+ .setDataType(dataType)
599
+ .build();
600
+ DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
601
+
602
+ generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
603
+ auto wTensor = cudnn_frontend::TensorBuilder()
604
+ .setDim(4, w_dim)
605
+ .setStrides(4, stride)
606
+ .setId('w')
607
+ .setAlignment(16)
608
+ .setDataType(dataType)
609
+ .build();
610
+ DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
611
+
612
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
613
+ auto afterConvTensor = cudnn_frontend::TensorBuilder()
614
+ .setDim(4, y_dim)
615
+ .setStrides(4, stride)
616
+ .setId('c')
617
+ .setAlignment(16)
618
+ .setDataType(CUDNN_DATA_FLOAT)
619
+ .setVirtual()
620
+ .build();
621
+ DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
622
+
623
+ generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
624
+ auto bTensor = cudnn_frontend::TensorBuilder()
625
+ .setDim(4, b_dim)
626
+ .setStrides(4, stride)
627
+ .setId('b')
628
+ .setAlignment(16)
629
+ .setDataType(dataType)
630
+ .build();
631
+ DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
632
+
633
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
634
+ auto afterBiasTensor = cudnn_frontend::TensorBuilder()
635
+ .setDim(4, y_dim)
636
+ .setStrides(4, stride)
637
+ .setId('B')
638
+ .setAlignment(16)
639
+ .setDataType(CUDNN_DATA_FLOAT)
640
+ .setVirtual()
641
+ .build();
642
+ DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
643
+
644
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
645
+ auto afterReLUTensor = cudnn_frontend::TensorBuilder()
646
+ .setDim(4, y_dim)
647
+ .setStrides(4, stride)
648
+ .setId('y')
649
+ .setAlignment(16)
650
+ .setDataType(dataType)
651
+ .build();
652
+ DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());
653
+
654
+ // Define the convolution problem
655
+ auto convDesc = cudnn_frontend::ConvDescBuilder()
656
+ .setDataType(CUDNN_DATA_FLOAT)
657
+ .setMathMode(CUDNN_CROSS_CORRELATION)
658
+ .setNDims(conv_dim)
659
+ .setStrides(conv_dim, conv_stride)
660
+ .setPrePadding(conv_dim, conv_pad)
661
+ .setPostPadding(conv_dim, conv_pad)
662
+ .setDilation(conv_dim, conv_dilation)
663
+ .build();
664
+ DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
665
+
666
+ // Define the bias operation
667
+ auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
668
+ .setMode(CUDNN_POINTWISE_ADD)
669
+ .setMathPrecision(CUDNN_DATA_FLOAT)
670
+ .build();
671
+ DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
672
+
673
+ // Define the activation operation
674
+ auto actDesc = cudnn_frontend::PointWiseDescBuilder()
675
+ .setMode(CUDNN_POINTWISE_RELU_FWD)
676
+ .setMathPrecision(CUDNN_DATA_FLOAT)
677
+ .build();
678
+ DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
679
+
680
+ // Create a convolution Node
681
+ auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
682
+ .setxDesc(xTensor)
683
+ .setwDesc(wTensor)
684
+ .setyDesc(afterConvTensor)
685
+ .setcDesc(convDesc)
686
+ .setAlpha(alpha)
687
+ .setBeta(beta)
688
+ .build();
689
+ DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
690
+
691
+ // Create a Bias Node.
692
+ auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
693
+ .setxDesc(conv_op.getOutputTensor())
694
+ .setbDesc(bTensor)
695
+ .setyDesc(afterBiasTensor)
696
+ .setpwDesc(biasDesc)
697
+ .build();
698
+ DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
699
+
700
+ // Create an Activation Node.
701
+ auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
702
+ .setxDesc(bias_op.getOutputTensor())
703
+ .setyDesc(afterReLUTensor)
704
+ .setpwDesc(actDesc)
705
+ .build();
706
+ DEBUG_CUDNN_MSG(log_buf, act_op.describe());
707
+
708
+ // Create an Operation Graph. In this case it is convolution bias activation
709
+ std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &bias_op, &act_op};
710
+
711
+ auto opGraph = cudnn_frontend::OperationGraphBuilder()
712
+ .setHandle(handle_)
713
+ .setOperationGraph(3, ops.data())
714
+ .build();
715
+
716
+ // Create string encoding for plan caching
717
+ auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
718
+ DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
719
+
720
+ auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
721
+ DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
722
+
723
+ auto workspace_size = plan.getWorkspaceSize();
724
+ DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
725
+
726
+ void* workspace_ptr = nullptr;
727
+ auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
728
+ if (workspace_size > 0) {
729
+ workspace_ptr = workspace_tensor.data_ptr<float>();
730
+ }
731
+ void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};
732
+ int64_t uids[] = {'x', 'w', 'b', 'y'};
733
+ auto variantPack = cudnn_frontend::VariantPackBuilder()
734
+ .setWorkspacePointer(workspace_ptr)
735
+ .setDataPointers(4, data_ptrs)
736
+ .setUids(4, uids)
737
+ .build();
738
+ DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
739
+ cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
740
+ checkCudnnErr(status);
741
+ cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
742
+ } catch (cudnn_frontend::cudnnException e) {
743
+ std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
744
+ }
745
+ }
746
+
747
+
748
+ void
749
+ run_drelu_dbias(int64_t* dy_dim,
750
+ cudnnDataType_t dataType,
751
+ at::Half* devPtrDY,
752
+ at::Half* devPtrR,
753
+ at::Half* devPtrDR,
754
+ float* devPtrDB) {
755
+
756
+ cudnnHandle_t handle_ = torch::native::getCudnnHandle();
757
+ std::stringstream log_buf;
758
+
759
+ try {
760
+ int convDim = 2;
761
+ float alpha = 1.0f;
762
+ float beta = 0.0f;
763
+ int64_t b_dim[] = {1, dy_dim[1], 1, 1};
764
+
765
+ // Creates the necessary tensor descriptors
766
+ int64_t stride[4];
767
+ generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
768
+ auto dyTensor = cudnn_frontend::TensorBuilder()
769
+ .setDim(4, dy_dim)
770
+ .setStrides(4, stride)
771
+ .setId('x')
772
+ .setAlignment(16)
773
+ .setDataType(dataType)
774
+ .build();
775
+ DEBUG_CUDNN_MSG(log_buf, dyTensor.describe());
776
+
777
+ generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
778
+ auto rTensor = cudnn_frontend::TensorBuilder()
779
+ .setDim(4, dy_dim)
780
+ .setStrides(4, stride)
781
+ .setId('r')
782
+ .setAlignment(16)
783
+ .setDataType(dataType)
784
+ .build();
785
+ DEBUG_CUDNN_MSG(log_buf, rTensor.describe());
786
+
787
+ generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
788
+ auto inActGradTensor = cudnn_frontend::TensorBuilder()
789
+ .setDim(4, dy_dim)
790
+ .setStrides(4, stride)
791
+ .setId('R')
792
+ .setAlignment(16)
793
+ .setDataType(dataType)
794
+ .build();
795
+ DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe());
796
+
797
+ generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
798
+ auto biasGradTensor = cudnn_frontend::TensorBuilder()
799
+ .setDim(4, b_dim)
800
+ .setStrides(4, stride)
801
+ .setId('y')
802
+ .setAlignment(16)
803
+ .setDataType(CUDNN_DATA_FLOAT)
804
+ .build();
805
+ DEBUG_CUDNN_MSG(log_buf, biasGradTensor.describe());
806
+
807
+ // Define the activation backward operation
808
+ auto actDesc = cudnn_frontend::PointWiseDescBuilder()
809
+ .setMode(CUDNN_POINTWISE_RELU_BWD)
810
+ .setMathPrecision(CUDNN_DATA_FLOAT)
811
+ .build();
812
+ DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
813
+
814
+ // Define the bias backward operation
815
+ auto biasDesc = cudnn_frontend::ReductionDescBuilder()
816
+ .setMathPrecision(CUDNN_DATA_FLOAT)
817
+ .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
818
+ .build();
819
+ DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
820
+
821
+ // Create an relu backward Node
822
+ auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
823
+ .setdyDesc(dyTensor)
824
+ .setxDesc(rTensor)
825
+ .setdxDesc(inActGradTensor)
826
+ .setpwDesc(actDesc)
827
+ .build();
828
+ DEBUG_CUDNN_MSG(log_buf, act_op.describe());
829
+
830
+ // Create bias node
831
+ auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
832
+ .setxDesc(inActGradTensor)
833
+ .setyDesc(biasGradTensor)
834
+ .setreductionDesc(biasDesc)
835
+ .build();
836
+ DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
837
+
838
+ // Create an Operation Graph. In this case it is bias only
839
+ std::array<cudnn_frontend::Operation const*, 2> ops = {&act_op, &bias_op};
840
+
841
+ auto opGraph = cudnn_frontend::OperationGraphBuilder()
842
+ .setHandle(handle_)
843
+ .setOperationGraph(ops.size(), ops.data())
844
+ .build();
845
+
846
+ // Create string encoding for plan caching
847
+ // creating unique dummy values
848
+ int64_t pad_dummy[] = {20, 20};
849
+ int64_t stride_dummy[] = {20, 20};
850
+ int64_t dilation_dummy[] = {20, 20};
851
+ auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());
852
+ DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
853
+
854
+ auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
855
+ DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
856
+
857
+ auto workspace_size = plan.getWorkspaceSize();
858
+ DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
859
+
860
+ void* workspace_ptr = nullptr;
861
+ auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
862
+ if (workspace_size > 0) {
863
+ workspace_ptr = workspace_tensor.data_ptr<float>();
864
+ }
865
+ void* data_ptrs[] = {devPtrDY, devPtrR, devPtrDR, devPtrDB};
866
+ int64_t uids[] = {'x', 'r', 'R', 'y'};
867
+ auto variantPack = cudnn_frontend::VariantPackBuilder()
868
+ .setWorkspacePointer(workspace_ptr)
869
+ .setDataPointers(4, data_ptrs)
870
+ .setUids(4, uids)
871
+ .build();
872
+ DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
873
+ cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
874
+ checkCudnnErr(status);
875
+ cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
876
+ } catch (cudnn_frontend::cudnnException e) {
877
+ std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
878
+ }
879
+ }
880
+
881
+
882
+ void
883
+ run_dconv_drelu_dbias(int64_t* x_dim,
884
+ int64_t* w_dim,
885
+ int64_t* y_dim,
886
+ int64_t* pad,
887
+ int64_t* convstride,
888
+ int64_t* dilation,
889
+ cudnnDataType_t dataType,
890
+ at::Half* devPtrX,
891
+ at::Half* devPtrW,
892
+ at::Half* devPtrR,
893
+ at::Half* devPtrRg,
894
+ float* devPtrY) {
895
+ cudnnHandle_t handle_ = torch::native::getCudnnHandle();
896
+ std::stringstream log_buf;
897
+ try {
898
+ int convDim = 2;
899
+ float alpha = 1.0f;
900
+ float beta = 0.0f;
901
+ int64_t b_dim[] = {1, x_dim[1], 1, 1};
902
+
903
+ int64_t stride[4];
904
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
905
+ auto outConvGradTensor = cudnn_frontend::TensorBuilder()
906
+ .setDim(4, y_dim)
907
+ .setStrides(4, stride)
908
+ .setId('x')
909
+ .setAlignment(16)
910
+ .setDataType(dataType)
911
+ .build();
912
+ DEBUG_CUDNN_MSG(log_buf, outConvGradTensor.describe());
913
+
914
+ generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
915
+ auto wTensor = cudnn_frontend::TensorBuilder()
916
+ .setDim(4, w_dim)
917
+ .setStrides(4, stride)
918
+ .setId('w')
919
+ .setAlignment(16)
920
+ .setDataType(dataType)
921
+ .build();
922
+ DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
923
+
924
+ generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
925
+ auto inConvGradTensor = cudnn_frontend::TensorBuilder()
926
+ .setDim(4, x_dim)
927
+ .setStrides(4, stride)
928
+ .setId('A')
929
+ .setAlignment(16)
930
+ .setDataType(CUDNN_DATA_FLOAT)
931
+ .setVirtual()
932
+ .build();
933
+ DEBUG_CUDNN_MSG(log_buf, inConvGradTensor.describe());
934
+
935
+ generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
936
+ auto rTensor = cudnn_frontend::TensorBuilder()
937
+ .setDim(4, x_dim)
938
+ .setStrides(4, stride)
939
+ .setId('r')
940
+ .setAlignment(16)
941
+ .setDataType(dataType)
942
+ .build();
943
+ DEBUG_CUDNN_MSG(log_buf, rTensor.describe());
944
+
945
+ generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
946
+ auto inReLUGradTensor = cudnn_frontend::TensorBuilder()
947
+ .setDim(4, x_dim)
948
+ .setStrides(4, stride)
949
+ .setId('R')
950
+ .setAlignment(16)
951
+ .setDataType(dataType)
952
+ .build();
953
+ DEBUG_CUDNN_MSG(log_buf, inReLUGradTensor.describe());
954
+
955
+ generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
956
+ auto inBiasGradTensor = cudnn_frontend::TensorBuilder()
957
+ .setDim(4, b_dim)
958
+ .setStrides(4, stride)
959
+ .setId('y')
960
+ .setAlignment(16)
961
+ .setDataType(CUDNN_DATA_FLOAT)
962
+ .build();
963
+ DEBUG_CUDNN_MSG(log_buf, inBiasGradTensor.describe());
964
+
965
+ // Define the convolution problem
966
+ auto convDesc = cudnn_frontend::ConvDescBuilder()
967
+ .setDataType(CUDNN_DATA_FLOAT)
968
+ .setMathMode(CUDNN_CROSS_CORRELATION)
969
+ .setNDims(convDim)
970
+ .setStrides(convDim, convstride)
971
+ .setPrePadding(convDim, pad)
972
+ .setPostPadding(convDim, pad)
973
+ .setDilation(convDim, dilation)
974
+ .build();
975
+ DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
976
+
977
+ // Define the activation backward operation
978
+ auto actDesc = cudnn_frontend::PointWiseDescBuilder()
979
+ .setMode(CUDNN_POINTWISE_RELU_BWD)
980
+ .setMathPrecision(CUDNN_DATA_FLOAT)
981
+ .build();
982
+ DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
983
+
984
+ // Define the bias backward operation
985
+ auto biasDesc = cudnn_frontend::ReductionDescBuilder()
986
+ .setMathPrecision(CUDNN_DATA_FLOAT)
987
+ .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
988
+ .build();
989
+ DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
990
+
991
+ // Create a convolution Node
992
+ auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
993
+ .setdyDesc(outConvGradTensor)
994
+ .setwDesc(wTensor)
995
+ .setdxDesc(inConvGradTensor)
996
+ .setcDesc(convDesc)
997
+ .setAlpha(alpha)
998
+ .setBeta(beta)
999
+ .build();
1000
+ DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
1001
+
1002
+ // Create an relu backward Node
1003
+ auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
1004
+ .setdyDesc(inConvGradTensor)
1005
+ .setxDesc(rTensor)
1006
+ .setdxDesc(inReLUGradTensor)
1007
+ .setpwDesc(actDesc)
1008
+ .build();
1009
+ DEBUG_CUDNN_MSG(log_buf, act_op.describe());
1010
+
1011
+ // Create bias node
1012
+ auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
1013
+ .setxDesc(inReLUGradTensor)
1014
+ .setyDesc(inBiasGradTensor)
1015
+ .setreductionDesc(biasDesc)
1016
+ .build();
1017
+ DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
1018
+
1019
+ // Create an Operation Graph. In this case it is bias only
1020
+ std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &bias_op};
1021
+
1022
+ auto opGraph = cudnn_frontend::OperationGraphBuilder()
1023
+ .setHandle(handle_)
1024
+ .setOperationGraph(ops.size(), ops.data())
1025
+ .build();
1026
+
1027
+ // Create string encoding for plan caching
1028
+ auto cache_string = getConvFusionString(x_dim, pad, convstride, dilation, w_dim, dataType, opGraph.getTag());
1029
+ DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
1030
+
1031
+ auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
1032
+ DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
1033
+
1034
+ auto workspace_size = plan.getWorkspaceSize();
1035
+ DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
1036
+
1037
+ void* workspace_ptr = nullptr;
1038
+ auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
1039
+ if (workspace_size > 0) {
1040
+ workspace_ptr = workspace_tensor.data_ptr<float>();
1041
+ }
1042
+ void* data_ptrs[] = {devPtrX, devPtrW, devPtrR, devPtrRg, devPtrY};
1043
+ int64_t uids[] = {'x', 'w', 'r', 'R', 'y'};
1044
+ auto variantPack = cudnn_frontend::VariantPackBuilder()
1045
+ .setWorkspacePointer(workspace_ptr)
1046
+ .setDataPointers(5, data_ptrs)
1047
+ .setUids(5, uids)
1048
+ .build();
1049
+ DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
1050
+ cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
1051
+ checkCudnnErr(status);
1052
+ cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
1053
+ } catch (cudnn_frontend::cudnnException e) {
1054
+ std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
1055
+ }
1056
+
1057
+ }
1058
+
1059
+
1060
+ void
1061
+ run_dconv(int64_t* x_dim,
1062
+ int64_t* w_dim,
1063
+ int64_t* y_dim,
1064
+ int64_t* conv_pad,
1065
+ int64_t* conv_stride,
1066
+ int64_t* conv_dilation,
1067
+ cudnnDataType_t dataType,
1068
+ at::Half* devPtrX,
1069
+ at::Half* devPtrW,
1070
+ at::Half* devPtrY,
1071
+ cudnnBackendDescriptorType_t mode) {
1072
+
1073
+ cudnnHandle_t handle_ = torch::native::getCudnnHandle();
1074
+ std::stringstream log_buf;
1075
+
1076
+ try {
1077
+ int conv_dim = 2;
1078
+ float alpha = 1.0f;
1079
+ float beta = 0.0f;
1080
+
1081
+ // Define the convolution problem
1082
+ int64_t stride[4];
1083
+ generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
1084
+ auto xTensor = cudnn_frontend::TensorBuilder()
1085
+ .setDim(4, x_dim)
1086
+ .setStrides(4, stride)
1087
+ .setId('x')
1088
+ .setAlignment(16)
1089
+ .setDataType(dataType)
1090
+ .build();
1091
+ DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
1092
+
1093
+ generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
1094
+ auto wTensor = cudnn_frontend::TensorBuilder()
1095
+ .setDim(4, w_dim)
1096
+ .setStrides(4, stride)
1097
+ .setId('w')
1098
+ .setAlignment(16)
1099
+ .setDataType(dataType)
1100
+ .build();
1101
+ DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
1102
+
1103
+ generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
1104
+ auto yTensor = cudnn_frontend::TensorBuilder()
1105
+ .setDim(4, y_dim)
1106
+ .setStrides(4, stride)
1107
+ .setId('y')
1108
+ .setAlignment(16)
1109
+ .setDataType(dataType)
1110
+ .build();
1111
+ DEBUG_CUDNN_MSG(log_buf, yTensor.describe());
1112
+
1113
+
1114
+ // Define the convolution problem
1115
+ auto convDesc = cudnn_frontend::ConvDescBuilder()
1116
+ .setDataType(CUDNN_DATA_FLOAT)
1117
+ .setMathMode(CUDNN_CROSS_CORRELATION)
1118
+ .setNDims(conv_dim)
1119
+ .setStrides(conv_dim, conv_stride)
1120
+ .setPrePadding(conv_dim, conv_pad)
1121
+ .setPostPadding(conv_dim, conv_pad)
1122
+ .setDilation(conv_dim, conv_dilation)
1123
+ .build();
1124
+ DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
1125
+
1126
+ // Create a convolution node
1127
+ // mode should be one of following
1128
+ // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
1129
+ // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
1130
+ auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);
1131
+ if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
1132
+ conv_op_builder.setdxDesc(xTensor)
1133
+ .setwDesc(wTensor)
1134
+ .setdyDesc(yTensor)
1135
+ .setcDesc(convDesc);
1136
+ }
1137
+ else {
1138
+ conv_op_builder.setxDesc(xTensor)
1139
+ .setdwDesc(wTensor)
1140
+ .setdyDesc(yTensor)
1141
+ .setcDesc(convDesc);
1142
+ }
1143
+ auto conv_op = conv_op_builder
1144
+ .setAlpha(alpha)
1145
+ .setBeta(beta)
1146
+ .build();
1147
+ DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
1148
+
1149
+ // Create an Operation Graph. In this case it is convolution add bias activation
1150
+ std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};
1151
+
1152
+ auto opGraph = cudnn_frontend::OperationGraphBuilder()
1153
+ .setHandle(handle_)
1154
+ .setOperationGraph(ops.size(), ops.data())
1155
+ .build();
1156
+
1157
+ // Create string encoding for plan caching
1158
+ auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
1159
+ DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
1160
+
1161
+ auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
1162
+ DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
1163
+
1164
+ auto workspace_size = plan.getWorkspaceSize();
1165
+ DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
1166
+
1167
+ void* workspace_ptr = nullptr;
1168
+ auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
1169
+ if (workspace_size > 0) {
1170
+ workspace_ptr = workspace_tensor.data_ptr<float>();
1171
+ }
1172
+ void* data_ptrs[] = {devPtrX, devPtrW, devPtrY};
1173
+ int64_t uids[] = {'x', 'w', 'y'};
1174
+ auto variantPack = cudnn_frontend::VariantPackBuilder()
1175
+ .setWorkspacePointer(workspace_ptr)
1176
+ .setDataPointers(3, data_ptrs)
1177
+ .setUids(3, uids)
1178
+ .build();
1179
+ DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
1180
+ cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
1181
+ checkCudnnErr(status);
1182
+ cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
1183
+ } catch (cudnn_frontend::cudnnException e) {
1184
+ std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
1185
+ }
1186
+ }
1187
+
1188
+
1189
+ void
1190
+ run_dbias(int64_t* x_dim,
1191
+ cudnnDataType_t dataType,
1192
+ at::Half* devPtrX,
1193
+ float* devPtrY) {
1194
+ cudnnHandle_t handle_ = torch::native::getCudnnHandle();
1195
+ std::stringstream log_buf;
1196
+ try {
1197
+ int convDim = 2;
1198
+ int64_t b_dim[] = {1, x_dim[1], 1, 1};
1199
+
1200
+ int64_t stride[4];
1201
+ generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
1202
+ auto xTensor = cudnn_frontend::TensorBuilder()
1203
+ .setDim(4, x_dim)
1204
+ .setStrides(4, stride)
1205
+ .setId('x')
1206
+ .setAlignment(16)
1207
+ .setDataType(dataType)
1208
+ .build();
1209
+ DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
1210
+
1211
+ generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
1212
+ auto yTensor = cudnn_frontend::TensorBuilder()
1213
+ .setDim(4, b_dim)
1214
+ .setStrides(4, stride)
1215
+ .setId('y')
1216
+ .setAlignment(16)
1217
+ .setDataType(CUDNN_DATA_FLOAT)
1218
+ .build();
1219
+ DEBUG_CUDNN_MSG(log_buf, yTensor.describe());
1220
+
1221
+ // Define the bias backward operation
1222
+ auto biasDesc = cudnn_frontend::ReductionDescBuilder()
1223
+ .setMathPrecision(CUDNN_DATA_FLOAT)
1224
+ .setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
1225
+ .build();
1226
+ DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
1227
+
1228
+ // Create bias node
1229
+ auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
1230
+ .setxDesc(xTensor)
1231
+ .setyDesc(yTensor)
1232
+ .setreductionDesc(biasDesc)
1233
+ .build();
1234
+ DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
1235
+
1236
+ // Create an Operation Graph. In this case it is bias only
1237
+ std::array<cudnn_frontend::Operation const*, 1> ops = {&bias_op};
1238
+
1239
+ auto opGraph = cudnn_frontend::OperationGraphBuilder()
1240
+ .setHandle(handle_)
1241
+ .setOperationGraph(ops.size(), ops.data())
1242
+ .build();
1243
+
1244
+ // Create string encoding for plan caching
1245
+ int64_t pad_dummy[] = {10, 10};
1246
+ int64_t stride_dummy[] = {10, 10};
1247
+ int64_t dilation_dummy[] = {10, 10};
1248
+ auto cache_string = getConvFusionString(x_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());
1249
+ DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
1250
+
1251
+ auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
1252
+ DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
1253
+
1254
+ auto workspace_size = plan.getWorkspaceSize();
1255
+ DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
1256
+
1257
+ void* workspace_ptr = nullptr;
1258
+ auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
1259
+ if (workspace_size > 0) {
1260
+ workspace_ptr = workspace_tensor.data_ptr<float>();
1261
+ }
1262
+ void* data_ptrs[] = {devPtrX, devPtrY};
1263
+ int64_t uids[] = {'x', 'y'};
1264
+ auto variantPack = cudnn_frontend::VariantPackBuilder()
1265
+ .setWorkspacePointer(workspace_ptr)
1266
+ .setDataPointers(2, data_ptrs)
1267
+ .setUids(2, uids)
1268
+ .build();
1269
+ DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
1270
+ cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
1271
+ checkCudnnErr(status);
1272
+ cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
1273
+ } catch (cudnn_frontend::cudnnException e) {
1274
+ std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
1275
+ }
1276
+
1277
+ }
1278
+
1279
+
1280
+ std::vector<at::Tensor> conv_bias_mask_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
1281
+ std::cout << std::fixed;
1282
+
1283
+ // create output vector
1284
+ std::vector<at::Tensor> outputs;
1285
+ auto output_format = at::MemoryFormat::ChannelsLast;
1286
+
1287
+ // setup dimensions
1288
+ int64_t x_dim[] = {0, 0, 0, 0};
1289
+ int64_t w_dim[] = {0, 0, 0, 0};
1290
+
1291
+ // All dim calculation after this order of n,c,h,w
1292
+ int axis[] = {0, 1, 2, 3};
1293
+ for (int dim = 0; dim < 4; dim++) {
1294
+ x_dim[dim] = inputs[0].size(axis[dim]);
1295
+ w_dim[dim] = inputs[1].size(axis[dim]);
1296
+ }
1297
+
1298
+ // output dim in n,c,h,w used by backend
1299
+ int64_t y_dim[] = {0, 0, 0, 0};
1300
+
1301
+ // use these fixed values
1302
+ int64_t conv_pad[] = {padding, padding};
1303
+ int64_t conv_stride[] = {stride, stride};
1304
+ int64_t conv_dilation[] = {1, 1};
1305
+
1306
+ // compute output from pad/stride/dilation
1307
+ y_dim[0] = x_dim[0];
1308
+ y_dim[1] = w_dim[0];
1309
+ for (int dim = 0; dim < 2; dim++) {
1310
+ y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
1311
+ }
1312
+
1313
+ // run
1314
+ at::Half* x = inputs[0].data_ptr<at::Half>();
1315
+ at::Half* w = inputs[1].data_ptr<at::Half>();
1316
+ at::Half* b = inputs[2].data_ptr<at::Half>();
1317
+ int8_t* m = inputs[3].data_ptr<int8_t>();
1318
+ auto out = at::empty(y_dim, inputs[0].type(), output_format);
1319
+ at::Half* y = out.data_ptr<at::Half>();
1320
+
1321
+ run_conv_bias_mask_relu(x_dim,
1322
+ w_dim,
1323
+ y_dim,
1324
+ conv_pad,
1325
+ conv_stride,
1326
+ conv_dilation,
1327
+ CUDNN_DATA_HALF,
1328
+ x,
1329
+ w,
1330
+ b,
1331
+ m,
1332
+ y);
1333
+
1334
+ DEBUG_MSG("[DEBUG] conv-bias-mask-relu : " << y.to(at::kFloat).sum().item<float>());
1335
+
1336
+ outputs.push_back(out);
1337
+
1338
+ return outputs;
1339
+ }
1340
+
1341
+
1342
+ std::vector<at::Tensor> conv_bias_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
1343
+ std::cout << std::fixed;
1344
+
1345
+ // create output vector
1346
+ std::vector<at::Tensor> outputs;
1347
+ auto output_format = at::MemoryFormat::ChannelsLast;
1348
+
1349
+ // setup dimensions
1350
+ int64_t x_dim[] = {0, 0, 0, 0};
1351
+ int64_t w_dim[] = {0, 0, 0, 0};
1352
+
1353
+ // All dim calculation after this order of n,c,h,w
1354
+ int axis[] = {0, 1, 2, 3};
1355
+ for (int dim = 0; dim < 4; dim++) {
1356
+ x_dim[dim] = inputs[0].size(axis[dim]);
1357
+ w_dim[dim] = inputs[1].size(axis[dim]);
1358
+ }
1359
+
1360
+ // output dim in n,c,h,w used by backend
1361
+ int64_t y_dim[] = {0, 0, 0, 0};
1362
+
1363
+ // use these fixed values
1364
+ int64_t conv_pad[] = {padding, padding};
1365
+ int64_t conv_stride[] = {stride, stride};
1366
+ int64_t conv_dilation[] = {1, 1};
1367
+
1368
+ // compute output from pad/stride/dilation
1369
+ y_dim[0] = x_dim[0];
1370
+ y_dim[1] = w_dim[0];
1371
+ for (int dim = 0; dim < 2; dim++) {
1372
+ y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
1373
+ }
1374
+
1375
+ // run
1376
+ at::Half* x = inputs[0].data_ptr<at::Half>();
1377
+ at::Half* w = inputs[1].data_ptr<at::Half>();
1378
+ at::Half* b = inputs[2].data_ptr<at::Half>();
1379
+ auto out = at::empty(y_dim, inputs[0].type(), output_format);
1380
+ at::Half* y = out.data_ptr<at::Half>();
1381
+
1382
+ run_conv_bias_relu(x_dim,
1383
+ w_dim,
1384
+ y_dim,
1385
+ conv_pad,
1386
+ conv_stride,
1387
+ conv_dilation,
1388
+ CUDNN_DATA_HALF,
1389
+ x,
1390
+ w,
1391
+ b,
1392
+ y);
1393
+
1394
+ DEBUG_MSG("[DEBUG] conv-bias-relu : " << y.to(at::kFloat).sum().item<float>());
1395
+
1396
+ outputs.push_back(out);
1397
+
1398
+ return outputs;
1399
+ }
1400
+
1401
+
1402
+ std::vector<at::Tensor> conv_bias_relu_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
1403
+ bool requires_grad = inputs[0].requires_grad();
1404
+
1405
+ for (int i = 0; i <= 3; i++) {
1406
+ CHECK_INPUT(inputs[i]);
1407
+ }
1408
+
1409
+ std::cout << std::fixed;
1410
+
1411
+ // create output vector
1412
+ std::vector<at::Tensor> outputs;
1413
+ auto output_format = at::MemoryFormat::ChannelsLast;
1414
+
1415
+ // setup dimensions
1416
+ int64_t x_dim[] = {0, 0, 0, 0};
1417
+ int64_t w_dim[] = {0, 0, 0, 0};
1418
+ int64_t y_dim[] = {0, 0, 0, 0};
1419
+
1420
+ // All dim calculation after this order of n,c,h,w
1421
+ int axis[] = {0, 1, 2, 3};
1422
+ for (int dim = 0; dim < 4; dim++) {
1423
+ x_dim[dim] = inputs[0].size(axis[dim]);
1424
+ w_dim[dim] = inputs[1].size(axis[dim]);
1425
+ y_dim[dim] = inputs[3].size(axis[dim]);
1426
+ }
1427
+
1428
+ int64_t b_dim[] = {1, y_dim[1], 1, 1};
1429
+
1430
+ int64_t conv_pad[] = {padding, padding};
1431
+ int64_t conv_stride[] = {stride, stride};
1432
+ int64_t conv_dilation[] = {1, 1};
1433
+
1434
+ // run
1435
+ // drelu-dbias
1436
+ at::Half* dy = inputs[3].data_ptr<at::Half>();
1437
+ at::Half* r = inputs[2].data_ptr<at::Half>();
1438
+ auto drelu = at::empty_like(inputs[2]);
1439
+ at::Half* dr = drelu.data_ptr<at::Half>();
1440
+ auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);
1441
+ auto bgrad = at::empty(b_dim, options, output_format);
1442
+ float* db = bgrad.data_ptr<float>();
1443
+ run_drelu_dbias(y_dim,
1444
+ CUDNN_DATA_HALF,
1445
+ dy,
1446
+ r,
1447
+ dr,
1448
+ db);
1449
+
1450
+ // conv wgrad
1451
+ at::Half* x = inputs[0].data_ptr<at::Half>();
1452
+ auto wgrad = at::empty_like(inputs[1]);
1453
+ at::Half* dw = wgrad.data_ptr<at::Half>();
1454
+ run_dconv(x_dim,
1455
+ w_dim,
1456
+ y_dim,
1457
+ conv_pad,
1458
+ conv_stride,
1459
+ conv_dilation,
1460
+ CUDNN_DATA_HALF,
1461
+ x,
1462
+ dw,
1463
+ dr,
1464
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
1465
+
1466
+ // conv dgrad
1467
+ at::Half* w = inputs[1].data_ptr<at::Half>();
1468
+ auto dgrad = at::empty_like(inputs[0]);
1469
+ at::Half* dx = dgrad.data_ptr<at::Half>();
1470
+ run_dconv(x_dim,
1471
+ w_dim,
1472
+ y_dim,
1473
+ conv_pad,
1474
+ conv_stride,
1475
+ conv_dilation,
1476
+ CUDNN_DATA_HALF,
1477
+ dx,
1478
+ w,
1479
+ dr,
1480
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
1481
+
1482
+ outputs.push_back(dgrad);
1483
+ outputs.push_back(wgrad);
1484
+ outputs.push_back(bgrad);
1485
+
1486
+ return outputs;
1487
+
1488
+ }
1489
+
1490
+ std::vector<at::Tensor> conv_bias_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
1491
+ std::cout << std::fixed;
1492
+
1493
+ // create output vector
1494
+ std::vector<at::Tensor> outputs;
1495
+ auto output_format = at::MemoryFormat::ChannelsLast;
1496
+
1497
+ // setup dimensions
1498
+ int64_t x_dim[] = {0, 0, 0, 0};
1499
+ int64_t w_dim[] = {0, 0, 0, 0};
1500
+
1501
+ // All dim calculation after this order of n,c,h,w
1502
+ int axis[] = {0, 1, 2, 3};
1503
+ for (int dim = 0; dim < 4; dim++) {
1504
+ x_dim[dim] = inputs[0].size(axis[dim]);
1505
+ w_dim[dim] = inputs[1].size(axis[dim]);
1506
+ }
1507
+
1508
+ // output dim in n,c,h,w used by backend
1509
+ int64_t y_dim[] = {0, 0, 0, 0};
1510
+
1511
+ // use these fixed values
1512
+ int64_t conv_pad[] = {padding, padding};
1513
+ int64_t conv_stride[] = {stride, stride};
1514
+ int64_t conv_dilation[] = {1, 1};
1515
+
1516
+ // compute output from pad/stride/dilation
1517
+ y_dim[0] = x_dim[0];
1518
+ y_dim[1] = w_dim[0];
1519
+ for (int dim = 0; dim < 2; dim++) {
1520
+ y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
1521
+ }
1522
+
1523
+ // run
1524
+ at::Half* x = inputs[0].data_ptr<at::Half>();
1525
+ at::Half* w = inputs[1].data_ptr<at::Half>();
1526
+ at::Half* b = inputs[2].data_ptr<at::Half>();
1527
+ auto out = at::empty(y_dim, inputs[0].type(), output_format);
1528
+ at::Half* y = out.data_ptr<at::Half>();
1529
+
1530
+ run_conv_bias(x_dim,
1531
+ w_dim,
1532
+ y_dim,
1533
+ conv_pad,
1534
+ conv_stride,
1535
+ conv_dilation,
1536
+ CUDNN_DATA_HALF,
1537
+ x,
1538
+ w,
1539
+ b,
1540
+ y);
1541
+
1542
+ DEBUG_MSG("[DEBUG] conv-bias : " << y.to(at::kFloat).sum().item<float>());
1543
+
1544
+ outputs.push_back(out);
1545
+
1546
+ return outputs;
1547
+ }
1548
+
1549
+
1550
+ std::vector<at::Tensor> conv_bias_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
1551
+ bool requires_grad = inputs[0].requires_grad();
1552
+
1553
+ for (int i = 0; i <= 2; i++) {
1554
+ CHECK_INPUT(inputs[i]);
1555
+ }
1556
+
1557
+ std::cout << std::fixed;
1558
+
1559
+ // create output vector
1560
+ std::vector<at::Tensor> outputs;
1561
+ auto output_format = at::MemoryFormat::ChannelsLast;
1562
+
1563
+ // setup dimensions
1564
+ int64_t x_dim[] = {0, 0, 0, 0};
1565
+ int64_t w_dim[] = {0, 0, 0, 0};
1566
+ int64_t y_dim[] = {0, 0, 0, 0};
1567
+
1568
+ // All dim calculation after this order of n,c,h,w
1569
+ int axis[] = {0, 1, 2, 3};
1570
+ for (int dim = 0; dim < 4; dim++) {
1571
+ x_dim[dim] = inputs[0].size(axis[dim]);
1572
+ w_dim[dim] = inputs[1].size(axis[dim]);
1573
+ y_dim[dim] = inputs[2].size(axis[dim]);
1574
+ }
1575
+
1576
+ int64_t b_dim[] = {1, y_dim[1], 1, 1};
1577
+
1578
+ int64_t conv_pad[] = {padding, padding};
1579
+ int64_t conv_stride[] = {stride, stride};
1580
+ int64_t conv_dilation[] = {1, 1};
1581
+
1582
+ // run
1583
+ // dbias
1584
+ at::Half* dy = inputs[2].data_ptr<at::Half>();
1585
+ auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);
1586
+ auto bgrad = at::empty(b_dim, options, output_format);
1587
+ float* db = bgrad.data_ptr<float>();
1588
+ run_dbias(y_dim,
1589
+ CUDNN_DATA_HALF,
1590
+ dy,
1591
+ db);
1592
+
1593
+ // conv wgrad
1594
+ at::Half* x = inputs[0].data_ptr<at::Half>();
1595
+ auto wgrad = at::empty_like(inputs[1]);
1596
+ at::Half* dw = wgrad.data_ptr<at::Half>();
1597
+ run_dconv(x_dim,
1598
+ w_dim,
1599
+ y_dim,
1600
+ conv_pad,
1601
+ conv_stride,
1602
+ conv_dilation,
1603
+ CUDNN_DATA_HALF,
1604
+ x,
1605
+ dw,
1606
+ dy,
1607
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
1608
+
1609
+ // conv dgrad
1610
+ at::Half* w = inputs[1].data_ptr<at::Half>();
1611
+ auto dgrad = at::empty_like(inputs[0]);
1612
+ at::Half* dx = dgrad.data_ptr<at::Half>();
1613
+ run_dconv(x_dim,
1614
+ w_dim,
1615
+ y_dim,
1616
+ conv_pad,
1617
+ conv_stride,
1618
+ conv_dilation,
1619
+ CUDNN_DATA_HALF,
1620
+ dx,
1621
+ w,
1622
+ dy,
1623
+ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
1624
+
1625
+ outputs.push_back(dgrad);
1626
+ outputs.push_back(wgrad);
1627
+ outputs.push_back(bgrad);
1628
+
1629
+ return outputs;
1630
+ }
1631
+
1632
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1633
+ m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward");
1634
+ m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward");
1635
+ m.def("forward_no_relu", &conv_bias_forward, "Fused Conv-Bias forward");
1636
+ m.def("backward_no_relu", &conv_bias_backward, "Fused Conv-Bias backward");
1637
+ m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward");
1638
+ }
1639
+
apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <torch/extension.h>
3
+ #include <torch/torch.h>
4
+ #include <vector>
5
+
6
+ #include <iostream>
7
+
8
+ #include "norm_sample.h"
9
+
10
+ at::Tensor gbn_forward(const at::Tensor& x,
11
+ const at::Tensor& scale,
12
+ const at::Tensor& bias,
13
+ const at::Tensor& running_mean,
14
+ const at::Tensor& running_var,
15
+ const at::Tensor& minibatch_mean,
16
+ const at::Tensor& minibatch_inv_var,
17
+ const float momentum,
18
+ const float epsilon,
19
+ const int64_t bn_group,
20
+ const int rank_id,
21
+ const std::vector<int64_t> &peer_buffers) {
22
+
23
+ int64_t N = x.size(0);
24
+ int64_t C = x.size(1);
25
+ int64_t H = x.size(2);
26
+ int64_t W = x.size(3);
27
+
28
+ int64_t tensorDims[] = {N, C, H, W};
29
+ int64_t peerDims[] = {bn_group, 4*C, 1, 1};
30
+ int64_t perChannelDims[] = {1, C, 1, 1};
31
+ int64_t epsilonDims[] = {1, 1, 1, 1};
32
+
33
+ // Allocate output tensor
34
+ at::Tensor y = at::empty_like(x);
35
+
36
+ std::vector<void*> void_peer_buffers;
37
+ for (int64_t addr : peer_buffers) {
38
+ void_peer_buffers.push_back((void*)addr);
39
+ }
40
+
41
+ assert(bn_group == void_peer_buffers.size());
42
+ run_batch_norm_forward(
43
+ perChannelDims,
44
+ epsilonDims,
45
+ tensorDims,
46
+ peerDims,
47
+ x.data_ptr(),
48
+ y.data_ptr(),
49
+ scale.data_ptr(),
50
+ bias.data_ptr(),
51
+ running_mean.data_ptr(),
52
+ running_var.data_ptr(),
53
+ running_mean.data_ptr(),
54
+ running_var.data_ptr(),
55
+ minibatch_mean.data_ptr(),
56
+ minibatch_inv_var.data_ptr(),
57
+ void_peer_buffers,
58
+ epsilon,
59
+ momentum,
60
+ rank_id
61
+ );
62
+
63
+ return y;
64
+ }
65
+
66
+ std::vector<at::Tensor> gbn_backward(
67
+ const at::Tensor& x,
68
+ const at::Tensor& dy,
69
+ const at::Tensor& scale,
70
+ const at::Tensor& minibatch_mean,
71
+ const at::Tensor& minibatch_inv_var,
72
+ const float epsilon,
73
+ const int64_t bn_group,
74
+ const int rank_id,
75
+ const std::vector<int64_t> &peer_buffers) {
76
+
77
+ int64_t N = x.size(0);
78
+ int64_t C = x.size(1);
79
+ int64_t H = x.size(2);
80
+ int64_t W = x.size(3);
81
+
82
+ int64_t tensorDims[] = {N, C, H, W};
83
+ int64_t peerDims[] = {bn_group, 4*C, 1, 1};
84
+ int64_t perChannelDims[] = {1, C, 1, 1};
85
+ int64_t epsilonDims[] = {1, 1, 1, 1};
86
+
87
+ // Allocate output tensor
88
+ // outputs
89
+ at::Tensor x_grad, scale_grad, bias_grad;
90
+
91
+ // Allocate outputs
92
+ x_grad = at::empty_like(x);
93
+ scale_grad = at::empty_like(scale);
94
+ bias_grad = at::empty_like(scale);
95
+
96
+ std::vector<void*> void_peer_buffers;
97
+ for (int64_t addr : peer_buffers) {
98
+ void_peer_buffers.push_back((void*)addr);
99
+ }
100
+
101
+ assert(bn_group == void_peer_buffers.size());
102
+
103
+ run_batch_norm_backward(
104
+ perChannelDims,
105
+ epsilonDims,
106
+ tensorDims,
107
+ peerDims,
108
+ x.data_ptr(),
109
+ dy.data_ptr(),
110
+ scale.data_ptr(),
111
+ minibatch_mean.data_ptr(),
112
+ minibatch_inv_var.data_ptr(),
113
+ x_grad.data_ptr(),
114
+ scale_grad.data_ptr(),
115
+ bias_grad.data_ptr(),
116
+ void_peer_buffers,
117
+ epsilon,
118
+ rank_id);
119
+
120
+
121
+
122
+ return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
123
+ }
124
+
125
+
126
+
127
+
128
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
129
+ m.def("forward", &gbn_forward, "Group batch norm forward");
130
+ m.def("backward", &gbn_backward, "Group batch backward");
131
+ }