Ubuntu
commited on
Commit
·
a3d4355
1
Parent(s):
506a7e5
feat: torch 2.5.1+cu124
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +1 -0
- README.md +20 -0
- api.py +800 -0
- build/lib.linux-x86_64-3.10/dropout_layer_norm.cpython-310-x86_64-linux-gnu.so +3 -0
- ln.h +281 -0
- ln_api.cpp +850 -0
- ln_bwd_1024.cu +15 -0
- ln_bwd_1280.cu +15 -0
- ln_bwd_1536.cu +15 -0
- ln_bwd_2048.cu +15 -0
- ln_bwd_256.cu +15 -0
- ln_bwd_2560.cu +15 -0
- ln_bwd_3072.cu +15 -0
- ln_bwd_4096.cu +15 -0
- ln_bwd_512.cu +15 -0
- ln_bwd_5120.cu +15 -0
- ln_bwd_6144.cu +15 -0
- ln_bwd_7168.cu +15 -0
- ln_bwd_768.cu +15 -0
- ln_bwd_8192.cu +15 -0
- ln_bwd_kernels.cuh +534 -0
- ln_fwd_1024.cu +15 -0
- ln_fwd_1280.cu +15 -0
- ln_fwd_1536.cu +15 -0
- ln_fwd_2048.cu +15 -0
- ln_fwd_256.cu +15 -0
- ln_fwd_2560.cu +15 -0
- ln_fwd_3072.cu +15 -0
- ln_fwd_4096.cu +15 -0
- ln_fwd_512.cu +15 -0
- ln_fwd_5120.cu +15 -0
- ln_fwd_6144.cu +15 -0
- ln_fwd_7168.cu +15 -0
- ln_fwd_768.cu +15 -0
- ln_fwd_8192.cu +15 -0
- ln_fwd_kernels.cuh +272 -0
- ln_kernel_traits.h +172 -0
- ln_parallel_bwd_1024.cu +15 -0
- ln_parallel_bwd_1280.cu +15 -0
- ln_parallel_bwd_1536.cu +15 -0
- ln_parallel_bwd_2048.cu +15 -0
- ln_parallel_bwd_256.cu +15 -0
- ln_parallel_bwd_2560.cu +15 -0
- ln_parallel_bwd_3072.cu +15 -0
- ln_parallel_bwd_4096.cu +17 -0
- ln_parallel_bwd_512.cu +15 -0
- ln_parallel_bwd_5120.cu +17 -0
- ln_parallel_bwd_6144.cu +15 -0
- ln_parallel_bwd_7168.cu +15 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/build/temp*
|
README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This CUDA extension implements fused dropout + residual + LayerNorm, building on
|
| 2 |
+
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
|
| 3 |
+
Major changes:
|
| 4 |
+
- Add dropout and residual.
|
| 5 |
+
- Make it work for both pre-norm and post-norm architecture.
|
| 6 |
+
- Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
|
| 7 |
+
- Implement RMSNorm as an option.
|
| 8 |
+
- Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).
|
| 9 |
+
|
| 10 |
+
If you want to use it for dimensions larger than 8k, please file an issue.
|
| 11 |
+
|
| 12 |
+
This extension has only been tested on A100s.
|
| 13 |
+
|
| 14 |
+
```sh
|
| 15 |
+
cd csrc/layer_norm && pip install .
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
|
| 19 |
+
We've instead switched to a Triton-based
|
| 20 |
+
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py).
|
api.py
ADDED
|
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, Tri Dao.
|
| 2 |
+
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
|
| 3 |
+
|
| 4 |
+
import dropout_layer_norm
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn import init
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def maybe_align(x, alignment_in_bytes=16):
|
| 10 |
+
"""Assume that x already has last dim divisible by alignment_in_bytes"""
|
| 11 |
+
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
|
| 12 |
+
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
|
| 13 |
+
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _dropout_add_layer_norm_forward(
|
| 17 |
+
x0,
|
| 18 |
+
residual,
|
| 19 |
+
gamma,
|
| 20 |
+
beta,
|
| 21 |
+
rowscale,
|
| 22 |
+
colscale,
|
| 23 |
+
dropout_p,
|
| 24 |
+
epsilon,
|
| 25 |
+
residual_in_fp32=False,
|
| 26 |
+
is_rms_norm=False,
|
| 27 |
+
):
|
| 28 |
+
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
| 29 |
+
hidden_size = gamma.numel()
|
| 30 |
+
x0mat = x0.view((-1, hidden_size))
|
| 31 |
+
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
| 32 |
+
rowscale = rowscale.view(-1) if rowscale is not None else None
|
| 33 |
+
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
| 34 |
+
x0mat,
|
| 35 |
+
residualmat,
|
| 36 |
+
gamma,
|
| 37 |
+
beta,
|
| 38 |
+
rowscale,
|
| 39 |
+
colscale,
|
| 40 |
+
None,
|
| 41 |
+
None,
|
| 42 |
+
dropout_p,
|
| 43 |
+
epsilon,
|
| 44 |
+
1.0,
|
| 45 |
+
0,
|
| 46 |
+
None,
|
| 47 |
+
residual_in_fp32,
|
| 48 |
+
is_rms_norm,
|
| 49 |
+
)
|
| 50 |
+
# dmask is None if dropout_p == 0.0
|
| 51 |
+
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
| 52 |
+
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _dropout_add_layer_norm_backward(
|
| 56 |
+
dz,
|
| 57 |
+
dx,
|
| 58 |
+
x,
|
| 59 |
+
x0,
|
| 60 |
+
dmask,
|
| 61 |
+
mu,
|
| 62 |
+
rsigma,
|
| 63 |
+
gamma,
|
| 64 |
+
rowscale,
|
| 65 |
+
colscale,
|
| 66 |
+
dropout_p,
|
| 67 |
+
has_residual,
|
| 68 |
+
is_rms_norm=False,
|
| 69 |
+
):
|
| 70 |
+
"""Assume that arguments are contiguous and aligned to 16 bytes
|
| 71 |
+
dx == None means that it was a post-norm architecture
|
| 72 |
+
(x = drop(x0) + residual was not returned in the fwd).
|
| 73 |
+
x0 must not be None if we have colscale.
|
| 74 |
+
"""
|
| 75 |
+
hidden_size = gamma.numel()
|
| 76 |
+
xmat = x.view((-1, hidden_size))
|
| 77 |
+
dzmat = dz.view(xmat.shape)
|
| 78 |
+
dxmat = dx.view(xmat.shape) if dx is not None else None
|
| 79 |
+
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
| 80 |
+
rowscale = rowscale.view(-1) if rowscale is not None else None
|
| 81 |
+
if colscale is not None:
|
| 82 |
+
assert x0 is not None, "x0 is required to compute the gradient of colscale"
|
| 83 |
+
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
| 84 |
+
dzmat,
|
| 85 |
+
dxmat,
|
| 86 |
+
xmat,
|
| 87 |
+
x0mat,
|
| 88 |
+
dmask,
|
| 89 |
+
mu,
|
| 90 |
+
rsigma,
|
| 91 |
+
gamma,
|
| 92 |
+
rowscale,
|
| 93 |
+
colscale,
|
| 94 |
+
None,
|
| 95 |
+
None,
|
| 96 |
+
dropout_p,
|
| 97 |
+
1.0,
|
| 98 |
+
0,
|
| 99 |
+
has_residual,
|
| 100 |
+
is_rms_norm,
|
| 101 |
+
)
|
| 102 |
+
# dresidualmat is None if not has_residual
|
| 103 |
+
if colscale is None:
|
| 104 |
+
return dx0mat, dresidualmat, dgamma, dbeta
|
| 105 |
+
else:
|
| 106 |
+
dcolscale = rest[0]
|
| 107 |
+
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _dropout_add_layer_norm_subset_forward(
|
| 111 |
+
x0,
|
| 112 |
+
residual,
|
| 113 |
+
gamma,
|
| 114 |
+
beta,
|
| 115 |
+
colscale,
|
| 116 |
+
x0_subset,
|
| 117 |
+
out_subset,
|
| 118 |
+
dropout_p,
|
| 119 |
+
epsilon,
|
| 120 |
+
rowscale_const,
|
| 121 |
+
out_numrows,
|
| 122 |
+
residual_in_fp32=False,
|
| 123 |
+
is_rms_norm=False,
|
| 124 |
+
):
|
| 125 |
+
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
| 126 |
+
hidden_size = gamma.numel()
|
| 127 |
+
x0mat = x0.view((-1, hidden_size))
|
| 128 |
+
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
| 129 |
+
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
| 130 |
+
out_subset = out_subset.view(-1) if out_subset is not None else None
|
| 131 |
+
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
| 132 |
+
x0mat,
|
| 133 |
+
residualmat,
|
| 134 |
+
gamma,
|
| 135 |
+
beta,
|
| 136 |
+
None,
|
| 137 |
+
colscale,
|
| 138 |
+
x0_subset,
|
| 139 |
+
out_subset,
|
| 140 |
+
dropout_p,
|
| 141 |
+
epsilon,
|
| 142 |
+
rowscale_const,
|
| 143 |
+
out_numrows,
|
| 144 |
+
None,
|
| 145 |
+
residual_in_fp32,
|
| 146 |
+
is_rms_norm,
|
| 147 |
+
)
|
| 148 |
+
# dmask is None if dropout_p == 0.0
|
| 149 |
+
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
| 150 |
+
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _dropout_add_layer_norm_subset_backward(
|
| 154 |
+
dz,
|
| 155 |
+
dx,
|
| 156 |
+
x,
|
| 157 |
+
x0,
|
| 158 |
+
dmask,
|
| 159 |
+
mu,
|
| 160 |
+
rsigma,
|
| 161 |
+
gamma,
|
| 162 |
+
colscale,
|
| 163 |
+
x0_subset,
|
| 164 |
+
out_subset,
|
| 165 |
+
dropout_p,
|
| 166 |
+
rowscale_const,
|
| 167 |
+
x0_numrows,
|
| 168 |
+
has_residual,
|
| 169 |
+
is_rms_norm=False,
|
| 170 |
+
):
|
| 171 |
+
"""Assume that arguments are contiguous and aligned to 16 bytes
|
| 172 |
+
dx == None means that it was a post-norm architecture
|
| 173 |
+
(x = drop(x0) + residual was not returned in the fwd).
|
| 174 |
+
x0 must not be None if we have colscale.
|
| 175 |
+
"""
|
| 176 |
+
hidden_size = gamma.numel()
|
| 177 |
+
xmat = x.view((-1, hidden_size))
|
| 178 |
+
dzmat = dz.view(-1, hidden_size)
|
| 179 |
+
dxmat = dx.view(xmat.shape) if dx is not None else None
|
| 180 |
+
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
| 181 |
+
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
| 182 |
+
out_subset = out_subset.view(-1) if out_subset is not None else None
|
| 183 |
+
if colscale is not None:
|
| 184 |
+
assert x0 is not None, "x0 is required to compute the gradient of colscale"
|
| 185 |
+
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
| 186 |
+
dzmat,
|
| 187 |
+
dxmat,
|
| 188 |
+
xmat,
|
| 189 |
+
x0mat,
|
| 190 |
+
dmask,
|
| 191 |
+
mu,
|
| 192 |
+
rsigma,
|
| 193 |
+
gamma,
|
| 194 |
+
None,
|
| 195 |
+
colscale,
|
| 196 |
+
x0_subset,
|
| 197 |
+
out_subset,
|
| 198 |
+
dropout_p,
|
| 199 |
+
rowscale_const,
|
| 200 |
+
x0_numrows,
|
| 201 |
+
has_residual,
|
| 202 |
+
is_rms_norm,
|
| 203 |
+
)
|
| 204 |
+
# dresidualmat is None if not has_residual
|
| 205 |
+
if colscale is None:
|
| 206 |
+
return dx0mat, dresidualmat, dgamma, dbeta
|
| 207 |
+
else:
|
| 208 |
+
dcolscale = rest[0]
|
| 209 |
+
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _dropout_add_layer_norm_parallel_residual_forward(
|
| 213 |
+
x0,
|
| 214 |
+
x1,
|
| 215 |
+
residual,
|
| 216 |
+
gamma0,
|
| 217 |
+
beta0,
|
| 218 |
+
gamma1,
|
| 219 |
+
beta1,
|
| 220 |
+
dropout_p,
|
| 221 |
+
epsilon,
|
| 222 |
+
residual_in_fp32=False,
|
| 223 |
+
is_rms_norm=False,
|
| 224 |
+
):
|
| 225 |
+
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
| 226 |
+
hidden_size = gamma0.numel()
|
| 227 |
+
x0mat = x0.view((-1, hidden_size))
|
| 228 |
+
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
| 229 |
+
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
| 230 |
+
(
|
| 231 |
+
z0mat,
|
| 232 |
+
z1mat,
|
| 233 |
+
xmat,
|
| 234 |
+
dmask0,
|
| 235 |
+
dmask1,
|
| 236 |
+
mu,
|
| 237 |
+
rsigma,
|
| 238 |
+
) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
|
| 239 |
+
x0mat,
|
| 240 |
+
x1mat,
|
| 241 |
+
residualmat,
|
| 242 |
+
gamma0,
|
| 243 |
+
beta0,
|
| 244 |
+
gamma1,
|
| 245 |
+
beta1,
|
| 246 |
+
dropout_p,
|
| 247 |
+
epsilon,
|
| 248 |
+
None,
|
| 249 |
+
residual_in_fp32,
|
| 250 |
+
is_rms_norm,
|
| 251 |
+
)
|
| 252 |
+
# dmask0 and dmask1 are None if dropout_p == 0.0
|
| 253 |
+
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
| 254 |
+
return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _dropout_add_layer_norm_parallel_residual_backward(
|
| 258 |
+
dz0,
|
| 259 |
+
dz1,
|
| 260 |
+
dx,
|
| 261 |
+
x,
|
| 262 |
+
dmask0,
|
| 263 |
+
dmask1,
|
| 264 |
+
mu,
|
| 265 |
+
rsigma,
|
| 266 |
+
gamma0,
|
| 267 |
+
gamma1,
|
| 268 |
+
dropout_p,
|
| 269 |
+
has_x1,
|
| 270 |
+
has_residual,
|
| 271 |
+
is_rms_norm=False,
|
| 272 |
+
):
|
| 273 |
+
"""Assume that arguments are contiguous and aligned to 16 bytes
|
| 274 |
+
dx == None means that it was a post-norm architecture
|
| 275 |
+
(x = drop(x0) + residual was not returned in the fwd).
|
| 276 |
+
"""
|
| 277 |
+
hidden_size = gamma0.numel()
|
| 278 |
+
xmat = x.view((-1, hidden_size))
|
| 279 |
+
dz0mat = dz0.view(xmat.shape)
|
| 280 |
+
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
|
| 281 |
+
dxmat = dx.view(xmat.shape) if dx is not None else None
|
| 282 |
+
(
|
| 283 |
+
dx0mat,
|
| 284 |
+
dx1mat,
|
| 285 |
+
dresidualmat,
|
| 286 |
+
dgamma0,
|
| 287 |
+
dbeta0,
|
| 288 |
+
dgamma1,
|
| 289 |
+
dbeta1,
|
| 290 |
+
*rest,
|
| 291 |
+
) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
|
| 292 |
+
dz0mat,
|
| 293 |
+
dz1mat,
|
| 294 |
+
dxmat,
|
| 295 |
+
xmat,
|
| 296 |
+
dmask0,
|
| 297 |
+
dmask1,
|
| 298 |
+
mu,
|
| 299 |
+
rsigma,
|
| 300 |
+
gamma0,
|
| 301 |
+
gamma1,
|
| 302 |
+
dropout_p,
|
| 303 |
+
has_x1,
|
| 304 |
+
has_residual,
|
| 305 |
+
is_rms_norm,
|
| 306 |
+
)
|
| 307 |
+
# dresidualmat is None if not has_residual
|
| 308 |
+
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class DropoutAddLayerNormFn(torch.autograd.Function):
|
| 312 |
+
@staticmethod
|
| 313 |
+
def forward(
|
| 314 |
+
ctx,
|
| 315 |
+
x0,
|
| 316 |
+
residual,
|
| 317 |
+
gamma,
|
| 318 |
+
beta,
|
| 319 |
+
rowscale,
|
| 320 |
+
colscale,
|
| 321 |
+
dropout_p,
|
| 322 |
+
epsilon,
|
| 323 |
+
residual_in_fp32=False,
|
| 324 |
+
prenorm=False,
|
| 325 |
+
is_rms_norm=False,
|
| 326 |
+
return_dmask=False,
|
| 327 |
+
):
|
| 328 |
+
x0 = maybe_align(x0.contiguous(), 16)
|
| 329 |
+
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
| 330 |
+
gamma = maybe_align(gamma.contiguous(), 16)
|
| 331 |
+
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
| 332 |
+
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
|
| 333 |
+
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
| 334 |
+
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
| 335 |
+
x0,
|
| 336 |
+
residual,
|
| 337 |
+
gamma,
|
| 338 |
+
beta,
|
| 339 |
+
rowscale,
|
| 340 |
+
colscale,
|
| 341 |
+
dropout_p,
|
| 342 |
+
epsilon,
|
| 343 |
+
residual_in_fp32,
|
| 344 |
+
is_rms_norm,
|
| 345 |
+
)
|
| 346 |
+
# Only need to save x0 if we need to compute gradient wrt colscale
|
| 347 |
+
x0_saved = x0 if colscale is not None else None
|
| 348 |
+
ctx.save_for_backward(
|
| 349 |
+
xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
|
| 350 |
+
)
|
| 351 |
+
ctx.prenorm = prenorm
|
| 352 |
+
ctx.dropout_p = dropout_p
|
| 353 |
+
ctx.has_residual = residual is not None
|
| 354 |
+
ctx.is_rms_norm = is_rms_norm
|
| 355 |
+
ctx.has_beta = beta is not None
|
| 356 |
+
if not return_dmask:
|
| 357 |
+
return (
|
| 358 |
+
zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
dmask = (
|
| 362 |
+
dmask.view(x0.shape)
|
| 363 |
+
if dropout_p > 0.0
|
| 364 |
+
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
| 365 |
+
)
|
| 366 |
+
ctx.mark_non_differentiable(dmask)
|
| 367 |
+
return (
|
| 368 |
+
(zmat.view(x0.shape), dmask)
|
| 369 |
+
if not prenorm
|
| 370 |
+
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
@staticmethod
|
| 374 |
+
def backward(ctx, dz, *args):
|
| 375 |
+
# assert dz.is_contiguous()
|
| 376 |
+
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
| 377 |
+
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
| 378 |
+
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
|
| 379 |
+
# x0 is None if colscale is None
|
| 380 |
+
dropout_p = ctx.dropout_p
|
| 381 |
+
has_residual = ctx.has_residual
|
| 382 |
+
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
| 383 |
+
dz,
|
| 384 |
+
dx,
|
| 385 |
+
x,
|
| 386 |
+
x0,
|
| 387 |
+
dmask,
|
| 388 |
+
mu,
|
| 389 |
+
rsigma,
|
| 390 |
+
gamma,
|
| 391 |
+
rowscale,
|
| 392 |
+
colscale,
|
| 393 |
+
dropout_p,
|
| 394 |
+
has_residual,
|
| 395 |
+
ctx.is_rms_norm,
|
| 396 |
+
)
|
| 397 |
+
dx0 = dx0mat.view(x.shape)
|
| 398 |
+
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
| 399 |
+
dcolscale = rest[0] if colscale is not None else None
|
| 400 |
+
return (
|
| 401 |
+
dx0,
|
| 402 |
+
dresidual,
|
| 403 |
+
dgamma,
|
| 404 |
+
dbeta if ctx.has_beta else None,
|
| 405 |
+
None,
|
| 406 |
+
dcolscale,
|
| 407 |
+
None,
|
| 408 |
+
None,
|
| 409 |
+
None,
|
| 410 |
+
None,
|
| 411 |
+
None,
|
| 412 |
+
None,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
| 417 |
+
@staticmethod
|
| 418 |
+
def forward(
|
| 419 |
+
ctx,
|
| 420 |
+
x0,
|
| 421 |
+
residual,
|
| 422 |
+
gamma,
|
| 423 |
+
beta,
|
| 424 |
+
colscale,
|
| 425 |
+
x0_subset,
|
| 426 |
+
out_subset,
|
| 427 |
+
dropout_p,
|
| 428 |
+
epsilon,
|
| 429 |
+
rowscale_const,
|
| 430 |
+
out_numrows,
|
| 431 |
+
residual_in_fp32=False,
|
| 432 |
+
prenorm=False,
|
| 433 |
+
is_rms_norm=False,
|
| 434 |
+
return_dmask=False,
|
| 435 |
+
):
|
| 436 |
+
x0 = maybe_align(x0.contiguous(), 16)
|
| 437 |
+
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
| 438 |
+
gamma = maybe_align(gamma.contiguous(), 16)
|
| 439 |
+
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
| 440 |
+
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
| 441 |
+
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
| 442 |
+
x0,
|
| 443 |
+
residual,
|
| 444 |
+
gamma,
|
| 445 |
+
beta,
|
| 446 |
+
colscale,
|
| 447 |
+
x0_subset,
|
| 448 |
+
out_subset,
|
| 449 |
+
dropout_p,
|
| 450 |
+
epsilon,
|
| 451 |
+
rowscale_const,
|
| 452 |
+
out_numrows,
|
| 453 |
+
residual_in_fp32,
|
| 454 |
+
is_rms_norm,
|
| 455 |
+
)
|
| 456 |
+
# Only need to save x0 if we need to compute gradient wrt colscale
|
| 457 |
+
x0_saved = x0 if colscale is not None else None
|
| 458 |
+
x_shape = (-1, *x0.shape[1:])
|
| 459 |
+
ctx.save_for_backward(
|
| 460 |
+
xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
|
| 461 |
+
)
|
| 462 |
+
ctx.prenorm = prenorm
|
| 463 |
+
ctx.dropout_p = dropout_p
|
| 464 |
+
ctx.rowscale_const = rowscale_const
|
| 465 |
+
ctx.x0_numrows = x0.shape[:-1].numel()
|
| 466 |
+
ctx.has_residual = residual is not None
|
| 467 |
+
ctx.is_rms_norm = is_rms_norm
|
| 468 |
+
ctx.has_beta = beta is not None
|
| 469 |
+
z_shape = (-1, *x0.shape[1:])
|
| 470 |
+
if not return_dmask:
|
| 471 |
+
return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
|
| 472 |
+
else:
|
| 473 |
+
z = zmat.view(z_shape)
|
| 474 |
+
dmask = (
|
| 475 |
+
dmask.view(x0.shape)
|
| 476 |
+
if dropout_p > 0.0
|
| 477 |
+
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
| 478 |
+
)
|
| 479 |
+
ctx.mark_non_differentiable(dmask)
|
| 480 |
+
return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
|
| 481 |
+
|
| 482 |
+
@staticmethod
|
| 483 |
+
def backward(ctx, dz, *args):
|
| 484 |
+
# assert dz.is_contiguous()
|
| 485 |
+
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
| 486 |
+
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
| 487 |
+
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
|
| 488 |
+
# x0 is None if colscale is None
|
| 489 |
+
dropout_p = ctx.dropout_p
|
| 490 |
+
has_residual = ctx.has_residual
|
| 491 |
+
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
|
| 492 |
+
dz,
|
| 493 |
+
dx,
|
| 494 |
+
x,
|
| 495 |
+
x0,
|
| 496 |
+
dmask,
|
| 497 |
+
mu,
|
| 498 |
+
rsigma,
|
| 499 |
+
gamma,
|
| 500 |
+
colscale,
|
| 501 |
+
x0_subset,
|
| 502 |
+
out_subset,
|
| 503 |
+
dropout_p,
|
| 504 |
+
ctx.rowscale_const,
|
| 505 |
+
ctx.x0_numrows,
|
| 506 |
+
has_residual,
|
| 507 |
+
ctx.is_rms_norm,
|
| 508 |
+
)
|
| 509 |
+
dx0 = dx0mat.view(-1, *x.shape[1:])
|
| 510 |
+
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
| 511 |
+
dcolscale = rest[0] if colscale is not None else None
|
| 512 |
+
return (
|
| 513 |
+
dx0,
|
| 514 |
+
dresidual,
|
| 515 |
+
dgamma,
|
| 516 |
+
dbeta if ctx.has_beta else None,
|
| 517 |
+
dcolscale,
|
| 518 |
+
None,
|
| 519 |
+
None,
|
| 520 |
+
None,
|
| 521 |
+
None,
|
| 522 |
+
None,
|
| 523 |
+
None,
|
| 524 |
+
None,
|
| 525 |
+
None,
|
| 526 |
+
None,
|
| 527 |
+
None,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
| 532 |
+
@staticmethod
|
| 533 |
+
def forward(
|
| 534 |
+
ctx,
|
| 535 |
+
x0,
|
| 536 |
+
x1,
|
| 537 |
+
residual,
|
| 538 |
+
gamma0,
|
| 539 |
+
beta0,
|
| 540 |
+
gamma1,
|
| 541 |
+
beta1,
|
| 542 |
+
dropout_p,
|
| 543 |
+
epsilon,
|
| 544 |
+
residual_in_fp32=False,
|
| 545 |
+
prenorm=False,
|
| 546 |
+
is_rms_norm=False,
|
| 547 |
+
return_dmask=False,
|
| 548 |
+
):
|
| 549 |
+
x0 = maybe_align(x0.contiguous(), 16)
|
| 550 |
+
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
|
| 551 |
+
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
| 552 |
+
gamma0 = maybe_align(gamma0.contiguous(), 16)
|
| 553 |
+
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
|
| 554 |
+
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
|
| 555 |
+
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
|
| 556 |
+
(
|
| 557 |
+
z0mat,
|
| 558 |
+
z1mat,
|
| 559 |
+
xmat,
|
| 560 |
+
dmask0,
|
| 561 |
+
dmask1,
|
| 562 |
+
mu,
|
| 563 |
+
rsigma,
|
| 564 |
+
) = _dropout_add_layer_norm_parallel_residual_forward(
|
| 565 |
+
x0,
|
| 566 |
+
x1,
|
| 567 |
+
residual,
|
| 568 |
+
gamma0,
|
| 569 |
+
beta0,
|
| 570 |
+
gamma1,
|
| 571 |
+
beta1,
|
| 572 |
+
dropout_p,
|
| 573 |
+
epsilon,
|
| 574 |
+
residual_in_fp32,
|
| 575 |
+
is_rms_norm,
|
| 576 |
+
)
|
| 577 |
+
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
|
| 578 |
+
ctx.prenorm = prenorm
|
| 579 |
+
ctx.dropout_p = dropout_p
|
| 580 |
+
ctx.has_x1 = x1 is not None
|
| 581 |
+
ctx.has_residual = residual is not None
|
| 582 |
+
ctx.is_rms_norm = is_rms_norm
|
| 583 |
+
ctx.has_beta = beta0 is not None
|
| 584 |
+
z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
|
| 585 |
+
if not return_dmask:
|
| 586 |
+
return z if not prenorm else (*z, xmat.view(x0.shape))
|
| 587 |
+
else:
|
| 588 |
+
dmask0 = (
|
| 589 |
+
dmask0.view(x0.shape)
|
| 590 |
+
if dropout_p > 0.0
|
| 591 |
+
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
| 592 |
+
)
|
| 593 |
+
dmask1 = (
|
| 594 |
+
dmask1.view(x0.shape)
|
| 595 |
+
if dropout_p > 0.0 and x1 is not None
|
| 596 |
+
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
| 597 |
+
)
|
| 598 |
+
ctx.mark_non_differentiable(dmask0)
|
| 599 |
+
ctx.mark_non_differentiable(dmask1)
|
| 600 |
+
return (
|
| 601 |
+
(*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
@staticmethod
|
| 605 |
+
def backward(ctx, dz0, dz1, *args):
|
| 606 |
+
dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
|
| 607 |
+
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
|
| 608 |
+
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
| 609 |
+
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
|
| 610 |
+
dropout_p = ctx.dropout_p
|
| 611 |
+
has_x1 = ctx.has_x1
|
| 612 |
+
has_residual = ctx.has_residual
|
| 613 |
+
(
|
| 614 |
+
dx0mat,
|
| 615 |
+
dx1mat,
|
| 616 |
+
dresidualmat,
|
| 617 |
+
dgamma0,
|
| 618 |
+
dbeta0,
|
| 619 |
+
dgamma1,
|
| 620 |
+
dbeta1,
|
| 621 |
+
) = _dropout_add_layer_norm_parallel_residual_backward(
|
| 622 |
+
dz0,
|
| 623 |
+
dz1,
|
| 624 |
+
dx,
|
| 625 |
+
x,
|
| 626 |
+
dmask0,
|
| 627 |
+
dmask1,
|
| 628 |
+
mu,
|
| 629 |
+
rsigma,
|
| 630 |
+
gamma0,
|
| 631 |
+
gamma1,
|
| 632 |
+
dropout_p,
|
| 633 |
+
has_x1,
|
| 634 |
+
has_residual,
|
| 635 |
+
ctx.is_rms_norm,
|
| 636 |
+
)
|
| 637 |
+
dx0 = dx0mat.view(x.shape)
|
| 638 |
+
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
| 639 |
+
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
| 640 |
+
return (
|
| 641 |
+
dx0,
|
| 642 |
+
dx1,
|
| 643 |
+
dresidual,
|
| 644 |
+
dgamma0,
|
| 645 |
+
dbeta0 if ctx.has_beta else None,
|
| 646 |
+
dgamma1,
|
| 647 |
+
dbeta1 if ctx.has_beta else None,
|
| 648 |
+
None,
|
| 649 |
+
None,
|
| 650 |
+
None,
|
| 651 |
+
None,
|
| 652 |
+
None,
|
| 653 |
+
None,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def layer_norm(x, weight, bias, epsilon):
|
| 658 |
+
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def dropout_add_layer_norm(
|
| 662 |
+
x0,
|
| 663 |
+
residual,
|
| 664 |
+
weight,
|
| 665 |
+
bias,
|
| 666 |
+
dropout_p,
|
| 667 |
+
epsilon,
|
| 668 |
+
rowscale=None,
|
| 669 |
+
layerscale=None,
|
| 670 |
+
prenorm=False,
|
| 671 |
+
residual_in_fp32=False,
|
| 672 |
+
return_dropout_mask=False,
|
| 673 |
+
):
|
| 674 |
+
"""residual_in_fp32 only has an effect if residual is None.
|
| 675 |
+
Otherwise residual dtype is residual.dtype.
|
| 676 |
+
"""
|
| 677 |
+
return DropoutAddLayerNormFn.apply(
|
| 678 |
+
x0,
|
| 679 |
+
residual,
|
| 680 |
+
weight,
|
| 681 |
+
bias,
|
| 682 |
+
rowscale,
|
| 683 |
+
layerscale,
|
| 684 |
+
dropout_p,
|
| 685 |
+
epsilon,
|
| 686 |
+
residual_in_fp32,
|
| 687 |
+
prenorm,
|
| 688 |
+
False,
|
| 689 |
+
return_dropout_mask,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def dropout_add_layer_norm_subset(
|
| 694 |
+
x0,
|
| 695 |
+
residual,
|
| 696 |
+
weight,
|
| 697 |
+
bias,
|
| 698 |
+
dropout_p,
|
| 699 |
+
epsilon,
|
| 700 |
+
layerscale=None,
|
| 701 |
+
x0_subset=None,
|
| 702 |
+
out_subset=None,
|
| 703 |
+
rowscale_const=1.0,
|
| 704 |
+
out_numrows=0,
|
| 705 |
+
prenorm=False,
|
| 706 |
+
residual_in_fp32=False,
|
| 707 |
+
return_dropout_mask=False,
|
| 708 |
+
):
|
| 709 |
+
"""residual_in_fp32 only has an effect if residual is None.
|
| 710 |
+
Otherwise residual dtype is residual.dtype.
|
| 711 |
+
"""
|
| 712 |
+
return DropoutAddLayerNormSubsetFn.apply(
|
| 713 |
+
x0,
|
| 714 |
+
residual,
|
| 715 |
+
weight,
|
| 716 |
+
bias,
|
| 717 |
+
layerscale,
|
| 718 |
+
x0_subset,
|
| 719 |
+
out_subset,
|
| 720 |
+
dropout_p,
|
| 721 |
+
epsilon,
|
| 722 |
+
rowscale_const,
|
| 723 |
+
out_numrows,
|
| 724 |
+
residual_in_fp32,
|
| 725 |
+
prenorm,
|
| 726 |
+
False,
|
| 727 |
+
return_dropout_mask,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def dropout_add_layer_norm_parallel_residual(
|
| 732 |
+
x0,
|
| 733 |
+
x1,
|
| 734 |
+
residual,
|
| 735 |
+
weight0,
|
| 736 |
+
bias0,
|
| 737 |
+
weight1,
|
| 738 |
+
bias1,
|
| 739 |
+
dropout_p,
|
| 740 |
+
epsilon,
|
| 741 |
+
prenorm=False,
|
| 742 |
+
residual_in_fp32=False,
|
| 743 |
+
return_dropout_mask=False,
|
| 744 |
+
):
|
| 745 |
+
"""residual_in_fp32 only has an effect if residual is None.
|
| 746 |
+
Otherwise residual dtype is residual.dtype.
|
| 747 |
+
"""
|
| 748 |
+
return DropoutAddLayerNormParallelResidualFn.apply(
|
| 749 |
+
x0,
|
| 750 |
+
x1,
|
| 751 |
+
residual,
|
| 752 |
+
weight0,
|
| 753 |
+
bias0,
|
| 754 |
+
weight1,
|
| 755 |
+
bias1,
|
| 756 |
+
dropout_p,
|
| 757 |
+
epsilon,
|
| 758 |
+
residual_in_fp32,
|
| 759 |
+
prenorm,
|
| 760 |
+
False,
|
| 761 |
+
return_dropout_mask,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
class DropoutAddLayerNorm(torch.nn.Module):
|
| 766 |
+
def __init__(
|
| 767 |
+
self,
|
| 768 |
+
hidden_size,
|
| 769 |
+
prenorm=False,
|
| 770 |
+
p=0.0,
|
| 771 |
+
eps=1e-5,
|
| 772 |
+
residual_in_fp32=False,
|
| 773 |
+
device=None,
|
| 774 |
+
dtype=None,
|
| 775 |
+
):
|
| 776 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 777 |
+
super().__init__()
|
| 778 |
+
self.prenorm = prenorm
|
| 779 |
+
self.p = p
|
| 780 |
+
self.eps = eps
|
| 781 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 782 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 783 |
+
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
| 784 |
+
self.reset_parameters()
|
| 785 |
+
|
| 786 |
+
def reset_parameters(self):
|
| 787 |
+
init.ones_(self.weight)
|
| 788 |
+
init.zeros_(self.bias)
|
| 789 |
+
|
| 790 |
+
def forward(self, x0, residual=None):
|
| 791 |
+
return dropout_add_layer_norm(
|
| 792 |
+
x0,
|
| 793 |
+
residual,
|
| 794 |
+
self.weight,
|
| 795 |
+
self.bias,
|
| 796 |
+
self.p if self.training else 0.0,
|
| 797 |
+
self.eps,
|
| 798 |
+
prenorm=self.prenorm,
|
| 799 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 800 |
+
)
|
build/lib.linux-x86_64-3.10/dropout_layer_norm.cpython-310-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aba674c175147bfdff6acb354745749070519df35f66522d71b2743aedc3b5a9
|
| 3 |
+
size 26705096
|
ln.h
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <unordered_map>
|
| 4 |
+
#include <cuda_fp16.h>
|
| 5 |
+
#include <cuda_bf16.h>
|
| 6 |
+
|
| 7 |
+
#ifdef OLD_GENERATOR_PATH
|
| 8 |
+
#include <ATen/CUDAGeneratorImpl.h>
|
| 9 |
+
#else
|
| 10 |
+
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
namespace layer_norm {
|
| 14 |
+
|
| 15 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 16 |
+
|
| 17 |
+
template<typename Params>
|
| 18 |
+
struct LaunchParams{
|
| 19 |
+
|
| 20 |
+
size_t elts_per_thread;
|
| 21 |
+
size_t workspace_bytes;
|
| 22 |
+
size_t barrier_size;
|
| 23 |
+
|
| 24 |
+
cudaDeviceProp * props;
|
| 25 |
+
|
| 26 |
+
cudaStream_t stream;
|
| 27 |
+
|
| 28 |
+
Params params;
|
| 29 |
+
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 33 |
+
|
| 34 |
+
struct ParamsBase {
|
| 35 |
+
ParamsBase()
|
| 36 |
+
: ctas_per_col(0)
|
| 37 |
+
, rows(0)
|
| 38 |
+
, cols(0)
|
| 39 |
+
, x(nullptr)
|
| 40 |
+
, mu(nullptr)
|
| 41 |
+
, rs(nullptr)
|
| 42 |
+
, gamma(nullptr)
|
| 43 |
+
, gamma1(nullptr)
|
| 44 |
+
, rowscale(nullptr)
|
| 45 |
+
, colscale(nullptr)
|
| 46 |
+
, dropout_keep_p(1.f)
|
| 47 |
+
, dropout_scale(1.f)
|
| 48 |
+
, is_rms_norm(false)
|
| 49 |
+
, workspace(nullptr)
|
| 50 |
+
, barrier(nullptr)
|
| 51 |
+
{
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
|
| 55 |
+
int ctas_per_col;
|
| 56 |
+
|
| 57 |
+
// Input is interpreted as matrix. We normalize across columns.
|
| 58 |
+
int rows;
|
| 59 |
+
int cols;
|
| 60 |
+
|
| 61 |
+
// Common data pointers.
|
| 62 |
+
void *x0;
|
| 63 |
+
void *x1;
|
| 64 |
+
void *residual;
|
| 65 |
+
void *x;
|
| 66 |
+
void *dmask;
|
| 67 |
+
void *dmask1;
|
| 68 |
+
void *mu;
|
| 69 |
+
void *rs;
|
| 70 |
+
void *gamma;
|
| 71 |
+
void *gamma1;
|
| 72 |
+
void *rowscale;
|
| 73 |
+
void *colscale;
|
| 74 |
+
void *x0_subset;
|
| 75 |
+
void *z_subset;
|
| 76 |
+
|
| 77 |
+
float inverse_cols;
|
| 78 |
+
|
| 79 |
+
float dropout_keep_p;
|
| 80 |
+
float dropout_scale;
|
| 81 |
+
float rowscale_const;
|
| 82 |
+
|
| 83 |
+
bool is_rms_norm;
|
| 84 |
+
|
| 85 |
+
// Multi-CTA workspace in gmem.
|
| 86 |
+
void *workspace;
|
| 87 |
+
|
| 88 |
+
// Multi-CTA sync barriers in gmem.
|
| 89 |
+
int *barrier;
|
| 90 |
+
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 94 |
+
|
| 95 |
+
struct FwdParams : public ParamsBase {
|
| 96 |
+
FwdParams()
|
| 97 |
+
: ParamsBase()
|
| 98 |
+
, z(nullptr)
|
| 99 |
+
, z1(nullptr)
|
| 100 |
+
, beta(nullptr)
|
| 101 |
+
, beta1(nullptr)
|
| 102 |
+
, epsilon(0.f)
|
| 103 |
+
{
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// Output of LN FWD.
|
| 107 |
+
void *z;
|
| 108 |
+
void *z1;
|
| 109 |
+
void *beta;
|
| 110 |
+
void *beta1;
|
| 111 |
+
float epsilon;
|
| 112 |
+
|
| 113 |
+
// Random state.
|
| 114 |
+
at::PhiloxCudaState philox_args;
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 118 |
+
|
| 119 |
+
struct BwdParams : public ParamsBase {
|
| 120 |
+
BwdParams()
|
| 121 |
+
: ParamsBase()
|
| 122 |
+
, dz(nullptr)
|
| 123 |
+
, dz1(nullptr)
|
| 124 |
+
, dx(nullptr)
|
| 125 |
+
, dbeta_part(nullptr)
|
| 126 |
+
, dgamma_part(nullptr)
|
| 127 |
+
, dbeta1_part(nullptr)
|
| 128 |
+
, dgamma1_part(nullptr)
|
| 129 |
+
, dcolscale_part(nullptr)
|
| 130 |
+
, dx0(nullptr)
|
| 131 |
+
, dx1(nullptr)
|
| 132 |
+
, dresidual(nullptr)
|
| 133 |
+
, dbeta(nullptr)
|
| 134 |
+
, dgamma(nullptr)
|
| 135 |
+
, dbeta1(nullptr)
|
| 136 |
+
, dgamma1(nullptr)
|
| 137 |
+
, dcolscale(nullptr)
|
| 138 |
+
{
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// Input: gradient wrt. LN FWD output.
|
| 142 |
+
void *dz;
|
| 143 |
+
void *dz1;
|
| 144 |
+
// Input: gradient wrt residual.
|
| 145 |
+
void *dx;
|
| 146 |
+
|
| 147 |
+
// Workspace for Wgrad pre-reduction.
|
| 148 |
+
void *dbeta_part;
|
| 149 |
+
void *dgamma_part;
|
| 150 |
+
void *dbeta1_part;
|
| 151 |
+
void *dgamma1_part;
|
| 152 |
+
void *dcolscale_part;
|
| 153 |
+
|
| 154 |
+
// Output: Dgrad.
|
| 155 |
+
void *dx0;
|
| 156 |
+
void *dx1;
|
| 157 |
+
void *dresidual;
|
| 158 |
+
// Output: Wgrad.
|
| 159 |
+
void *dbeta;
|
| 160 |
+
void *dgamma;
|
| 161 |
+
void *dbeta1;
|
| 162 |
+
void *dgamma1;
|
| 163 |
+
void *dcolscale;
|
| 164 |
+
|
| 165 |
+
};
|
| 166 |
+
|
| 167 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 168 |
+
|
| 169 |
+
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
|
| 170 |
+
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
|
| 171 |
+
using FunctionKey = uint64_t;
|
| 172 |
+
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
|
| 173 |
+
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
|
| 174 |
+
|
| 175 |
+
extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
|
| 176 |
+
extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
|
| 177 |
+
|
| 178 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 179 |
+
|
| 180 |
+
using fp32 = float;
|
| 181 |
+
using fp16 = half;
|
| 182 |
+
using bf16 = nv_bfloat16;
|
| 183 |
+
|
| 184 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 185 |
+
|
| 186 |
+
template<typename T>
|
| 187 |
+
struct TypeId{};
|
| 188 |
+
|
| 189 |
+
template<>
|
| 190 |
+
struct TypeId<fp16>{
|
| 191 |
+
constexpr static uint32_t Value = 0;
|
| 192 |
+
};
|
| 193 |
+
|
| 194 |
+
template<>
|
| 195 |
+
struct TypeId<bf16>{
|
| 196 |
+
constexpr static uint32_t Value = 1;
|
| 197 |
+
};
|
| 198 |
+
|
| 199 |
+
template<>
|
| 200 |
+
struct TypeId<fp32>{
|
| 201 |
+
constexpr static uint32_t Value = 2;
|
| 202 |
+
};
|
| 203 |
+
|
| 204 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 205 |
+
|
| 206 |
+
template<typename T, int S>
|
| 207 |
+
struct Type2Key{
|
| 208 |
+
constexpr static uint32_t Value = TypeId<T>::Value << S;
|
| 209 |
+
};
|
| 210 |
+
|
| 211 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 212 |
+
|
| 213 |
+
template<typename T>
|
| 214 |
+
struct WeightType2Key : public Type2Key<T, 0>{};
|
| 215 |
+
|
| 216 |
+
template<typename T>
|
| 217 |
+
struct InputType2Key : public Type2Key<T, 2>{};
|
| 218 |
+
|
| 219 |
+
template<typename T>
|
| 220 |
+
struct ResidualType2Key : public Type2Key<T, 4>{};
|
| 221 |
+
|
| 222 |
+
template<typename T>
|
| 223 |
+
struct OutputType2Key : public Type2Key<T, 6>{};
|
| 224 |
+
|
| 225 |
+
template<typename T>
|
| 226 |
+
struct ComputeType2Key : public Type2Key<T, 8>{};
|
| 227 |
+
|
| 228 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 229 |
+
|
| 230 |
+
template<typename W, typename I, typename R, typename O, typename C>
|
| 231 |
+
struct Types2Key{
|
| 232 |
+
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
|
| 233 |
+
constexpr static inline uint64_t get(const uint64_t hidden_size){
|
| 234 |
+
constexpr uint64_t type_key = Value;
|
| 235 |
+
return (type_key << 32) | hidden_size;
|
| 236 |
+
}
|
| 237 |
+
};
|
| 238 |
+
|
| 239 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 240 |
+
|
| 241 |
+
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
| 242 |
+
struct FwdRegistrar{
|
| 243 |
+
FwdRegistrar(FwdFunction f){
|
| 244 |
+
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
| 245 |
+
FWD_FUNCS.insert({ key, f });
|
| 246 |
+
}
|
| 247 |
+
};
|
| 248 |
+
|
| 249 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 250 |
+
|
| 251 |
+
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
| 252 |
+
struct BwdRegistrar{
|
| 253 |
+
BwdRegistrar(BwdFunction f){
|
| 254 |
+
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
| 255 |
+
BWD_FUNCS.insert({ key, f });
|
| 256 |
+
}
|
| 257 |
+
};
|
| 258 |
+
|
| 259 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 260 |
+
|
| 261 |
+
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
| 262 |
+
struct FwdParallelRegistrar{
|
| 263 |
+
FwdParallelRegistrar(FwdFunction f){
|
| 264 |
+
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
| 265 |
+
PARALLEL_FWD_FUNCS.insert({ key, f });
|
| 266 |
+
}
|
| 267 |
+
};
|
| 268 |
+
|
| 269 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 270 |
+
|
| 271 |
+
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
| 272 |
+
struct BwdParallelRegistrar{
|
| 273 |
+
BwdParallelRegistrar(BwdFunction f){
|
| 274 |
+
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
| 275 |
+
PARALLEL_BWD_FUNCS.insert({ key, f });
|
| 276 |
+
}
|
| 277 |
+
};
|
| 278 |
+
|
| 279 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 280 |
+
|
| 281 |
+
} // namespace layer_norm
|
ln_api.cpp
ADDED
|
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "ATen/cuda/CUDAContext.h"
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
|
| 5 |
+
#include "ln.h"
|
| 6 |
+
|
| 7 |
+
/*
|
| 8 |
+
|
| 9 |
+
Supported Type combinations:
|
| 10 |
+
|
| 11 |
+
input residual compute weights output
|
| 12 |
+
============================================
|
| 13 |
+
fp32 fp32 fp32 fp32 fp32
|
| 14 |
+
fp16 fp32 fp32 fp32 fp16
|
| 15 |
+
fp16 fp16 fp32 fp32 fp16
|
| 16 |
+
bf16 fp32 fp32 fp32 bf16
|
| 17 |
+
bf16 bf16 fp32 fp32 bf16
|
| 18 |
+
fp16 fp16 fp32 fp16 fp16
|
| 19 |
+
bf16 bf16 fp32 bf16 bf16
|
| 20 |
+
|
| 21 |
+
Remarks:
|
| 22 |
+
Output type = Input type
|
| 23 |
+
Compute always in FP32
|
| 24 |
+
|
| 25 |
+
*/
|
| 26 |
+
|
| 27 |
+
namespace layer_norm {
|
| 28 |
+
|
| 29 |
+
// Create registries and provide runtime versions of config hash functions.
|
| 30 |
+
|
| 31 |
+
FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
|
| 32 |
+
BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
|
| 33 |
+
|
| 34 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 35 |
+
|
| 36 |
+
uint32_t get_type_id(torch::Dtype dtype){
|
| 37 |
+
if( dtype == torch::kFloat16 ) {
|
| 38 |
+
return TypeId<fp16>::Value;
|
| 39 |
+
} else if( dtype == torch::kBFloat16 ) {
|
| 40 |
+
return TypeId<bf16>::Value;
|
| 41 |
+
} else if( dtype == torch::kFloat32 ) {
|
| 42 |
+
return TypeId<fp32>::Value;
|
| 43 |
+
} else {
|
| 44 |
+
TORCH_CHECK(false, "Type not supported: ", dtype);
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
|
| 51 |
+
using namespace layer_norm;
|
| 52 |
+
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
|
| 53 |
+
uint64_t launcher_key = (type_key << 32) | hidden_size;
|
| 54 |
+
return launcher_key;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
} // namespace layer_norm
|
| 58 |
+
|
| 59 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
| 62 |
+
auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
| 63 |
+
if( iter != layer_norm::FWD_FUNCS.end() ) {
|
| 64 |
+
return iter->second;
|
| 65 |
+
} else {
|
| 66 |
+
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 71 |
+
|
| 72 |
+
layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
| 73 |
+
auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
| 74 |
+
if( iter != layer_norm::BWD_FUNCS.end() ) {
|
| 75 |
+
return iter->second;
|
| 76 |
+
} else {
|
| 77 |
+
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 82 |
+
|
| 83 |
+
layer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
| 84 |
+
auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
| 85 |
+
if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) {
|
| 86 |
+
return iter->second;
|
| 87 |
+
} else {
|
| 88 |
+
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 93 |
+
|
| 94 |
+
layer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
| 95 |
+
auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
| 96 |
+
if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) {
|
| 97 |
+
return iter->second;
|
| 98 |
+
} else {
|
| 99 |
+
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 104 |
+
|
| 105 |
+
std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
|
| 106 |
+
c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
|
| 107 |
+
const at::Tensor &gamma, // hidden_size
|
| 108 |
+
c10::optional<const at::Tensor> &beta_, // hidden_size
|
| 109 |
+
c10::optional<const at::Tensor> &rowscale_, // BxS
|
| 110 |
+
c10::optional<const at::Tensor> &colscale_, // hidden_size
|
| 111 |
+
c10::optional<const at::Tensor> &x0_subset_, // BxS
|
| 112 |
+
c10::optional<const at::Tensor> &z_subset_, // BxS
|
| 113 |
+
const float dropout_p,
|
| 114 |
+
const float epsilon,
|
| 115 |
+
const float rowscale_const,
|
| 116 |
+
const int64_t z_numrows,
|
| 117 |
+
c10::optional<at::Generator> gen_,
|
| 118 |
+
bool residual_in_fp32=false,
|
| 119 |
+
bool is_rms_norm=false
|
| 120 |
+
) {
|
| 121 |
+
auto itype = x0.scalar_type();
|
| 122 |
+
auto rtype = residual_.has_value()
|
| 123 |
+
? residual_.value().scalar_type()
|
| 124 |
+
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
|
| 125 |
+
auto wtype = gamma.scalar_type();
|
| 126 |
+
auto otype = itype;
|
| 127 |
+
auto ctype = torch::kFloat32;
|
| 128 |
+
auto mtype = torch::kUInt8;
|
| 129 |
+
|
| 130 |
+
TORCH_CHECK(x0.is_cuda());
|
| 131 |
+
TORCH_CHECK(gamma.is_cuda());
|
| 132 |
+
|
| 133 |
+
TORCH_CHECK(x0.is_contiguous());
|
| 134 |
+
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
|
| 135 |
+
// Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
|
| 136 |
+
// blah is then deallocated.
|
| 137 |
+
std::vector<int64_t> sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)};
|
| 138 |
+
auto sizes = c10::IntArrayRef(sizes_vec);
|
| 139 |
+
TORCH_CHECK(x0.dim() == 2);
|
| 140 |
+
TORCH_CHECK(sizes.size() == 2);
|
| 141 |
+
|
| 142 |
+
const int rows = sizes[0];
|
| 143 |
+
const int cols = sizes[1];
|
| 144 |
+
auto hidden_size = gamma.numel();
|
| 145 |
+
TORCH_CHECK(hidden_size == cols);
|
| 146 |
+
|
| 147 |
+
if (beta_.has_value()) {
|
| 148 |
+
auto beta = beta_.value();
|
| 149 |
+
TORCH_CHECK(beta.dtype() == wtype);
|
| 150 |
+
TORCH_CHECK(beta.is_cuda());
|
| 151 |
+
TORCH_CHECK(beta.is_contiguous());
|
| 152 |
+
TORCH_CHECK(beta.sizes() == gamma.sizes());
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
if (residual_.has_value()) {
|
| 156 |
+
auto residual = residual_.value();
|
| 157 |
+
TORCH_CHECK(residual.is_cuda());
|
| 158 |
+
TORCH_CHECK(residual.is_contiguous());
|
| 159 |
+
TORCH_CHECK(residual.sizes() == sizes);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if (rowscale_.has_value()) {
|
| 163 |
+
auto rowscale = rowscale_.value();
|
| 164 |
+
TORCH_CHECK(rowscale.is_cuda());
|
| 165 |
+
TORCH_CHECK(rowscale.is_contiguous());
|
| 166 |
+
TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
|
| 167 |
+
TORCH_CHECK(rowscale.dtype() == itype);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
if (colscale_.has_value()) {
|
| 171 |
+
auto colscale = colscale_.value();
|
| 172 |
+
TORCH_CHECK(colscale.is_cuda());
|
| 173 |
+
TORCH_CHECK(colscale.is_contiguous());
|
| 174 |
+
TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
|
| 175 |
+
TORCH_CHECK(colscale.dtype() == wtype);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if (x0_subset_.has_value()) {
|
| 179 |
+
auto x0_subset = x0_subset_.value();
|
| 180 |
+
TORCH_CHECK(x0_subset.is_cuda());
|
| 181 |
+
TORCH_CHECK(x0_subset.is_contiguous());
|
| 182 |
+
TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
|
| 183 |
+
TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
|
| 184 |
+
|
| 185 |
+
TORCH_CHECK(z_subset_.has_value());
|
| 186 |
+
auto z_subset = z_subset_.value();
|
| 187 |
+
TORCH_CHECK(z_subset.is_cuda());
|
| 188 |
+
TORCH_CHECK(z_subset.is_contiguous());
|
| 189 |
+
TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
|
| 190 |
+
TORCH_CHECK(z_subset.dtype() == torch::kInt32);
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
|
| 194 |
+
TORCH_CHECK(epsilon >= 0.f);
|
| 195 |
+
|
| 196 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 197 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 198 |
+
at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
|
| 199 |
+
|
| 200 |
+
auto opts = x0.options();
|
| 201 |
+
|
| 202 |
+
bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
|
| 203 |
+
at::Tensor x;
|
| 204 |
+
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
|
| 205 |
+
at::Tensor dmask;
|
| 206 |
+
if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); };
|
| 207 |
+
auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype));
|
| 208 |
+
|
| 209 |
+
auto mu = torch::empty({ rows }, opts.dtype(ctype));
|
| 210 |
+
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
|
| 211 |
+
|
| 212 |
+
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
|
| 213 |
+
|
| 214 |
+
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
| 215 |
+
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
| 216 |
+
TORCH_CHECK(dropout_p < 1.f);
|
| 217 |
+
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
| 218 |
+
launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
|
| 219 |
+
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
| 220 |
+
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
|
| 221 |
+
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
|
| 222 |
+
launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
|
| 223 |
+
|
| 224 |
+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
| 225 |
+
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
| 226 |
+
|
| 227 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 228 |
+
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
| 229 |
+
// Request the kernel launcher.
|
| 230 |
+
auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
| 231 |
+
|
| 232 |
+
// Set the kernel runtime parameters.
|
| 233 |
+
layer_norm::FwdParams ¶ms = launch_params.params;
|
| 234 |
+
params.rows = rows;
|
| 235 |
+
params.cols = cols;
|
| 236 |
+
params.x0 = x0.data_ptr();
|
| 237 |
+
params.x = save_x ? x.data_ptr() : nullptr;
|
| 238 |
+
params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
|
| 239 |
+
params.mu = mu.data_ptr();
|
| 240 |
+
params.rs = rsigma.data_ptr();
|
| 241 |
+
params.gamma = gamma.data_ptr();
|
| 242 |
+
params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr;
|
| 243 |
+
params.z = z.data_ptr();
|
| 244 |
+
params.epsilon = epsilon;
|
| 245 |
+
params.dropout_scale = 1.f / (1.f - dropout_p);
|
| 246 |
+
params.inverse_cols = 1.f / float(params.cols);
|
| 247 |
+
params.rowscale_const = rowscale_const;
|
| 248 |
+
params.is_rms_norm = is_rms_norm;
|
| 249 |
+
|
| 250 |
+
// Query the kernel-specific launch parameters.
|
| 251 |
+
launcher(launch_params, true);
|
| 252 |
+
|
| 253 |
+
at::Tensor workspace, barrier;
|
| 254 |
+
|
| 255 |
+
if (dropout_p > 0.f) {
|
| 256 |
+
// number of times random will be generated per thread, to offset philox counter in thc random
|
| 257 |
+
// state
|
| 258 |
+
int64_t counter_offset = launch_params.elts_per_thread;
|
| 259 |
+
|
| 260 |
+
// See Note [Acquire lock when using random generators]
|
| 261 |
+
{
|
| 262 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 263 |
+
params.philox_args = gen->philox_cuda_state(counter_offset);
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
if( launch_params.barrier_size > 0 ) {
|
| 268 |
+
auto options = x0.options();
|
| 269 |
+
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
|
| 270 |
+
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
|
| 271 |
+
params.workspace = workspace.data_ptr();
|
| 272 |
+
params.barrier = barrier.data_ptr<int>();
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// Launch the kernel.
|
| 276 |
+
launcher(launch_params, false);
|
| 277 |
+
|
| 278 |
+
return { z, x, dmask, mu, rsigma };
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 282 |
+
|
| 283 |
+
std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
|
| 284 |
+
c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
|
| 285 |
+
const at::Tensor &x, // BxSxhidden_size
|
| 286 |
+
c10::optional<const at::Tensor> &x0_, // BxSxhidden_size
|
| 287 |
+
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
|
| 288 |
+
const at::Tensor &mu, // BxS, FP32!
|
| 289 |
+
const at::Tensor &rsigma, // BxS, FP32!
|
| 290 |
+
const at::Tensor &gamma, // hidden_size
|
| 291 |
+
c10::optional<const at::Tensor> &rowscale_, // BxS
|
| 292 |
+
c10::optional<const at::Tensor> &colscale_, // hidden_size
|
| 293 |
+
c10::optional<const at::Tensor> &x0_subset_, // BxS
|
| 294 |
+
c10::optional<const at::Tensor> &z_subset_, // BxS
|
| 295 |
+
const float dropout_p,
|
| 296 |
+
const float rowscale_const,
|
| 297 |
+
const int64_t x0_numrows,
|
| 298 |
+
const bool has_residual,
|
| 299 |
+
bool is_rms_norm=false
|
| 300 |
+
) {
|
| 301 |
+
|
| 302 |
+
auto itype = dz.scalar_type();
|
| 303 |
+
auto rtype = x.scalar_type();
|
| 304 |
+
auto wtype = gamma.scalar_type();
|
| 305 |
+
auto otype = itype;
|
| 306 |
+
auto ctype = torch::kFloat32;
|
| 307 |
+
auto mtype = torch::kUInt8;
|
| 308 |
+
|
| 309 |
+
if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
|
| 310 |
+
|
| 311 |
+
TORCH_CHECK(dz.dtype() == otype);
|
| 312 |
+
TORCH_CHECK(mu.dtype() == ctype);
|
| 313 |
+
TORCH_CHECK(rsigma.dtype() == ctype);
|
| 314 |
+
|
| 315 |
+
TORCH_CHECK(x.is_cuda());
|
| 316 |
+
TORCH_CHECK(dz.is_cuda());
|
| 317 |
+
TORCH_CHECK(mu.is_cuda());
|
| 318 |
+
TORCH_CHECK(rsigma.is_cuda());
|
| 319 |
+
TORCH_CHECK(gamma.is_cuda());
|
| 320 |
+
|
| 321 |
+
TORCH_CHECK(x.is_contiguous());
|
| 322 |
+
TORCH_CHECK(dz.is_contiguous());
|
| 323 |
+
|
| 324 |
+
auto sizes = x.sizes();
|
| 325 |
+
TORCH_CHECK(sizes.size() == 2);
|
| 326 |
+
auto rows = sizes[0];
|
| 327 |
+
auto cols = sizes[1];
|
| 328 |
+
TORCH_CHECK(dz.dim() == 2);
|
| 329 |
+
TORCH_CHECK(dz.size(1) == cols);
|
| 330 |
+
auto hidden_size = gamma.numel();
|
| 331 |
+
TORCH_CHECK(hidden_size == cols);
|
| 332 |
+
|
| 333 |
+
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
|
| 334 |
+
// Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because
|
| 335 |
+
// blah is then deallocated.
|
| 336 |
+
std::vector<int64_t> x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols};
|
| 337 |
+
auto x0_sizes = c10::IntArrayRef(x0_sizes_vec);
|
| 338 |
+
|
| 339 |
+
if (dx_.has_value()) {
|
| 340 |
+
auto dx = dx_.value();
|
| 341 |
+
TORCH_CHECK(dx.dtype() == rtype);
|
| 342 |
+
TORCH_CHECK(dx.is_cuda());
|
| 343 |
+
TORCH_CHECK(dx.is_contiguous());
|
| 344 |
+
TORCH_CHECK(dx.sizes() == sizes);
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
if (dmask_.has_value()) {
|
| 348 |
+
auto dmask = dmask_.value();
|
| 349 |
+
TORCH_CHECK(dmask.dtype() == mtype);
|
| 350 |
+
TORCH_CHECK(dmask.is_cuda());
|
| 351 |
+
TORCH_CHECK(dmask.is_contiguous());
|
| 352 |
+
TORCH_CHECK(dmask.sizes() == x0_sizes);
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
if (rowscale_.has_value()) {
|
| 356 |
+
auto rowscale = rowscale_.value();
|
| 357 |
+
TORCH_CHECK(rowscale.is_cuda());
|
| 358 |
+
TORCH_CHECK(rowscale.is_contiguous());
|
| 359 |
+
TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
|
| 360 |
+
TORCH_CHECK(rowscale.dtype() == itype);
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
if (colscale_.has_value()) {
|
| 364 |
+
auto colscale = colscale_.value();
|
| 365 |
+
TORCH_CHECK(colscale.is_cuda());
|
| 366 |
+
TORCH_CHECK(colscale.is_contiguous());
|
| 367 |
+
TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
|
| 368 |
+
TORCH_CHECK(colscale.dtype() == wtype);
|
| 369 |
+
|
| 370 |
+
TORCH_CHECK(x0_.has_value());
|
| 371 |
+
auto x0 = x0_.value();
|
| 372 |
+
TORCH_CHECK(x0.is_cuda());
|
| 373 |
+
TORCH_CHECK(x0.is_contiguous());
|
| 374 |
+
TORCH_CHECK(x0.sizes() == x0_sizes);
|
| 375 |
+
TORCH_CHECK(x0.dtype() == itype);
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
if (x0_subset_.has_value()) {
|
| 379 |
+
auto x0_subset = x0_subset_.value();
|
| 380 |
+
TORCH_CHECK(x0_subset.is_cuda());
|
| 381 |
+
TORCH_CHECK(x0_subset.is_contiguous());
|
| 382 |
+
TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
|
| 383 |
+
TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
|
| 384 |
+
|
| 385 |
+
TORCH_CHECK(z_subset_.has_value());
|
| 386 |
+
auto z_subset = z_subset_.value();
|
| 387 |
+
TORCH_CHECK(z_subset.is_cuda());
|
| 388 |
+
TORCH_CHECK(z_subset.is_contiguous());
|
| 389 |
+
TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
|
| 390 |
+
TORCH_CHECK(z_subset.dtype() == torch::kInt32);
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
|
| 394 |
+
|
| 395 |
+
TORCH_CHECK(mu.numel() == rows);
|
| 396 |
+
TORCH_CHECK(mu.sizes() == rsigma.sizes());
|
| 397 |
+
|
| 398 |
+
TORCH_CHECK(gamma.numel() == cols);
|
| 399 |
+
|
| 400 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 401 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 402 |
+
at::cuda::CUDAGuard device_guard{(char)dz.get_device()};
|
| 403 |
+
|
| 404 |
+
auto opts = x.options();
|
| 405 |
+
|
| 406 |
+
auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
|
| 407 |
+
at::Tensor dresidual;
|
| 408 |
+
if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
|
| 409 |
+
auto dgamma = torch::empty_like(gamma);
|
| 410 |
+
auto dbeta = torch::empty_like(gamma);
|
| 411 |
+
at::Tensor dcolscale;
|
| 412 |
+
if (colscale_.has_value()) {
|
| 413 |
+
dcolscale = torch::empty_like(colscale_.value());
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
|
| 417 |
+
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
| 418 |
+
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
| 419 |
+
TORCH_CHECK(dropout_p < 1.f);
|
| 420 |
+
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
| 421 |
+
launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
|
| 422 |
+
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
| 423 |
+
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
|
| 424 |
+
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
|
| 425 |
+
launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
|
| 426 |
+
|
| 427 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 428 |
+
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
| 429 |
+
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
| 430 |
+
|
| 431 |
+
launcher(launch_params, true);
|
| 432 |
+
|
| 433 |
+
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 434 |
+
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 435 |
+
at::Tensor dcolscale_part;
|
| 436 |
+
if (colscale_.has_value()) {
|
| 437 |
+
dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 438 |
+
}
|
| 439 |
+
at::Tensor workspace, barrier;
|
| 440 |
+
|
| 441 |
+
layer_norm::BwdParams ¶ms = launch_params.params;
|
| 442 |
+
params.rows = rows;
|
| 443 |
+
params.cols = cols;
|
| 444 |
+
params.x = x.data_ptr();
|
| 445 |
+
params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
|
| 446 |
+
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
|
| 447 |
+
params.mu = mu.data_ptr();
|
| 448 |
+
params.rs = rsigma.data_ptr();
|
| 449 |
+
params.gamma = gamma.data_ptr();
|
| 450 |
+
params.dz = dz.data_ptr();
|
| 451 |
+
params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
|
| 452 |
+
params.dx0 = dx0.data_ptr();
|
| 453 |
+
params.dbeta = dbeta.data_ptr();
|
| 454 |
+
params.dgamma = dgamma.data_ptr();
|
| 455 |
+
params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
|
| 456 |
+
params.dbeta_part = dbeta_part.data_ptr();
|
| 457 |
+
params.dgamma_part = dgamma_part.data_ptr();
|
| 458 |
+
params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
|
| 459 |
+
params.dropout_scale = 1.f / (1.f - dropout_p);
|
| 460 |
+
params.inverse_cols = 1.f / float(params.cols);
|
| 461 |
+
params.rowscale_const = rowscale_const;
|
| 462 |
+
params.is_rms_norm = is_rms_norm;
|
| 463 |
+
|
| 464 |
+
if( launch_params.barrier_size > 0 ) {
|
| 465 |
+
// TODO Any way to avoid this?
|
| 466 |
+
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
|
| 467 |
+
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
|
| 468 |
+
params.workspace = workspace.data_ptr();
|
| 469 |
+
params.barrier = barrier.data_ptr<int>();
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
launcher(launch_params, false);
|
| 473 |
+
|
| 474 |
+
std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
|
| 475 |
+
if (colscale_.has_value()) {
|
| 476 |
+
result.push_back(dcolscale);
|
| 477 |
+
result.push_back(dcolscale_part);
|
| 478 |
+
}
|
| 479 |
+
return result;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 483 |
+
|
| 484 |
+
std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
|
| 485 |
+
const at::Tensor &x0, // Input: BxSxhidden_size
|
| 486 |
+
c10::optional<const at::Tensor> &x1_, // Input: BxSxhidden_size
|
| 487 |
+
c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
|
| 488 |
+
const at::Tensor &gamma0, // hidden_size
|
| 489 |
+
c10::optional<const at::Tensor> &beta0_, // hidden_size
|
| 490 |
+
c10::optional<const at::Tensor> &gamma1_, // hidden_size
|
| 491 |
+
c10::optional<const at::Tensor> &beta1_, // hidden_size
|
| 492 |
+
const float dropout_p,
|
| 493 |
+
const float epsilon,
|
| 494 |
+
c10::optional<at::Generator> gen_,
|
| 495 |
+
bool residual_in_fp32=false,
|
| 496 |
+
bool is_rms_norm=false
|
| 497 |
+
) {
|
| 498 |
+
auto itype = x0.scalar_type();
|
| 499 |
+
auto rtype = residual_.has_value()
|
| 500 |
+
? residual_.value().scalar_type()
|
| 501 |
+
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
|
| 502 |
+
auto wtype = gamma0.scalar_type();
|
| 503 |
+
auto otype = itype;
|
| 504 |
+
auto ctype = torch::kFloat32;
|
| 505 |
+
auto mtype = torch::kUInt8;
|
| 506 |
+
|
| 507 |
+
TORCH_CHECK(x0.is_cuda());
|
| 508 |
+
TORCH_CHECK(gamma0.is_cuda());
|
| 509 |
+
|
| 510 |
+
TORCH_CHECK(x0.is_contiguous());
|
| 511 |
+
const auto sizes = x0.sizes();
|
| 512 |
+
TORCH_CHECK(x0.dim() == 2);
|
| 513 |
+
|
| 514 |
+
const int rows = sizes[0];
|
| 515 |
+
const int cols = sizes[1];
|
| 516 |
+
auto hidden_size = gamma0.numel();
|
| 517 |
+
TORCH_CHECK(hidden_size == cols);
|
| 518 |
+
|
| 519 |
+
if (x1_.has_value()) {
|
| 520 |
+
auto x1 = x1_.value();
|
| 521 |
+
TORCH_CHECK(x1.is_cuda());
|
| 522 |
+
TORCH_CHECK(x1.is_contiguous());
|
| 523 |
+
TORCH_CHECK(x1.sizes() == sizes);
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
if (residual_.has_value()) {
|
| 527 |
+
auto residual = residual_.value();
|
| 528 |
+
TORCH_CHECK(residual.is_cuda());
|
| 529 |
+
TORCH_CHECK(residual.is_contiguous());
|
| 530 |
+
TORCH_CHECK(residual.sizes() == sizes);
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
if (beta0_.has_value()) {
|
| 534 |
+
auto beta0 = beta0_.value();
|
| 535 |
+
TORCH_CHECK(beta0.dtype() == wtype);
|
| 536 |
+
TORCH_CHECK(beta0.is_cuda());
|
| 537 |
+
TORCH_CHECK(beta0.is_contiguous());
|
| 538 |
+
TORCH_CHECK(beta0.sizes() == gamma0.sizes());
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
if (gamma1_.has_value()) {
|
| 542 |
+
auto gamma1 = gamma1_.value();
|
| 543 |
+
TORCH_CHECK(gamma1.dtype() == wtype);
|
| 544 |
+
TORCH_CHECK(gamma1.is_cuda());
|
| 545 |
+
TORCH_CHECK(gamma1.is_contiguous());
|
| 546 |
+
TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
if (beta1_.has_value()) {
|
| 550 |
+
auto beta1 = beta1_.value();
|
| 551 |
+
TORCH_CHECK(beta1.dtype() == wtype);
|
| 552 |
+
TORCH_CHECK(beta1.is_cuda());
|
| 553 |
+
TORCH_CHECK(beta1.is_contiguous());
|
| 554 |
+
TORCH_CHECK(beta1.sizes() == gamma0.sizes());
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
|
| 558 |
+
TORCH_CHECK(epsilon >= 0.f);
|
| 559 |
+
|
| 560 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 561 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 562 |
+
at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
|
| 563 |
+
|
| 564 |
+
auto opts = x0.options();
|
| 565 |
+
|
| 566 |
+
bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
|
| 567 |
+
at::Tensor x;
|
| 568 |
+
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
|
| 569 |
+
at::Tensor dmask0, dmask1;
|
| 570 |
+
if (dropout_p > 0.f) {
|
| 571 |
+
dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype));
|
| 572 |
+
if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); }
|
| 573 |
+
};
|
| 574 |
+
auto z0 = torch::empty(sizes, opts.dtype(otype));
|
| 575 |
+
at::Tensor z1;
|
| 576 |
+
if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); }
|
| 577 |
+
|
| 578 |
+
auto mu = torch::empty({ rows }, opts.dtype(ctype));
|
| 579 |
+
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
|
| 580 |
+
|
| 581 |
+
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
|
| 582 |
+
|
| 583 |
+
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
| 584 |
+
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
| 585 |
+
TORCH_CHECK(dropout_p < 1.f);
|
| 586 |
+
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
| 587 |
+
launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
|
| 588 |
+
|
| 589 |
+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
| 590 |
+
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
| 591 |
+
|
| 592 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 593 |
+
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
| 594 |
+
// Request the kernel launcher.
|
| 595 |
+
auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
| 596 |
+
|
| 597 |
+
// Set the kernel runtime parameters.
|
| 598 |
+
layer_norm::FwdParams ¶ms = launch_params.params;
|
| 599 |
+
params.rows = rows;
|
| 600 |
+
params.cols = cols;
|
| 601 |
+
params.x0 = x0.data_ptr();
|
| 602 |
+
params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
|
| 603 |
+
params.x = save_x ? x.data_ptr() : nullptr;
|
| 604 |
+
params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr;
|
| 605 |
+
params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr;
|
| 606 |
+
params.mu = mu.data_ptr();
|
| 607 |
+
params.rs = rsigma.data_ptr();
|
| 608 |
+
params.gamma = gamma0.data_ptr();
|
| 609 |
+
params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
|
| 610 |
+
params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr;
|
| 611 |
+
params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr;
|
| 612 |
+
params.z = z0.data_ptr();
|
| 613 |
+
params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr;
|
| 614 |
+
params.epsilon = epsilon;
|
| 615 |
+
params.dropout_scale = 1.f / (1.f - dropout_p);
|
| 616 |
+
params.inverse_cols = 1.f / float(params.cols);
|
| 617 |
+
params.is_rms_norm = is_rms_norm;
|
| 618 |
+
|
| 619 |
+
// Query the kernel-specific launch parameters.
|
| 620 |
+
launcher(launch_params, true);
|
| 621 |
+
|
| 622 |
+
at::Tensor workspace, barrier;
|
| 623 |
+
|
| 624 |
+
if (dropout_p > 0.f) {
|
| 625 |
+
// number of times random will be generated per thread, to offset philox counter in thc random
|
| 626 |
+
// state
|
| 627 |
+
int64_t counter_offset = 2 * launch_params.elts_per_thread;
|
| 628 |
+
|
| 629 |
+
// See Note [Acquire lock when using random generators]
|
| 630 |
+
{
|
| 631 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 632 |
+
params.philox_args = gen->philox_cuda_state(counter_offset);
|
| 633 |
+
}
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
if( launch_params.barrier_size > 0 ) {
|
| 637 |
+
auto options = x0.options();
|
| 638 |
+
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
|
| 639 |
+
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
|
| 640 |
+
params.workspace = workspace.data_ptr();
|
| 641 |
+
params.barrier = barrier.data_ptr<int>();
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
// Launch the kernel.
|
| 645 |
+
launcher(launch_params, false);
|
| 646 |
+
|
| 647 |
+
return { z0, z1, x, dmask0, dmask1, mu, rsigma };
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 651 |
+
|
| 652 |
+
std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
|
| 653 |
+
const at::Tensor &dz0, // BxSxhidden_size
|
| 654 |
+
c10::optional<const at::Tensor> &dz1_, // BxSxhidden_size
|
| 655 |
+
c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
|
| 656 |
+
const at::Tensor &x, // BxSxhidden_size
|
| 657 |
+
c10::optional<const at::Tensor> &dmask0_, // BxSxhidden_size
|
| 658 |
+
c10::optional<const at::Tensor> &dmask1_, // BxSxhidden_size
|
| 659 |
+
const at::Tensor &mu, // BxS, FP32!
|
| 660 |
+
const at::Tensor &rsigma, // BxS, FP32!
|
| 661 |
+
const at::Tensor &gamma0, // hidden_size
|
| 662 |
+
c10::optional<const at::Tensor> &gamma1_, // hidden_size
|
| 663 |
+
const float dropout_p,
|
| 664 |
+
const bool has_x1,
|
| 665 |
+
const bool has_residual,
|
| 666 |
+
bool is_rms_norm=false
|
| 667 |
+
) {
|
| 668 |
+
|
| 669 |
+
auto itype = dz0.scalar_type();
|
| 670 |
+
auto rtype = x.scalar_type();
|
| 671 |
+
auto wtype = gamma0.scalar_type();
|
| 672 |
+
auto otype = itype;
|
| 673 |
+
auto ctype = torch::kFloat32;
|
| 674 |
+
auto mtype = torch::kUInt8;
|
| 675 |
+
|
| 676 |
+
if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); }
|
| 677 |
+
|
| 678 |
+
TORCH_CHECK(dz0.dtype() == otype);
|
| 679 |
+
TORCH_CHECK(dz0.dtype() == otype);
|
| 680 |
+
TORCH_CHECK(mu.dtype() == ctype);
|
| 681 |
+
TORCH_CHECK(rsigma.dtype() == ctype);
|
| 682 |
+
|
| 683 |
+
TORCH_CHECK(x.is_cuda());
|
| 684 |
+
TORCH_CHECK(dz0.is_cuda());
|
| 685 |
+
TORCH_CHECK(mu.is_cuda());
|
| 686 |
+
TORCH_CHECK(rsigma.is_cuda());
|
| 687 |
+
TORCH_CHECK(gamma0.is_cuda());
|
| 688 |
+
|
| 689 |
+
TORCH_CHECK(x.is_contiguous());
|
| 690 |
+
TORCH_CHECK(dz0.is_contiguous());
|
| 691 |
+
|
| 692 |
+
auto sizes = x.sizes();
|
| 693 |
+
TORCH_CHECK(sizes.size() == 2);
|
| 694 |
+
auto rows = sizes[0];
|
| 695 |
+
auto cols = sizes[1];
|
| 696 |
+
TORCH_CHECK(dz0.dim() == 2);
|
| 697 |
+
TORCH_CHECK(dz0.size(1) == cols);
|
| 698 |
+
auto hidden_size = gamma0.numel();
|
| 699 |
+
TORCH_CHECK(hidden_size == cols);
|
| 700 |
+
|
| 701 |
+
if (dz1_.has_value()) {
|
| 702 |
+
auto dz1 = dz1_.value();
|
| 703 |
+
TORCH_CHECK(dz1.dtype() == otype);
|
| 704 |
+
TORCH_CHECK(dz1.is_cuda());
|
| 705 |
+
TORCH_CHECK(dz1.is_contiguous());
|
| 706 |
+
TORCH_CHECK(dz1.sizes() == sizes);
|
| 707 |
+
|
| 708 |
+
TORCH_CHECK(gamma1_.has_value());
|
| 709 |
+
auto gamma1 = gamma1_.value();
|
| 710 |
+
TORCH_CHECK(gamma1.dtype() == wtype);
|
| 711 |
+
TORCH_CHECK(gamma1.is_cuda());
|
| 712 |
+
TORCH_CHECK(gamma1.is_contiguous());
|
| 713 |
+
TORCH_CHECK(gamma1.sizes() == gamma0.sizes());
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
if (dx_.has_value()) {
|
| 717 |
+
auto dx = dx_.value();
|
| 718 |
+
TORCH_CHECK(dx.dtype() == rtype);
|
| 719 |
+
TORCH_CHECK(dx.is_cuda());
|
| 720 |
+
TORCH_CHECK(dx.is_contiguous());
|
| 721 |
+
TORCH_CHECK(dx.sizes() == sizes);
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
if (dmask0_.has_value()) {
|
| 725 |
+
auto dmask0 = dmask0_.value();
|
| 726 |
+
TORCH_CHECK(dmask0.dtype() == mtype);
|
| 727 |
+
TORCH_CHECK(dmask0.is_cuda());
|
| 728 |
+
TORCH_CHECK(dmask0.is_contiguous());
|
| 729 |
+
TORCH_CHECK(dmask0.sizes() == sizes);
|
| 730 |
+
|
| 731 |
+
if (has_x1) {
|
| 732 |
+
TORCH_CHECK(dmask1_.has_value());
|
| 733 |
+
auto dmask1 = dmask1_.value();
|
| 734 |
+
TORCH_CHECK(dmask1.dtype() == mtype);
|
| 735 |
+
TORCH_CHECK(dmask1.is_cuda());
|
| 736 |
+
TORCH_CHECK(dmask1.is_contiguous());
|
| 737 |
+
TORCH_CHECK(dmask1.sizes() == sizes);
|
| 738 |
+
}
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
|
| 742 |
+
|
| 743 |
+
TORCH_CHECK(mu.numel() == rows);
|
| 744 |
+
TORCH_CHECK(mu.sizes() == rsigma.sizes());
|
| 745 |
+
|
| 746 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 747 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 748 |
+
at::cuda::CUDAGuard device_guard{(char)dz0.get_device()};
|
| 749 |
+
|
| 750 |
+
auto opts = x.options();
|
| 751 |
+
|
| 752 |
+
auto dx0 = torch::empty(sizes, opts.dtype(itype));
|
| 753 |
+
at::Tensor dx1;
|
| 754 |
+
if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); }
|
| 755 |
+
at::Tensor dresidual;
|
| 756 |
+
if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
|
| 757 |
+
auto dgamma0 = torch::empty_like(gamma0);
|
| 758 |
+
auto dbeta0 = torch::empty_like(gamma0);
|
| 759 |
+
at::Tensor dgamma1, dbeta1;
|
| 760 |
+
if (gamma1_.has_value()) {
|
| 761 |
+
dgamma1 = torch::empty_like(gamma0);
|
| 762 |
+
dbeta1 = torch::empty_like(gamma0);
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
|
| 766 |
+
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
| 767 |
+
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
| 768 |
+
TORCH_CHECK(dropout_p < 1.f);
|
| 769 |
+
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
| 770 |
+
launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
|
| 771 |
+
|
| 772 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 773 |
+
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
| 774 |
+
auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
| 775 |
+
|
| 776 |
+
launcher(launch_params, true);
|
| 777 |
+
|
| 778 |
+
auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 779 |
+
auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
| 780 |
+
at::Tensor dgamma1_part, dbeta1_part;
|
| 781 |
+
if (gamma1_.has_value()) {
|
| 782 |
+
dgamma1_part = torch::zeros_like(dgamma0_part);
|
| 783 |
+
dbeta1_part = torch::zeros_like(dbeta0_part);
|
| 784 |
+
}
|
| 785 |
+
at::Tensor workspace, barrier;
|
| 786 |
+
|
| 787 |
+
layer_norm::BwdParams ¶ms = launch_params.params;
|
| 788 |
+
params.rows = rows;
|
| 789 |
+
params.cols = cols;
|
| 790 |
+
params.x = x.data_ptr();
|
| 791 |
+
params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr;
|
| 792 |
+
params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr;
|
| 793 |
+
params.mu = mu.data_ptr();
|
| 794 |
+
params.rs = rsigma.data_ptr();
|
| 795 |
+
params.gamma = gamma0.data_ptr();
|
| 796 |
+
params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;
|
| 797 |
+
params.dz = dz0.data_ptr();
|
| 798 |
+
params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr;
|
| 799 |
+
params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
|
| 800 |
+
params.dx0 = dx0.data_ptr();
|
| 801 |
+
params.dx1 = has_x1 ? dx1.data_ptr() : nullptr;
|
| 802 |
+
params.dbeta = dbeta0.data_ptr();
|
| 803 |
+
params.dgamma = dgamma0.data_ptr();
|
| 804 |
+
params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr;
|
| 805 |
+
params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr;
|
| 806 |
+
params.dbeta_part = dbeta0_part.data_ptr();
|
| 807 |
+
params.dgamma_part = dgamma0_part.data_ptr();
|
| 808 |
+
params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr;
|
| 809 |
+
params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr;
|
| 810 |
+
params.dropout_scale = 1.f / (1.f - dropout_p);
|
| 811 |
+
params.inverse_cols = 1.f / float(params.cols);
|
| 812 |
+
params.is_rms_norm = is_rms_norm;
|
| 813 |
+
|
| 814 |
+
if( launch_params.barrier_size > 0 ) {
|
| 815 |
+
// TODO Any way to avoid this?
|
| 816 |
+
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
|
| 817 |
+
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
|
| 818 |
+
params.workspace = workspace.data_ptr();
|
| 819 |
+
params.barrier = barrier.data_ptr<int>();
|
| 820 |
+
}
|
| 821 |
+
|
| 822 |
+
launcher(launch_params, false);
|
| 823 |
+
|
| 824 |
+
std::vector<at::Tensor> result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part };
|
| 825 |
+
return result;
|
| 826 |
+
}
|
| 827 |
+
|
| 828 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 829 |
+
|
| 830 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 831 |
+
m.doc() = "CUDA DropoutAddLayerNorm";
|
| 832 |
+
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
|
| 833 |
+
py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta_"),
|
| 834 |
+
py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
|
| 835 |
+
py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
|
| 836 |
+
py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
|
| 837 |
+
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel",
|
| 838 |
+
py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"),
|
| 839 |
+
py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"),
|
| 840 |
+
py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"),
|
| 841 |
+
py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
|
| 842 |
+
m.def("dropout_add_ln_parallel_residual_fwd", &dropout_add_ln_parallel_residual_fwd, "Run Dropout + Add + LayerNorm parallel residual forward kernel",
|
| 843 |
+
py::arg("x0"), py::arg("x1_"), py::arg("residual"), py::arg("gamma0"), py::arg("beta0_"),
|
| 844 |
+
py::arg("gamma1_"), py::arg("beta1_"), py::arg("dropout_p"), py::arg("epsilon"),
|
| 845 |
+
py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
|
| 846 |
+
m.def("dropout_add_ln_parallel_residual_bwd", &dropout_add_ln_parallel_residual_bwd, "Run Dropout + Add + LayerNorm parallel residual backward kernel",
|
| 847 |
+
py::arg("dz0"), py::arg("dz1_"), py::arg("dx_"), py::arg("x"), py::arg("dmask0_"),
|
| 848 |
+
py::arg("dmask1_"), py::arg("mu"), py::arg("rsigma"), py::arg("gamma0"), py::arg("gamma1_"),
|
| 849 |
+
py::arg("dropout_p"), py::arg("has_x1"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
|
| 850 |
+
}
|
ln_bwd_1024.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
ln_bwd_1280.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
ln_bwd_1536.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
ln_bwd_2048.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
ln_bwd_256.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
ln_bwd_2560.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
ln_bwd_3072.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
ln_bwd_4096.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
ln_bwd_512.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
ln_bwd_5120.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
ln_bwd_6144.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
ln_bwd_7168.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
|
ln_bwd_768.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
ln_bwd_8192.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 7 |
+
REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 8 |
+
REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 9 |
+
REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 10 |
+
REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 11 |
+
REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 12 |
+
REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 13 |
+
REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
| 14 |
+
REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 15 |
+
REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
ln_bwd_kernels.cuh
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "ln.h"
|
| 4 |
+
#include "ln_utils.cuh"
|
| 5 |
+
#include "ln_kernel_traits.h"
|
| 6 |
+
#include "static_switch.h"
|
| 7 |
+
|
| 8 |
+
namespace layer_norm {
|
| 9 |
+
|
| 10 |
+
template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
|
| 11 |
+
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
|
| 12 |
+
void ln_bwd_kernel(layer_norm::BwdParams params) {
|
| 13 |
+
|
| 14 |
+
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
|
| 15 |
+
enum { WARPS_M = Ktraits::WARPS_M };
|
| 16 |
+
enum { WARPS_N = Ktraits::WARPS_N };
|
| 17 |
+
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
|
| 18 |
+
enum { COLS = Ktraits::COLS };
|
| 19 |
+
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
|
| 20 |
+
enum { LDGS = Ktraits::LDGS };
|
| 21 |
+
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
|
| 22 |
+
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
|
| 23 |
+
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
|
| 24 |
+
|
| 25 |
+
using input_t = typename Ktraits::input_t;
|
| 26 |
+
using compute_t = typename Ktraits::compute_t;
|
| 27 |
+
using index_t = typename Ktraits::index_t;
|
| 28 |
+
using mask_t = typename Ktraits::mask_t;
|
| 29 |
+
using Ivec = typename Ktraits::Ivec;
|
| 30 |
+
using Rvec = typename Ktraits::Rvec;
|
| 31 |
+
using Ovec = typename Ktraits::Ovec;
|
| 32 |
+
using Wvec = typename Ktraits::Wvec;
|
| 33 |
+
using Cvec = typename Ktraits::Cvec;
|
| 34 |
+
using Mvec = typename Ktraits::Mvec;
|
| 35 |
+
using Reducer = typename Ktraits::Reducer;
|
| 36 |
+
using reduce_t = typename Reducer::Type;
|
| 37 |
+
|
| 38 |
+
extern __shared__ char smem_[];
|
| 39 |
+
|
| 40 |
+
const bool has_residual = params.dresidual != nullptr;
|
| 41 |
+
const bool prenorm = params.dx != nullptr;
|
| 42 |
+
|
| 43 |
+
const index_t tidx = threadIdx.x;
|
| 44 |
+
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
|
| 45 |
+
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
|
| 46 |
+
const index_t lane = tidx % THREADS_PER_WARP;
|
| 47 |
+
const index_t warp = tidx / THREADS_PER_WARP;
|
| 48 |
+
const index_t warp_m = warp / Ktraits::WARPS_N;
|
| 49 |
+
const index_t warp_n = warp % Ktraits::WARPS_N;
|
| 50 |
+
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
|
| 51 |
+
|
| 52 |
+
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
|
| 53 |
+
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
|
| 54 |
+
|
| 55 |
+
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
|
| 56 |
+
|
| 57 |
+
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
|
| 58 |
+
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
|
| 59 |
+
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
|
| 60 |
+
|
| 61 |
+
Cvec dzy_sum[LDGS];
|
| 62 |
+
Cvec dz_sum[LDGS];
|
| 63 |
+
Cvec dcolscale_sum[LDGS];
|
| 64 |
+
|
| 65 |
+
memset(dzy_sum, 0, sizeof(dzy_sum));
|
| 66 |
+
memset(dz_sum, 0, sizeof(dz_sum));
|
| 67 |
+
if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
|
| 68 |
+
|
| 69 |
+
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
|
| 70 |
+
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
|
| 71 |
+
|
| 72 |
+
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
|
| 73 |
+
|
| 74 |
+
Sum<reduce_t> sum;
|
| 75 |
+
|
| 76 |
+
const index_t num_valid_ldgs =
|
| 77 |
+
((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
|
| 78 |
+
|
| 79 |
+
Wvec gamma[LDGS];
|
| 80 |
+
Wvec colscale[LDGS];
|
| 81 |
+
index_t idx = c;
|
| 82 |
+
#pragma unroll
|
| 83 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 84 |
+
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 85 |
+
gamma[it].load_from(params.gamma, idx);
|
| 86 |
+
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
|
| 87 |
+
idx += Ktraits::VEC_COLS_PER_LDG;
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
|
| 91 |
+
// last blocks with syncthreads!
|
| 92 |
+
// grid stride over rows
|
| 93 |
+
#pragma unroll 1
|
| 94 |
+
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
|
| 95 |
+
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
|
| 96 |
+
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
|
| 97 |
+
const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
|
| 98 |
+
const int row_z = !Has_subset ? row + 1 : z_subset[row];
|
| 99 |
+
const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
|
| 100 |
+
const bool load_dz = !Has_subset || row_z > 0;
|
| 101 |
+
const bool save_dx0 = !Has_subset || row_x0 > 0;
|
| 102 |
+
Mvec dmask[LDGS];
|
| 103 |
+
Rvec dx[LDGS];
|
| 104 |
+
compute_t dy[LDGS * NUM_ELTS];
|
| 105 |
+
compute_t y[LDGS * NUM_ELTS];
|
| 106 |
+
compute_t mdy_local = 0.f;
|
| 107 |
+
compute_t mdyy_local = 0.f;
|
| 108 |
+
// If dz is not loaded, then dy should be 0 and we don't care about the value of y.
|
| 109 |
+
if (load_dz) {
|
| 110 |
+
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 111 |
+
index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
| 112 |
+
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
| 113 |
+
#pragma unroll
|
| 114 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 115 |
+
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 116 |
+
Rvec x;
|
| 117 |
+
Ovec dz;
|
| 118 |
+
dz.load_from(params.dz, !Has_subset ? idx_x : idx_z);
|
| 119 |
+
if (prenorm) { dx[it].load_from(params.dx, idx_x); }
|
| 120 |
+
x.load_from(params.x, idx_x);
|
| 121 |
+
if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
|
| 122 |
+
idx_x += Ktraits::VEC_COLS_PER_LDG;
|
| 123 |
+
idx_z += Ktraits::VEC_COLS_PER_LDG;
|
| 124 |
+
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
|
| 125 |
+
#pragma unroll
|
| 126 |
+
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
| 127 |
+
compute_t x_tmp = x.data.elt[jt];
|
| 128 |
+
compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));
|
| 129 |
+
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
|
| 130 |
+
compute_t dz_tmp = dz.data.elt[jt];
|
| 131 |
+
|
| 132 |
+
mdy_local += dy_tmp;
|
| 133 |
+
mdyy_local += dy_tmp * y_tmp;
|
| 134 |
+
|
| 135 |
+
dy[it * NUM_ELTS + jt] = dy_tmp;
|
| 136 |
+
y[it * NUM_ELTS + jt] = y_tmp;
|
| 137 |
+
|
| 138 |
+
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
|
| 139 |
+
dz_sum[it].data.elt[jt] += dz_tmp;
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
} else {
|
| 144 |
+
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 145 |
+
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
| 146 |
+
#pragma unroll
|
| 147 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 148 |
+
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 149 |
+
if (prenorm) { dx[it].load_from(params.dx, idx_x); }
|
| 150 |
+
if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
|
| 151 |
+
idx_x += Ktraits::VEC_COLS_PER_LDG;
|
| 152 |
+
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
|
| 158 |
+
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
|
| 159 |
+
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
|
| 160 |
+
|
| 161 |
+
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 162 |
+
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
| 163 |
+
#pragma unroll
|
| 164 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 165 |
+
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 166 |
+
Ivec dx0;
|
| 167 |
+
Rvec dresidual;
|
| 168 |
+
Ivec x0;
|
| 169 |
+
if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
|
| 170 |
+
#pragma unroll
|
| 171 |
+
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
| 172 |
+
compute_t dx_tmp_res;
|
| 173 |
+
if (load_dz) {
|
| 174 |
+
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
|
| 175 |
+
compute_t y_tmp = y[it * NUM_ELTS + jt];
|
| 176 |
+
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));
|
| 177 |
+
dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
|
| 178 |
+
} else {
|
| 179 |
+
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
|
| 180 |
+
}
|
| 181 |
+
if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
|
| 182 |
+
if (save_dx0) {
|
| 183 |
+
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
|
| 184 |
+
if (Is_dropout) {
|
| 185 |
+
dx0_tmp_res *= params.dropout_scale;
|
| 186 |
+
if (Has_colscale) {
|
| 187 |
+
dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
|
| 188 |
+
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
|
| 189 |
+
} else {
|
| 190 |
+
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
|
| 191 |
+
}
|
| 192 |
+
} else {
|
| 193 |
+
if (Has_colscale) {
|
| 194 |
+
dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
|
| 195 |
+
dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
|
| 196 |
+
} else {
|
| 197 |
+
dx0.data.elt[jt] = dx0_tmp_res;
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
|
| 203 |
+
if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
|
| 204 |
+
idx_x += Ktraits::VEC_COLS_PER_LDG;
|
| 205 |
+
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
} // end: grid stride loop
|
| 210 |
+
|
| 211 |
+
if( WARPS_M == 1 ) {
|
| 212 |
+
idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 213 |
+
#pragma unroll
|
| 214 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 215 |
+
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 216 |
+
dz_sum[it].store_to(params.dbeta_part, idx);
|
| 217 |
+
dzy_sum[it].store_to(params.dgamma_part, idx);
|
| 218 |
+
if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
|
| 219 |
+
idx += Ktraits::VEC_COLS_PER_LDG;
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
} else {
|
| 223 |
+
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
|
| 224 |
+
// Finalize reduction of part dgamma and dbeta for this CTA
|
| 225 |
+
// by reducing over the rows held across the WARPS_M warps
|
| 226 |
+
|
| 227 |
+
// Assumption: blockSize divides hidden size.
|
| 228 |
+
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
|
| 229 |
+
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
|
| 230 |
+
|
| 231 |
+
idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
| 232 |
+
#pragma unroll
|
| 233 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 234 |
+
dz_sum[it].store_to(smem_wgrad, idx);
|
| 235 |
+
idx += THREADS_PER_ROW;
|
| 236 |
+
}
|
| 237 |
+
__syncthreads();
|
| 238 |
+
compute_t cta_dz_sum[NUM_RES];
|
| 239 |
+
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
|
| 240 |
+
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
| 241 |
+
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
| 242 |
+
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
__syncthreads();
|
| 246 |
+
|
| 247 |
+
idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
| 248 |
+
#pragma unroll
|
| 249 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 250 |
+
dzy_sum[it].store_to(smem_wgrad, idx);
|
| 251 |
+
idx += THREADS_PER_ROW;
|
| 252 |
+
}
|
| 253 |
+
__syncthreads();
|
| 254 |
+
compute_t cta_dzy_sum[NUM_RES];
|
| 255 |
+
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
|
| 256 |
+
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
| 257 |
+
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
| 258 |
+
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
compute_t cta_dcolscale_sum[NUM_RES];
|
| 263 |
+
if (Has_colscale) {
|
| 264 |
+
__syncthreads();
|
| 265 |
+
idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
| 266 |
+
#pragma unroll
|
| 267 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 268 |
+
dcolscale_sum[it].store_to(smem_wgrad, idx);
|
| 269 |
+
idx += THREADS_PER_ROW;
|
| 270 |
+
}
|
| 271 |
+
__syncthreads();
|
| 272 |
+
memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
|
| 273 |
+
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
| 274 |
+
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
| 275 |
+
cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
const index_t num_valid_writes
|
| 281 |
+
= (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
|
| 282 |
+
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
|
| 283 |
+
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
|
| 284 |
+
compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
|
| 285 |
+
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
| 286 |
+
if (Is_even_cols || (jt < num_valid_writes)) {
|
| 287 |
+
*dgamma_part = cta_dzy_sum[jt];
|
| 288 |
+
dgamma_part += Ktraits::THREADS_PER_CTA;
|
| 289 |
+
*dbeta_part = cta_dz_sum[jt];
|
| 290 |
+
dbeta_part += Ktraits::THREADS_PER_CTA;
|
| 291 |
+
if (Has_colscale) {
|
| 292 |
+
*dcolscale_part = cta_dcolscale_sum[jt];
|
| 293 |
+
dcolscale_part += Ktraits::THREADS_PER_CTA;
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
|
| 302 |
+
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
|
| 303 |
+
void ln_bwd_finalize_kernel(BwdParams params)
|
| 304 |
+
{
|
| 305 |
+
|
| 306 |
+
using compute_t = typename Kernel_traits::compute_t;
|
| 307 |
+
using weight_t = typename Kernel_traits::weight_t;
|
| 308 |
+
using index_t = typename Kernel_traits::index_t;
|
| 309 |
+
using Reducer = typename Kernel_traits::Reducer;
|
| 310 |
+
using reduce_t = typename Reducer::Type;
|
| 311 |
+
|
| 312 |
+
Sum<reduce_t> sum;
|
| 313 |
+
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
|
| 314 |
+
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
|
| 315 |
+
|
| 316 |
+
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
|
| 317 |
+
|
| 318 |
+
constexpr uint32_t bidm = 0;
|
| 319 |
+
|
| 320 |
+
const uint32_t bidn = blockIdx.x;
|
| 321 |
+
const uint32_t tidx = threadIdx.x;
|
| 322 |
+
const uint32_t warp = tidx / THREADS_PER_WARP;
|
| 323 |
+
const uint32_t lane = tidx % THREADS_PER_WARP;
|
| 324 |
+
|
| 325 |
+
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
|
| 326 |
+
|
| 327 |
+
const uint32_t c = bidn * THREADS_PER_WARP + lane;
|
| 328 |
+
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
|
| 329 |
+
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
|
| 330 |
+
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
|
| 331 |
+
// Each thread sums over NUM_ELT columns.
|
| 332 |
+
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
|
| 333 |
+
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
| 334 |
+
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
| 335 |
+
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
|
| 336 |
+
if (Is_even_cols || col < params.cols) {
|
| 337 |
+
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
|
| 338 |
+
index_t idx = row * params.cols + col;
|
| 339 |
+
|
| 340 |
+
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
|
| 341 |
+
dbeta_part.load_from(params.dbeta_part, idx);
|
| 342 |
+
dgamma_part.load_from(params.dgamma_part, idx);
|
| 343 |
+
if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
|
| 344 |
+
#pragma unroll
|
| 345 |
+
for( int it = 0; it < NUM_ELT; it++ ) {
|
| 346 |
+
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
|
| 347 |
+
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
|
| 348 |
+
if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
void * smem_gamma = smem_;
|
| 353 |
+
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
| 354 |
+
void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
| 355 |
+
|
| 356 |
+
const int write_row = warp;
|
| 357 |
+
const int write_col = lane ^ write_row;
|
| 358 |
+
const int write_idx = write_row * THREADS_PER_WARP + write_col;
|
| 359 |
+
|
| 360 |
+
dgamma_local.store_to(smem_gamma, write_idx);
|
| 361 |
+
dbeta_local.store_to(smem_beta, write_idx);
|
| 362 |
+
if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
|
| 363 |
+
|
| 364 |
+
__syncthreads();
|
| 365 |
+
|
| 366 |
+
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
|
| 367 |
+
void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
| 368 |
+
void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
|
| 369 |
+
void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
// More than one iter iff ROWS_PER_CTA < 32.
|
| 373 |
+
for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
|
| 374 |
+
const int read_row = lane;
|
| 375 |
+
const int read_col = w ^ read_row;
|
| 376 |
+
const int read_idx = read_row * THREADS_PER_WARP + read_col;
|
| 377 |
+
|
| 378 |
+
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
| 379 |
+
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
| 380 |
+
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
|
| 381 |
+
|
| 382 |
+
// Load beta and gamma transposed
|
| 383 |
+
if(read_row < Kernel_traits::ROWS_PER_CTA){
|
| 384 |
+
dbeta_local.load_from(smem_beta, read_idx);
|
| 385 |
+
dgamma_local.load_from(smem_gamma, read_idx);
|
| 386 |
+
if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
// Call reducer on the loaded value(s) and convert.
|
| 390 |
+
#pragma unroll
|
| 391 |
+
for( int it = 0; it < NUM_ELT; it++ ) {
|
| 392 |
+
compute_t b_i = dbeta_local.data.elt[it];
|
| 393 |
+
compute_t g_i = dgamma_local.data.elt[it];
|
| 394 |
+
b_i = reducer.allreduce(b_i, sum);
|
| 395 |
+
g_i = reducer.allreduce(g_i, sum);
|
| 396 |
+
|
| 397 |
+
dgamma_local.data.elt[it] = g_i;
|
| 398 |
+
dbeta_local.data.elt[it] = b_i;
|
| 399 |
+
if (Has_colscale) {
|
| 400 |
+
compute_t cs_i = dcolscale_local.data.elt[it];
|
| 401 |
+
cs_i = reducer.allreduce(cs_i, sum);
|
| 402 |
+
dcolscale_local.data.elt[it] = cs_i;
|
| 403 |
+
}
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
// Leader stores the result at the current column.
|
| 407 |
+
if(lane == 0){
|
| 408 |
+
dgamma_local.store_to(smem_gamma_out, w);
|
| 409 |
+
dbeta_local.store_to(smem_beta_out, w);
|
| 410 |
+
if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
// All writes done.
|
| 416 |
+
__syncthreads();
|
| 417 |
+
|
| 418 |
+
// Pack and store: 2-wide stores with half the threads.
|
| 419 |
+
if (Is_even_cols || col_out * 2 < params.cols) {
|
| 420 |
+
if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
|
| 421 |
+
|
| 422 |
+
using src_t = typename TypeToVec2<compute_t>::Type;
|
| 423 |
+
using dst_t = typename TypeToVec2<weight_t>::Type;
|
| 424 |
+
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
|
| 425 |
+
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
|
| 426 |
+
|
| 427 |
+
dgamma_vec2.load_from(smem_gamma_out, lane);
|
| 428 |
+
dbeta_vec2.load_from(smem_beta_out, lane);
|
| 429 |
+
if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
|
| 430 |
+
#pragma unroll
|
| 431 |
+
for( int it = 0; it < NUM_ELT; it++ ) {
|
| 432 |
+
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
|
| 433 |
+
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
|
| 434 |
+
if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
|
| 435 |
+
}
|
| 436 |
+
dgamma_out2.store_to(params.dgamma, col_out);
|
| 437 |
+
dbeta_out2.store_to(params.dbeta, col_out);
|
| 438 |
+
if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
|
| 439 |
+
}
|
| 440 |
+
}
|
| 441 |
+
}
|
| 442 |
+
}
|
| 443 |
+
} // namespace layer_norm
|
| 444 |
+
|
| 445 |
+
using namespace layer_norm;
|
| 446 |
+
|
| 447 |
+
template<
|
| 448 |
+
typename weight_t,
|
| 449 |
+
typename input_t,
|
| 450 |
+
typename residual_t,
|
| 451 |
+
typename output_t,
|
| 452 |
+
typename compute_t,
|
| 453 |
+
typename index_t,
|
| 454 |
+
int HIDDEN_SIZE,
|
| 455 |
+
int CTAS_PER_ROW,
|
| 456 |
+
int WARPS_M,
|
| 457 |
+
int WARPS_N,
|
| 458 |
+
int BYTES_PER_LDG_MAIN,
|
| 459 |
+
int BYTES_PER_LDG_FINAL
|
| 460 |
+
>
|
| 461 |
+
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
|
| 462 |
+
|
| 463 |
+
using Kernel_traits = Kernel_traits<weight_t,
|
| 464 |
+
input_t,
|
| 465 |
+
residual_t,
|
| 466 |
+
output_t,
|
| 467 |
+
compute_t,
|
| 468 |
+
index_t,
|
| 469 |
+
HIDDEN_SIZE,
|
| 470 |
+
CTAS_PER_ROW,
|
| 471 |
+
WARPS_M,
|
| 472 |
+
WARPS_N,
|
| 473 |
+
BYTES_PER_LDG_MAIN
|
| 474 |
+
>;
|
| 475 |
+
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
|
| 476 |
+
bool has_colscale = launch_params.params.colscale != nullptr;
|
| 477 |
+
bool has_subset = launch_params.params.x0_subset != nullptr;
|
| 478 |
+
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
|
| 479 |
+
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
|
| 480 |
+
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
|
| 481 |
+
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
|
| 482 |
+
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
| 483 |
+
auto kernel = &ln_bwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
|
| 484 |
+
if( configure_params ) {
|
| 485 |
+
int ctas_per_sm;
|
| 486 |
+
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
| 487 |
+
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
|
| 488 |
+
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
| 489 |
+
launch_params.barrier_size = 0;
|
| 490 |
+
launch_params.workspace_bytes = 0;
|
| 491 |
+
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
| 492 |
+
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
| 493 |
+
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
| 494 |
+
* Kernel_traits::WARPS_M
|
| 495 |
+
* Kernel_traits::CTAS_PER_ROW
|
| 496 |
+
* sizeof(typename Kernel_traits::reduce_t)
|
| 497 |
+
* 2;
|
| 498 |
+
}
|
| 499 |
+
return;
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
|
| 503 |
+
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
|
| 504 |
+
}
|
| 505 |
+
auto stream = launch_params.stream;
|
| 506 |
+
auto ctas_per_col = launch_params.params.ctas_per_col;
|
| 507 |
+
|
| 508 |
+
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
| 509 |
+
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
|
| 510 |
+
} else {
|
| 511 |
+
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
| 512 |
+
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
| 513 |
+
void *params_ = (void *)&launch_params.params;
|
| 514 |
+
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream);
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
|
| 518 |
+
weight_t,
|
| 519 |
+
input_t,
|
| 520 |
+
residual_t,
|
| 521 |
+
output_t,
|
| 522 |
+
compute_t,
|
| 523 |
+
index_t,
|
| 524 |
+
HasColscaleConst,
|
| 525 |
+
32 * 32, // THREADS_PER_CTA
|
| 526 |
+
BYTES_PER_LDG_FINAL>;
|
| 527 |
+
|
| 528 |
+
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
|
| 529 |
+
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
|
| 530 |
+
});
|
| 531 |
+
});
|
| 532 |
+
});
|
| 533 |
+
});
|
| 534 |
+
}
|
ln_fwd_1024.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
ln_fwd_1280.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
ln_fwd_1536.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
ln_fwd_2048.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
ln_fwd_256.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
ln_fwd_2560.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
ln_fwd_3072.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
ln_fwd_4096.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
ln_fwd_512.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
ln_fwd_5120.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
ln_fwd_6144.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
|
ln_fwd_7168.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
ln_fwd_768.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
ln_fwd_8192.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_fwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create forward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
| 5 |
+
|
| 6 |
+
REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
|
| 7 |
+
REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
|
| 8 |
+
REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
|
| 9 |
+
REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
|
| 10 |
+
REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
|
| 11 |
+
REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
|
| 12 |
+
REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
|
| 13 |
+
REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
|
| 14 |
+
REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
|
| 15 |
+
REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
|
ln_fwd_kernels.cuh
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef OLD_GENERATOR_PATH
|
| 4 |
+
#include <ATen/CUDAGeneratorImpl.h>
|
| 5 |
+
#else
|
| 6 |
+
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
#include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
|
| 10 |
+
#include <curand_kernel.h>
|
| 11 |
+
|
| 12 |
+
#include "ln.h"
|
| 13 |
+
#include "ln_utils.cuh"
|
| 14 |
+
#include "ln_kernel_traits.h"
|
| 15 |
+
#include "static_switch.h"
|
| 16 |
+
|
| 17 |
+
namespace layer_norm {
|
| 18 |
+
|
| 19 |
+
template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
|
| 20 |
+
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
|
| 21 |
+
void ln_fwd_kernel(FwdParams params) {
|
| 22 |
+
|
| 23 |
+
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
|
| 24 |
+
enum { WARPS_N = Ktraits::WARPS_N };
|
| 25 |
+
enum { WARPS_M = Ktraits::WARPS_M };
|
| 26 |
+
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
|
| 27 |
+
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
|
| 28 |
+
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
|
| 29 |
+
enum { LDGS = Ktraits::LDGS };
|
| 30 |
+
enum { NUM_ELTS = Ktraits::NUM_ELTS };
|
| 31 |
+
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
|
| 32 |
+
|
| 33 |
+
using input_t = typename Ktraits::input_t;
|
| 34 |
+
using residual_t = typename Ktraits::residual_t;
|
| 35 |
+
using output_t = typename Ktraits::output_t;
|
| 36 |
+
using index_t = typename Ktraits::index_t;
|
| 37 |
+
using compute_t = typename Ktraits::compute_t;
|
| 38 |
+
using mask_t = typename Ktraits::mask_t;
|
| 39 |
+
using Ivec = typename Ktraits::Ivec;
|
| 40 |
+
using Rvec = typename Ktraits::Rvec;
|
| 41 |
+
using Ovec = typename Ktraits::Ovec;
|
| 42 |
+
using Wvec = typename Ktraits::Wvec;
|
| 43 |
+
using Cvec = typename Ktraits::Cvec;
|
| 44 |
+
using Mvec = typename Ktraits::Mvec;
|
| 45 |
+
|
| 46 |
+
using Stats = typename Ktraits::Stats;
|
| 47 |
+
using stats_t = typename Stats::stats_t;
|
| 48 |
+
|
| 49 |
+
const bool has_residual = params.residual != nullptr;
|
| 50 |
+
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
|
| 51 |
+
|
| 52 |
+
extern __shared__ char smem_[];
|
| 53 |
+
|
| 54 |
+
const index_t tidx = threadIdx.x;
|
| 55 |
+
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
|
| 56 |
+
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
|
| 57 |
+
const index_t lane = tidx % THREADS_PER_WARP;
|
| 58 |
+
const index_t warp = tidx / THREADS_PER_WARP;
|
| 59 |
+
const index_t warp_m = warp / WARPS_N;
|
| 60 |
+
const index_t warp_n = warp % WARPS_N;
|
| 61 |
+
|
| 62 |
+
const index_t r = bidm * ROWS_PER_CTA + warp_m;
|
| 63 |
+
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
|
| 64 |
+
|
| 65 |
+
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
|
| 66 |
+
|
| 67 |
+
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
|
| 68 |
+
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
|
| 69 |
+
|
| 70 |
+
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
|
| 71 |
+
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
|
| 72 |
+
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
|
| 73 |
+
|
| 74 |
+
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
|
| 75 |
+
curandStatePhilox4_32_10_t state;
|
| 76 |
+
if (Is_dropout) {
|
| 77 |
+
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
| 78 |
+
const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
|
| 79 |
+
curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
|
| 83 |
+
|
| 84 |
+
Wvec gamma[LDGS];
|
| 85 |
+
Wvec beta[LDGS];
|
| 86 |
+
Wvec colscale[LDGS];
|
| 87 |
+
index_t idx = c;
|
| 88 |
+
#pragma unroll
|
| 89 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 90 |
+
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 91 |
+
gamma[it].load_from(params.gamma, idx);
|
| 92 |
+
if (params.beta != nullptr) {
|
| 93 |
+
beta[it].load_from(params.beta, idx);
|
| 94 |
+
} else {
|
| 95 |
+
beta[it].zero_();
|
| 96 |
+
}
|
| 97 |
+
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
|
| 98 |
+
idx += VEC_COLS_PER_LDG;
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
|
| 103 |
+
const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
|
| 104 |
+
const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
|
| 105 |
+
const int row_z = !Has_subset ? row + 1 : z_subset[row];
|
| 106 |
+
const bool load_x0 = !Has_subset || row_x0 > 0;
|
| 107 |
+
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 108 |
+
index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
| 109 |
+
compute_t xf[LDGS * NUM_ELTS];
|
| 110 |
+
#pragma unroll
|
| 111 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 112 |
+
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 113 |
+
Ivec x0;
|
| 114 |
+
Rvec residual;
|
| 115 |
+
Rvec x;
|
| 116 |
+
Mvec dmask;
|
| 117 |
+
if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
|
| 118 |
+
if (has_residual) { residual.load_from(params.residual, idx_x); }
|
| 119 |
+
#pragma unroll
|
| 120 |
+
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
| 121 |
+
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
|
| 122 |
+
// the more efficient curand_uniform4.
|
| 123 |
+
compute_t x_ij;
|
| 124 |
+
if (load_x0) {
|
| 125 |
+
mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
|
| 126 |
+
if (Is_dropout) { dmask.data.elt[jt] = keep; }
|
| 127 |
+
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
|
| 128 |
+
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
|
| 129 |
+
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
|
| 130 |
+
x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
|
| 131 |
+
} else {
|
| 132 |
+
x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
|
| 133 |
+
}
|
| 134 |
+
if (save_x) { x.data.elt[jt] = x_ij; }
|
| 135 |
+
xf[it * NUM_ELTS + jt] = x_ij;
|
| 136 |
+
}
|
| 137 |
+
if (save_x) { x.store_to(params.x, idx_x); }
|
| 138 |
+
if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); }
|
| 139 |
+
idx_x += VEC_COLS_PER_LDG;
|
| 140 |
+
idx_x0 += VEC_COLS_PER_LDG;
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
|
| 145 |
+
const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
|
| 146 |
+
const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
|
| 147 |
+
const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
|
| 148 |
+
auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
|
| 149 |
+
// Need to convert to int, otherwise the subtraction will wrap around.
|
| 150 |
+
const index_t valid_partial_vecs_in_warp =
|
| 151 |
+
std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
|
| 152 |
+
int(THREADS_PER_WARP));
|
| 153 |
+
return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
|
| 154 |
+
};
|
| 155 |
+
stats_t s = stats.template compute<Is_even_cols>(
|
| 156 |
+
xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
|
| 157 |
+
);
|
| 158 |
+
|
| 159 |
+
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
|
| 160 |
+
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
|
| 161 |
+
|
| 162 |
+
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
|
| 163 |
+
mu_ptr[row] = mu;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
|
| 167 |
+
|
| 168 |
+
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
|
| 169 |
+
rs_ptr[row] = rs;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
const bool save_z = !Has_subset || row_z > 0;
|
| 173 |
+
if (save_z) {
|
| 174 |
+
index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c;
|
| 175 |
+
#pragma unroll
|
| 176 |
+
for( int it = 0; it < LDGS; it++ ) {
|
| 177 |
+
if (Is_even_cols || (it < num_valid_ldgs)) {
|
| 178 |
+
Ovec z;
|
| 179 |
+
#pragma unroll
|
| 180 |
+
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
| 181 |
+
compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
|
| 182 |
+
compute_t g_ij = gamma[it].data.elt[jt];
|
| 183 |
+
compute_t b_ij = beta[it].data.elt[jt];
|
| 184 |
+
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
|
| 185 |
+
}
|
| 186 |
+
z.store_to(params.z, idx_z);
|
| 187 |
+
idx_z += VEC_COLS_PER_LDG;
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
} // namespace layer_norm
|
| 196 |
+
|
| 197 |
+
using namespace layer_norm;
|
| 198 |
+
|
| 199 |
+
template<
|
| 200 |
+
typename weight_t,
|
| 201 |
+
typename input_t,
|
| 202 |
+
typename residual_t,
|
| 203 |
+
typename output_t,
|
| 204 |
+
typename compute_t,
|
| 205 |
+
typename index_t,
|
| 206 |
+
int HIDDEN_SIZE,
|
| 207 |
+
int CTAS_PER_ROW,
|
| 208 |
+
int WARPS_M,
|
| 209 |
+
int WARPS_N,
|
| 210 |
+
int BYTES_PER_LDG
|
| 211 |
+
>
|
| 212 |
+
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
|
| 213 |
+
|
| 214 |
+
using Kernel_traits = Kernel_traits<weight_t,
|
| 215 |
+
input_t,
|
| 216 |
+
residual_t,
|
| 217 |
+
output_t,
|
| 218 |
+
compute_t,
|
| 219 |
+
index_t,
|
| 220 |
+
HIDDEN_SIZE,
|
| 221 |
+
CTAS_PER_ROW,
|
| 222 |
+
WARPS_M,
|
| 223 |
+
WARPS_N,
|
| 224 |
+
BYTES_PER_LDG
|
| 225 |
+
>;
|
| 226 |
+
bool has_colscale = launch_params.params.colscale != nullptr;
|
| 227 |
+
bool has_subset = launch_params.params.x0_subset != nullptr;
|
| 228 |
+
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
|
| 229 |
+
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
|
| 230 |
+
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
|
| 231 |
+
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
|
| 232 |
+
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
| 233 |
+
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
|
| 234 |
+
if( configure_params ) {
|
| 235 |
+
int ctas_per_sm;
|
| 236 |
+
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
| 237 |
+
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
|
| 238 |
+
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
| 239 |
+
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
|
| 240 |
+
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
|
| 241 |
+
launch_params.barrier_size = 0;
|
| 242 |
+
launch_params.workspace_bytes = 0;
|
| 243 |
+
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
| 244 |
+
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
| 245 |
+
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
| 246 |
+
* Kernel_traits::WARPS_M
|
| 247 |
+
* Kernel_traits::CTAS_PER_ROW
|
| 248 |
+
* sizeof(typename Kernel_traits::Stats::stats_t)
|
| 249 |
+
* 2;
|
| 250 |
+
}
|
| 251 |
+
return;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
|
| 255 |
+
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
|
| 256 |
+
}
|
| 257 |
+
auto stream = launch_params.stream;
|
| 258 |
+
auto ctas_per_col = launch_params.params.ctas_per_col;
|
| 259 |
+
|
| 260 |
+
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
| 261 |
+
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
|
| 262 |
+
} else {
|
| 263 |
+
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
| 264 |
+
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
| 265 |
+
void *params_ = (void *)&launch_params.params;
|
| 266 |
+
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream);
|
| 267 |
+
}
|
| 268 |
+
});
|
| 269 |
+
});
|
| 270 |
+
});
|
| 271 |
+
});
|
| 272 |
+
}
|
ln_kernel_traits.h
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 4 |
+
|
| 5 |
+
namespace layer_norm {
|
| 6 |
+
template<
|
| 7 |
+
uint32_t HIDDEN_SIZE_,
|
| 8 |
+
typename weight_t_,
|
| 9 |
+
typename input_t_,
|
| 10 |
+
typename residual_t_,
|
| 11 |
+
typename output_t_,
|
| 12 |
+
typename compute_t_,
|
| 13 |
+
typename index_t_,
|
| 14 |
+
uint32_t THREADS_PER_CTA_
|
| 15 |
+
>
|
| 16 |
+
struct Kernel_traits_base {
|
| 17 |
+
|
| 18 |
+
using weight_t = weight_t_;
|
| 19 |
+
using input_t = input_t_;
|
| 20 |
+
using residual_t = residual_t_;
|
| 21 |
+
using output_t = output_t_;
|
| 22 |
+
using compute_t = compute_t_;
|
| 23 |
+
using index_t = index_t_;
|
| 24 |
+
|
| 25 |
+
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
|
| 26 |
+
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
|
| 27 |
+
enum { THREADS_PER_WARP = 32 };
|
| 28 |
+
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 32 |
+
|
| 33 |
+
template<
|
| 34 |
+
uint32_t HIDDEN_SIZE_,
|
| 35 |
+
typename weight_t_,
|
| 36 |
+
typename input_t_,
|
| 37 |
+
typename residual_t_,
|
| 38 |
+
typename output_t_,
|
| 39 |
+
typename compute_t_,
|
| 40 |
+
typename index_t_,
|
| 41 |
+
bool Has_colscale,
|
| 42 |
+
uint32_t THREADS_PER_CTA_,
|
| 43 |
+
uint32_t BYTES_PER_LDG_,
|
| 44 |
+
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
|
| 45 |
+
weight_t_,
|
| 46 |
+
input_t_,
|
| 47 |
+
residual_t_,
|
| 48 |
+
output_t_,
|
| 49 |
+
compute_t_,
|
| 50 |
+
index_t_,
|
| 51 |
+
THREADS_PER_CTA_>
|
| 52 |
+
>
|
| 53 |
+
struct Kernel_traits_finalize : public Base {
|
| 54 |
+
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
|
| 55 |
+
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
|
| 56 |
+
// Bytes per global load from the input.
|
| 57 |
+
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
|
| 58 |
+
// Number of elements fetched by a global load.
|
| 59 |
+
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
|
| 60 |
+
// Bytes per global store of the weights.
|
| 61 |
+
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
|
| 62 |
+
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
|
| 63 |
+
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
|
| 64 |
+
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
|
| 65 |
+
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
|
| 66 |
+
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
|
| 67 |
+
|
| 68 |
+
// Shared memory size to transpose the CTA result.
|
| 69 |
+
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
|
| 70 |
+
// Shared memory size to coalsece the CTA result.
|
| 71 |
+
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
|
| 72 |
+
// Shared memory requirement per CTA.
|
| 73 |
+
static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
|
| 74 |
+
enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
|
| 75 |
+
|
| 76 |
+
// The type of the reducer.
|
| 77 |
+
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
|
| 78 |
+
|
| 79 |
+
// Condition for the whole CTA to participate in syncthreads.
|
| 80 |
+
static_assert(COLS % Base::THREADS_PER_WARP == 0);
|
| 81 |
+
enum { CTAS = COLS / Base::THREADS_PER_WARP };
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
template<
|
| 88 |
+
typename weight_t_,
|
| 89 |
+
typename input_t_,
|
| 90 |
+
typename residual_t_,
|
| 91 |
+
typename output_t_,
|
| 92 |
+
typename compute_t_,
|
| 93 |
+
typename index_t_,
|
| 94 |
+
uint32_t HIDDEN_SIZE_,
|
| 95 |
+
uint32_t CTAS_PER_ROW_,
|
| 96 |
+
uint32_t WARPS_M_,
|
| 97 |
+
uint32_t WARPS_N_,
|
| 98 |
+
uint32_t BYTES_PER_LDG_ = 16,
|
| 99 |
+
typename Base = Kernel_traits_base<
|
| 100 |
+
HIDDEN_SIZE_,
|
| 101 |
+
weight_t_,
|
| 102 |
+
input_t_,
|
| 103 |
+
residual_t_,
|
| 104 |
+
output_t_,
|
| 105 |
+
compute_t_,
|
| 106 |
+
index_t_,
|
| 107 |
+
WARPS_M_*WARPS_N_*THREADS_PER_WARP
|
| 108 |
+
>
|
| 109 |
+
>
|
| 110 |
+
struct Kernel_traits : public Base {
|
| 111 |
+
|
| 112 |
+
using input_t = typename Base::input_t;
|
| 113 |
+
using residual_t = typename Base::residual_t;
|
| 114 |
+
using weight_t = typename Base::weight_t;
|
| 115 |
+
using compute_t = typename Base::compute_t;
|
| 116 |
+
using output_t = typename Base::output_t;
|
| 117 |
+
using index_t = typename Base::index_t;
|
| 118 |
+
// using mask_t = unsigned char;
|
| 119 |
+
using mask_t = bool;
|
| 120 |
+
|
| 121 |
+
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
|
| 122 |
+
enum { WARPS_M = WARPS_M_ };
|
| 123 |
+
enum { WARPS_N = WARPS_N_ };
|
| 124 |
+
enum { COLS = HIDDEN_SIZE_ };
|
| 125 |
+
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
|
| 126 |
+
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
|
| 127 |
+
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
|
| 128 |
+
|
| 129 |
+
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
|
| 130 |
+
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
|
| 131 |
+
enum { ROWS_PER_CTA = WARPS_M };
|
| 132 |
+
|
| 133 |
+
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
|
| 134 |
+
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
|
| 135 |
+
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
|
| 136 |
+
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
|
| 137 |
+
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
|
| 138 |
+
|
| 139 |
+
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
|
| 140 |
+
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
|
| 141 |
+
|
| 142 |
+
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
|
| 143 |
+
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
|
| 144 |
+
|
| 145 |
+
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
|
| 146 |
+
using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;
|
| 147 |
+
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
|
| 148 |
+
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
|
| 149 |
+
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
|
| 150 |
+
using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;
|
| 151 |
+
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
|
| 152 |
+
|
| 153 |
+
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
|
| 154 |
+
static_assert(sizeof(input_t) == sizeof(output_t));
|
| 155 |
+
static_assert(sizeof(input_t) <= sizeof(residual_t));
|
| 156 |
+
// The number of columns fetched per load from input: one per thread.
|
| 157 |
+
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
|
| 158 |
+
// The total number of vectorized loads/stores per hidden vector.
|
| 159 |
+
enum { VEC_COLS = COLS / ELTS_PER_LDG };
|
| 160 |
+
// The number of loads per thread for the input.
|
| 161 |
+
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
|
| 162 |
+
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
|
| 163 |
+
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
|
| 164 |
+
|
| 165 |
+
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
|
| 166 |
+
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
|
| 167 |
+
|
| 168 |
+
};
|
| 169 |
+
|
| 170 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 171 |
+
|
| 172 |
+
} // namespace layer_norm
|
ln_parallel_bwd_1024.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
ln_parallel_bwd_1280.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
ln_parallel_bwd_1536.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
ln_parallel_bwd_2048.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
ln_parallel_bwd_256.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
ln_parallel_bwd_2560.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
ln_parallel_bwd_3072.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
ln_parallel_bwd_4096.cu
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
// Use 8 warps otherwise there's a lot of register spilling
|
| 7 |
+
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
| 16 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 17 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
ln_parallel_bwd_512.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
ln_parallel_bwd_5120.cu
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
// Use 8 warps otherwise there's a lot of register spilling
|
| 7 |
+
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
|
| 16 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
|
| 17 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
|
ln_parallel_bwd_6144.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
ln_parallel_bwd_7168.cu
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ln_parallel_residual_bwd_kernels.cuh"
|
| 2 |
+
|
| 3 |
+
// Create backward launch function and register. Macro signature:
|
| 4 |
+
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
| 5 |
+
|
| 6 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 7 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
| 8 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
|
| 9 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4);
|
| 10 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
|
| 11 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
|
| 12 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4);
|
| 13 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
|
| 14 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4);
|
| 15 |
+
REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4);
|