cranky-coder08 commited on
Commit
9207dd1
·
verified ·
1 Parent(s): 67e9774

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. phivenv/Lib/site-packages/torch/lib/microkernels-prod.lib +3 -0
  3. phivenv/Lib/site-packages/torch/lib/pthreadpool.lib +3 -0
  4. phivenv/Lib/site-packages/torch/lib/sleef.lib +3 -0
  5. phivenv/Lib/site-packages/transformers/integrations/__pycache__/__init__.cpython-39.pyc +0 -0
  6. phivenv/Lib/site-packages/transformers/integrations/__pycache__/flash_attention.cpython-39.pyc +0 -0
  7. phivenv/Lib/site-packages/transformers/integrations/__pycache__/flash_paged.cpython-39.pyc +0 -0
  8. phivenv/Lib/site-packages/transformers/integrations/__pycache__/flex_attention.cpython-39.pyc +0 -0
  9. phivenv/Lib/site-packages/transformers/integrations/__pycache__/fp_quant.cpython-39.pyc +0 -0
  10. phivenv/Lib/site-packages/transformers/integrations/__pycache__/fsdp.cpython-39.pyc +0 -0
  11. phivenv/Lib/site-packages/transformers/integrations/__pycache__/ggml.cpython-39.pyc +0 -0
  12. phivenv/Lib/site-packages/transformers/integrations/__pycache__/higgs.cpython-39.pyc +0 -0
  13. phivenv/Lib/site-packages/transformers/integrations/__pycache__/hqq.cpython-39.pyc +0 -0
  14. phivenv/Lib/site-packages/transformers/integrations/__pycache__/hub_kernels.cpython-39.pyc +0 -0
  15. phivenv/Lib/site-packages/transformers/integrations/__pycache__/integration_utils.cpython-39.pyc +0 -0
  16. phivenv/Lib/site-packages/transformers/integrations/__pycache__/mistral.cpython-39.pyc +0 -0
  17. phivenv/Lib/site-packages/transformers/integrations/__pycache__/mxfp4.cpython-39.pyc +0 -0
  18. phivenv/Lib/site-packages/transformers/integrations/__pycache__/npu_flash_attention.cpython-39.pyc +0 -0
  19. phivenv/Lib/site-packages/transformers/integrations/__pycache__/peft.cpython-39.pyc +0 -0
  20. phivenv/Lib/site-packages/transformers/integrations/__pycache__/quanto.cpython-39.pyc +0 -0
  21. phivenv/Lib/site-packages/transformers/integrations/__pycache__/sdpa_attention.cpython-39.pyc +0 -0
  22. phivenv/Lib/site-packages/transformers/integrations/__pycache__/sdpa_paged.cpython-39.pyc +0 -0
  23. phivenv/Lib/site-packages/transformers/integrations/__pycache__/spqr.cpython-39.pyc +0 -0
  24. phivenv/Lib/site-packages/transformers/integrations/__pycache__/tensor_parallel.cpython-39.pyc +0 -0
  25. phivenv/Lib/site-packages/transformers/integrations/__pycache__/tiktoken.cpython-39.pyc +0 -0
  26. phivenv/Lib/site-packages/transformers/integrations/__pycache__/tpu.cpython-39.pyc +0 -0
  27. phivenv/Lib/site-packages/transformers/integrations/__pycache__/vptq.cpython-39.pyc +0 -0
  28. phivenv/Lib/site-packages/transformers/kernels/__init__.py +0 -0
  29. phivenv/Lib/site-packages/transformers/kernels/__pycache__/__init__.cpython-39.pyc +0 -0
  30. phivenv/Lib/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp +40 -0
  31. phivenv/Lib/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h +32 -0
  32. phivenv/Lib/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu +156 -0
  33. phivenv/Lib/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh +1467 -0
  34. phivenv/Lib/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h +29 -0
  35. phivenv/Lib/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh +1327 -0
  36. phivenv/Lib/site-packages/transformers/kernels/deta/ms_deform_attn.h +61 -0
  37. phivenv/Lib/site-packages/transformers/kernels/deta/vision.cpp +16 -0
  38. phivenv/Lib/site-packages/transformers/kernels/falcon_mamba/__init__.py +15 -0
  39. phivenv/Lib/site-packages/transformers/kernels/falcon_mamba/__pycache__/__init__.cpython-39.pyc +0 -0
  40. phivenv/Lib/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-39.pyc +0 -0
  41. phivenv/Lib/site-packages/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +525 -0
  42. phivenv/Lib/site-packages/transformers/kernels/mra/cuda_kernel.cu +383 -0
  43. phivenv/Lib/site-packages/transformers/kernels/mra/cuda_kernel.h +59 -0
  44. phivenv/Lib/site-packages/transformers/kernels/mra/cuda_launch.cu +154 -0
  45. phivenv/Lib/site-packages/transformers/kernels/mra/cuda_launch.h +39 -0
  46. phivenv/Lib/site-packages/transformers/kernels/mra/torch_extension.cpp +78 -0
  47. phivenv/Lib/site-packages/transformers/kernels/rwkv/wkv_cuda.cu +187 -0
  48. phivenv/Lib/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu +186 -0
  49. phivenv/Lib/site-packages/transformers/kernels/rwkv/wkv_op.cpp +66 -0
  50. phivenv/Lib/site-packages/transformers/kernels/yoso/common.h +10 -0
.gitattributes CHANGED
@@ -119,3 +119,6 @@ phivenv/Lib/site-packages/torch/lib/libprotobuf-lite.lib filter=lfs diff=lfs mer
119
  phivenv/Lib/site-packages/torch/lib/kineto.lib filter=lfs diff=lfs merge=lfs -text
120
  phivenv/Lib/site-packages/torch/lib/libprotobuf.lib filter=lfs diff=lfs merge=lfs -text
121
  phivenv/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text
 
 
 
 
119
  phivenv/Lib/site-packages/torch/lib/kineto.lib filter=lfs diff=lfs merge=lfs -text
120
  phivenv/Lib/site-packages/torch/lib/libprotobuf.lib filter=lfs diff=lfs merge=lfs -text
121
  phivenv/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text
122
+ phivenv/Lib/site-packages/torch/lib/pthreadpool.lib filter=lfs diff=lfs merge=lfs -text
123
+ phivenv/Lib/site-packages/torch/lib/microkernels-prod.lib filter=lfs diff=lfs merge=lfs -text
124
+ phivenv/Lib/site-packages/torch/lib/sleef.lib filter=lfs diff=lfs merge=lfs -text
phivenv/Lib/site-packages/torch/lib/microkernels-prod.lib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d40e51bd2bb25f60758fce200d2eea78bba5eb96e9771bec4d803ec6786ec11d
3
+ size 20267954
phivenv/Lib/site-packages/torch/lib/pthreadpool.lib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88e2da3e5e79047835a21d009707ea855b72d9dc78e46f179886a41f075f75c8
3
+ size 768704
phivenv/Lib/site-packages/torch/lib/sleef.lib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1164e5cee1a9bd4621704221b15f8a24ea1536bc5cf267579c3aa0397d539bbc
3
+ size 8776772
phivenv/Lib/site-packages/transformers/integrations/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (5.11 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/flash_attention.cpython-39.pyc ADDED
Binary file (2.29 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/flash_paged.cpython-39.pyc ADDED
Binary file (2.85 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/flex_attention.cpython-39.pyc ADDED
Binary file (8.57 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/fp_quant.cpython-39.pyc ADDED
Binary file (974 Bytes). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/fsdp.cpython-39.pyc ADDED
Binary file (1.11 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/ggml.cpython-39.pyc ADDED
Binary file (17.1 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/higgs.cpython-39.pyc ADDED
Binary file (17.3 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/hqq.cpython-39.pyc ADDED
Binary file (3.04 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/hub_kernels.cpython-39.pyc ADDED
Binary file (3.9 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/integration_utils.cpython-39.pyc ADDED
Binary file (81.9 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/mistral.cpython-39.pyc ADDED
Binary file (4.4 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/mxfp4.cpython-39.pyc ADDED
Binary file (10.7 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/npu_flash_attention.cpython-39.pyc ADDED
Binary file (2.47 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/peft.cpython-39.pyc ADDED
Binary file (21.2 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/quanto.cpython-39.pyc ADDED
Binary file (2.92 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/sdpa_attention.cpython-39.pyc ADDED
Binary file (2.56 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/sdpa_paged.cpython-39.pyc ADDED
Binary file (1.58 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/spqr.cpython-39.pyc ADDED
Binary file (3.1 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/tensor_parallel.cpython-39.pyc ADDED
Binary file (36.1 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/tiktoken.cpython-39.pyc ADDED
Binary file (1.71 kB). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/tpu.cpython-39.pyc ADDED
Binary file (846 Bytes). View file
 
phivenv/Lib/site-packages/transformers/integrations/__pycache__/vptq.cpython-39.pyc ADDED
Binary file (2.72 kB). View file
 
phivenv/Lib/site-packages/transformers/kernels/__init__.py ADDED
File without changes
phivenv/Lib/site-packages/transformers/kernels/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (164 Bytes). View file
 
phivenv/Lib/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+
17
+ at::Tensor
18
+ ms_deform_attn_cpu_forward(
19
+ const at::Tensor &value,
20
+ const at::Tensor &spatial_shapes,
21
+ const at::Tensor &level_start_index,
22
+ const at::Tensor &sampling_loc,
23
+ const at::Tensor &attn_weight,
24
+ const int im2col_step)
25
+ {
26
+ AT_ERROR("Not implement on cpu");
27
+ }
28
+
29
+ std::vector<at::Tensor>
30
+ ms_deform_attn_cpu_backward(
31
+ const at::Tensor &value,
32
+ const at::Tensor &spatial_shapes,
33
+ const at::Tensor &level_start_index,
34
+ const at::Tensor &sampling_loc,
35
+ const at::Tensor &attn_weight,
36
+ const at::Tensor &grad_output,
37
+ const int im2col_step)
38
+ {
39
+ AT_ERROR("Not implement on cpu");
40
+ }
phivenv/Lib/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor
15
+ ms_deform_attn_cpu_forward(
16
+ const at::Tensor &value,
17
+ const at::Tensor &spatial_shapes,
18
+ const at::Tensor &level_start_index,
19
+ const at::Tensor &sampling_loc,
20
+ const at::Tensor &attn_weight,
21
+ const int im2col_step);
22
+
23
+ std::vector<at::Tensor>
24
+ ms_deform_attn_cpu_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
phivenv/Lib/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "cuda/ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+ #pragma once
20
+ #include <torch/extension.h>
21
+
22
+
23
+ at::Tensor ms_deform_attn_cuda_forward(
24
+ const at::Tensor &value,
25
+ const at::Tensor &spatial_shapes,
26
+ const at::Tensor &level_start_index,
27
+ const at::Tensor &sampling_loc,
28
+ const at::Tensor &attn_weight,
29
+ const int im2col_step)
30
+ {
31
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
32
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
33
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
34
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
35
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
36
+
37
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
38
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
39
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
40
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
41
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
42
+
43
+ const int batch = value.size(0);
44
+ const int spatial_size = value.size(1);
45
+ const int num_heads = value.size(2);
46
+ const int channels = value.size(3);
47
+
48
+ const int num_levels = spatial_shapes.size(0);
49
+
50
+ const int num_query = sampling_loc.size(1);
51
+ const int num_point = sampling_loc.size(4);
52
+
53
+ const int im2col_step_ = std::min(batch, im2col_step);
54
+
55
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
56
+
57
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
58
+
59
+ const int batch_n = im2col_step_;
60
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
61
+ auto per_value_size = spatial_size * num_heads * channels;
62
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
63
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
64
+ for (int n = 0; n < batch/im2col_step_; ++n)
65
+ {
66
+ auto columns = output_n.select(0, n);
67
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
68
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
69
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
70
+ spatial_shapes.data<int64_t>(),
71
+ level_start_index.data<int64_t>(),
72
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
73
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
74
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
75
+ columns.data<scalar_t>());
76
+
77
+ }));
78
+ }
79
+
80
+ output = output.view({batch, num_query, num_heads*channels});
81
+
82
+ return output;
83
+ }
84
+
85
+
86
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
87
+ const at::Tensor &value,
88
+ const at::Tensor &spatial_shapes,
89
+ const at::Tensor &level_start_index,
90
+ const at::Tensor &sampling_loc,
91
+ const at::Tensor &attn_weight,
92
+ const at::Tensor &grad_output,
93
+ const int im2col_step)
94
+ {
95
+
96
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
97
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
98
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
99
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
100
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
101
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
102
+
103
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
104
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
105
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
106
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
107
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
108
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
109
+
110
+ const int batch = value.size(0);
111
+ const int spatial_size = value.size(1);
112
+ const int num_heads = value.size(2);
113
+ const int channels = value.size(3);
114
+
115
+ const int num_levels = spatial_shapes.size(0);
116
+
117
+ const int num_query = sampling_loc.size(1);
118
+ const int num_point = sampling_loc.size(4);
119
+
120
+ const int im2col_step_ = std::min(batch, im2col_step);
121
+
122
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
123
+
124
+ auto grad_value = at::zeros_like(value);
125
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
126
+ auto grad_attn_weight = at::zeros_like(attn_weight);
127
+
128
+ const int batch_n = im2col_step_;
129
+ auto per_value_size = spatial_size * num_heads * channels;
130
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
131
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
132
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
133
+
134
+ for (int n = 0; n < batch/im2col_step_; ++n)
135
+ {
136
+ auto grad_output_g = grad_output_n.select(0, n);
137
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
138
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
139
+ grad_output_g.data<scalar_t>(),
140
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
141
+ spatial_shapes.data<int64_t>(),
142
+ level_start_index.data<int64_t>(),
143
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
144
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
145
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
146
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
147
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
148
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
149
+
150
+ }));
151
+ }
152
+
153
+ return {
154
+ grad_value, grad_sampling_loc, grad_attn_weight
155
+ };
156
+ }
phivenv/Lib/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh ADDED
@@ -0,0 +1,1467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <cuda.h>
14
+ #include <cuda_runtime.h>
15
+
16
+ #include <cstdio>
17
+ #include <algorithm>
18
+ #include <cstring>
19
+
20
+ #include <ATen/ATen.h>
21
+ #include <ATen/cuda/CUDAContext.h>
22
+
23
+ #include <THC/THCAtomics.cuh>
24
+
25
+ #define CUDA_KERNEL_LOOP(i, n) \
26
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
27
+ i < (n); \
28
+ i += blockDim.x * gridDim.x)
29
+
30
+
31
+ at::Tensor ms_deform_attn_cuda_forward(
32
+ const at::Tensor &value,
33
+ const at::Tensor &spatial_shapes,
34
+ const at::Tensor &level_start_index,
35
+ const at::Tensor &sampling_loc,
36
+ const at::Tensor &attn_weight,
37
+ const int im2col_step)
38
+ {
39
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
40
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
41
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
42
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
43
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
44
+
45
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
46
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
47
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
48
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
49
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
50
+
51
+ const int batch = value.size(0);
52
+ const int spatial_size = value.size(1);
53
+ const int num_heads = value.size(2);
54
+ const int channels = value.size(3);
55
+
56
+ const int num_levels = spatial_shapes.size(0);
57
+
58
+ const int num_query = sampling_loc.size(1);
59
+ const int num_point = sampling_loc.size(4);
60
+
61
+ const int im2col_step_ = std::min(batch, im2col_step);
62
+
63
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
64
+
65
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
66
+
67
+ const int batch_n = im2col_step_;
68
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
69
+ auto per_value_size = spatial_size * num_heads * channels;
70
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
71
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
72
+ for (int n = 0; n < batch/im2col_step_; ++n)
73
+ {
74
+ auto columns = output_n.select(0, n);
75
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
76
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
77
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
78
+ spatial_shapes.data<int64_t>(),
79
+ level_start_index.data<int64_t>(),
80
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
81
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
82
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
83
+ columns.data<scalar_t>());
84
+
85
+ }));
86
+ }
87
+
88
+ output = output.view({batch, num_query, num_heads*channels});
89
+
90
+ return output;
91
+ }
92
+
93
+
94
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
95
+ const at::Tensor &value,
96
+ const at::Tensor &spatial_shapes,
97
+ const at::Tensor &level_start_index,
98
+ const at::Tensor &sampling_loc,
99
+ const at::Tensor &attn_weight,
100
+ const at::Tensor &grad_output,
101
+ const int im2col_step)
102
+ {
103
+
104
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
105
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
106
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
107
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
108
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
109
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
110
+
111
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
112
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
113
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
114
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
115
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
116
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
117
+
118
+ const int batch = value.size(0);
119
+ const int spatial_size = value.size(1);
120
+ const int num_heads = value.size(2);
121
+ const int channels = value.size(3);
122
+
123
+ const int num_levels = spatial_shapes.size(0);
124
+
125
+ const int num_query = sampling_loc.size(1);
126
+ const int num_point = sampling_loc.size(4);
127
+
128
+ const int im2col_step_ = std::min(batch, im2col_step);
129
+
130
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
131
+
132
+ auto grad_value = at::zeros_like(value);
133
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
134
+ auto grad_attn_weight = at::zeros_like(attn_weight);
135
+
136
+ const int batch_n = im2col_step_;
137
+ auto per_value_size = spatial_size * num_heads * channels;
138
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
139
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
140
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
141
+
142
+ for (int n = 0; n < batch/im2col_step_; ++n)
143
+ {
144
+ auto grad_output_g = grad_output_n.select(0, n);
145
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
146
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
147
+ grad_output_g.data<scalar_t>(),
148
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
149
+ spatial_shapes.data<int64_t>(),
150
+ level_start_index.data<int64_t>(),
151
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
152
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
153
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
154
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
155
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
156
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
157
+
158
+ }));
159
+ }
160
+
161
+ return {
162
+ grad_value, grad_sampling_loc, grad_attn_weight
163
+ };
164
+ }
165
+
166
+ const int CUDA_NUM_THREADS = 1024;
167
+ inline int GET_BLOCKS(const int N, const int num_threads)
168
+ {
169
+ return (N + num_threads - 1) / num_threads;
170
+ }
171
+
172
+
173
+ template <typename scalar_t>
174
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
175
+ const int &height, const int &width, const int &nheads, const int &channels,
176
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
177
+ {
178
+ const int h_low = floor(h);
179
+ const int w_low = floor(w);
180
+ const int h_high = h_low + 1;
181
+ const int w_high = w_low + 1;
182
+
183
+ const scalar_t lh = h - h_low;
184
+ const scalar_t lw = w - w_low;
185
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
186
+
187
+ const int w_stride = nheads * channels;
188
+ const int h_stride = width * w_stride;
189
+ const int h_low_ptr_offset = h_low * h_stride;
190
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
191
+ const int w_low_ptr_offset = w_low * w_stride;
192
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
193
+ const int base_ptr = m * channels + c;
194
+
195
+ scalar_t v1 = 0;
196
+ if (h_low >= 0 && w_low >= 0)
197
+ {
198
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
199
+ v1 = bottom_data[ptr1];
200
+ }
201
+ scalar_t v2 = 0;
202
+ if (h_low >= 0 && w_high <= width - 1)
203
+ {
204
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
205
+ v2 = bottom_data[ptr2];
206
+ }
207
+ scalar_t v3 = 0;
208
+ if (h_high <= height - 1 && w_low >= 0)
209
+ {
210
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
211
+ v3 = bottom_data[ptr3];
212
+ }
213
+ scalar_t v4 = 0;
214
+ if (h_high <= height - 1 && w_high <= width - 1)
215
+ {
216
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
217
+ v4 = bottom_data[ptr4];
218
+ }
219
+
220
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
221
+
222
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
223
+ return val;
224
+ }
225
+
226
+
227
+ template <typename scalar_t>
228
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
229
+ const int &height, const int &width, const int &nheads, const int &channels,
230
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
231
+ const scalar_t &top_grad,
232
+ const scalar_t &attn_weight,
233
+ scalar_t* &grad_value,
234
+ scalar_t* grad_sampling_loc,
235
+ scalar_t* grad_attn_weight)
236
+ {
237
+ const int h_low = floor(h);
238
+ const int w_low = floor(w);
239
+ const int h_high = h_low + 1;
240
+ const int w_high = w_low + 1;
241
+
242
+ const scalar_t lh = h - h_low;
243
+ const scalar_t lw = w - w_low;
244
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
245
+
246
+ const int w_stride = nheads * channels;
247
+ const int h_stride = width * w_stride;
248
+ const int h_low_ptr_offset = h_low * h_stride;
249
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
250
+ const int w_low_ptr_offset = w_low * w_stride;
251
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
252
+ const int base_ptr = m * channels + c;
253
+
254
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
255
+ const scalar_t top_grad_value = top_grad * attn_weight;
256
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
257
+
258
+ scalar_t v1 = 0;
259
+ if (h_low >= 0 && w_low >= 0)
260
+ {
261
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
262
+ v1 = bottom_data[ptr1];
263
+ grad_h_weight -= hw * v1;
264
+ grad_w_weight -= hh * v1;
265
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
266
+ }
267
+ scalar_t v2 = 0;
268
+ if (h_low >= 0 && w_high <= width - 1)
269
+ {
270
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
271
+ v2 = bottom_data[ptr2];
272
+ grad_h_weight -= lw * v2;
273
+ grad_w_weight += hh * v2;
274
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
275
+ }
276
+ scalar_t v3 = 0;
277
+ if (h_high <= height - 1 && w_low >= 0)
278
+ {
279
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
280
+ v3 = bottom_data[ptr3];
281
+ grad_h_weight += hw * v3;
282
+ grad_w_weight -= lh * v3;
283
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
284
+ }
285
+ scalar_t v4 = 0;
286
+ if (h_high <= height - 1 && w_high <= width - 1)
287
+ {
288
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
289
+ v4 = bottom_data[ptr4];
290
+ grad_h_weight += lw * v4;
291
+ grad_w_weight += lh * v4;
292
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
293
+ }
294
+
295
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
296
+ *grad_attn_weight = top_grad * val;
297
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
298
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
299
+ }
300
+
301
+
302
+ template <typename scalar_t>
303
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
304
+ const int &height, const int &width, const int &nheads, const int &channels,
305
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
306
+ const scalar_t &top_grad,
307
+ const scalar_t &attn_weight,
308
+ scalar_t* &grad_value,
309
+ scalar_t* grad_sampling_loc,
310
+ scalar_t* grad_attn_weight)
311
+ {
312
+ const int h_low = floor(h);
313
+ const int w_low = floor(w);
314
+ const int h_high = h_low + 1;
315
+ const int w_high = w_low + 1;
316
+
317
+ const scalar_t lh = h - h_low;
318
+ const scalar_t lw = w - w_low;
319
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
320
+
321
+ const int w_stride = nheads * channels;
322
+ const int h_stride = width * w_stride;
323
+ const int h_low_ptr_offset = h_low * h_stride;
324
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
325
+ const int w_low_ptr_offset = w_low * w_stride;
326
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
327
+ const int base_ptr = m * channels + c;
328
+
329
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
330
+ const scalar_t top_grad_value = top_grad * attn_weight;
331
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
332
+
333
+ scalar_t v1 = 0;
334
+ if (h_low >= 0 && w_low >= 0)
335
+ {
336
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
337
+ v1 = bottom_data[ptr1];
338
+ grad_h_weight -= hw * v1;
339
+ grad_w_weight -= hh * v1;
340
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
341
+ }
342
+ scalar_t v2 = 0;
343
+ if (h_low >= 0 && w_high <= width - 1)
344
+ {
345
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
346
+ v2 = bottom_data[ptr2];
347
+ grad_h_weight -= lw * v2;
348
+ grad_w_weight += hh * v2;
349
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
350
+ }
351
+ scalar_t v3 = 0;
352
+ if (h_high <= height - 1 && w_low >= 0)
353
+ {
354
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
355
+ v3 = bottom_data[ptr3];
356
+ grad_h_weight += hw * v3;
357
+ grad_w_weight -= lh * v3;
358
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
359
+ }
360
+ scalar_t v4 = 0;
361
+ if (h_high <= height - 1 && w_high <= width - 1)
362
+ {
363
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
364
+ v4 = bottom_data[ptr4];
365
+ grad_h_weight += lw * v4;
366
+ grad_w_weight += lh * v4;
367
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
368
+ }
369
+
370
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
371
+ atomicAdd(grad_attn_weight, top_grad * val);
372
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
373
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
374
+ }
375
+
376
+
377
+ template <typename scalar_t>
378
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
379
+ const scalar_t *data_value,
380
+ const int64_t *data_spatial_shapes,
381
+ const int64_t *data_level_start_index,
382
+ const scalar_t *data_sampling_loc,
383
+ const scalar_t *data_attn_weight,
384
+ const int batch_size,
385
+ const int spatial_size,
386
+ const int num_heads,
387
+ const int channels,
388
+ const int num_levels,
389
+ const int num_query,
390
+ const int num_point,
391
+ scalar_t *data_col)
392
+ {
393
+ CUDA_KERNEL_LOOP(index, n)
394
+ {
395
+ int _temp = index;
396
+ const int c_col = _temp % channels;
397
+ _temp /= channels;
398
+ const int sampling_index = _temp;
399
+ const int m_col = _temp % num_heads;
400
+ _temp /= num_heads;
401
+ const int q_col = _temp % num_query;
402
+ _temp /= num_query;
403
+ const int b_col = _temp;
404
+
405
+ scalar_t *data_col_ptr = data_col + index;
406
+ int data_weight_ptr = sampling_index * num_levels * num_point;
407
+ int data_loc_w_ptr = data_weight_ptr << 1;
408
+ const int qid_stride = num_heads * channels;
409
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
410
+ scalar_t col = 0;
411
+
412
+ for (int l_col=0; l_col < num_levels; ++l_col)
413
+ {
414
+ const int level_start_id = data_level_start_index[l_col];
415
+ const int spatial_h_ptr = l_col << 1;
416
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
417
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
418
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
419
+ for (int p_col=0; p_col < num_point; ++p_col)
420
+ {
421
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
422
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
423
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
424
+
425
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
426
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
427
+
428
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
429
+ {
430
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
431
+ }
432
+
433
+ data_weight_ptr += 1;
434
+ data_loc_w_ptr += 2;
435
+ }
436
+ }
437
+ *data_col_ptr = col;
438
+ }
439
+ }
440
+
441
+ template <typename scalar_t, unsigned int blockSize>
442
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
443
+ const scalar_t *grad_col,
444
+ const scalar_t *data_value,
445
+ const int64_t *data_spatial_shapes,
446
+ const int64_t *data_level_start_index,
447
+ const scalar_t *data_sampling_loc,
448
+ const scalar_t *data_attn_weight,
449
+ const int batch_size,
450
+ const int spatial_size,
451
+ const int num_heads,
452
+ const int channels,
453
+ const int num_levels,
454
+ const int num_query,
455
+ const int num_point,
456
+ scalar_t *grad_value,
457
+ scalar_t *grad_sampling_loc,
458
+ scalar_t *grad_attn_weight)
459
+ {
460
+ CUDA_KERNEL_LOOP(index, n)
461
+ {
462
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
463
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
464
+ unsigned int tid = threadIdx.x;
465
+ int _temp = index;
466
+ const int c_col = _temp % channels;
467
+ _temp /= channels;
468
+ const int sampling_index = _temp;
469
+ const int m_col = _temp % num_heads;
470
+ _temp /= num_heads;
471
+ const int q_col = _temp % num_query;
472
+ _temp /= num_query;
473
+ const int b_col = _temp;
474
+
475
+ const scalar_t top_grad = grad_col[index];
476
+
477
+ int data_weight_ptr = sampling_index * num_levels * num_point;
478
+ int data_loc_w_ptr = data_weight_ptr << 1;
479
+ const int grad_sampling_ptr = data_weight_ptr;
480
+ grad_sampling_loc += grad_sampling_ptr << 1;
481
+ grad_attn_weight += grad_sampling_ptr;
482
+ const int grad_weight_stride = 1;
483
+ const int grad_loc_stride = 2;
484
+ const int qid_stride = num_heads * channels;
485
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
486
+
487
+ for (int l_col=0; l_col < num_levels; ++l_col)
488
+ {
489
+ const int level_start_id = data_level_start_index[l_col];
490
+ const int spatial_h_ptr = l_col << 1;
491
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
492
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
493
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
494
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
495
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
496
+
497
+ for (int p_col=0; p_col < num_point; ++p_col)
498
+ {
499
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
500
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
501
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
502
+
503
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
504
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
505
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
506
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
507
+ *(cache_grad_attn_weight+threadIdx.x)=0;
508
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
509
+ {
510
+ ms_deform_attn_col2im_bilinear(
511
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
512
+ top_grad, weight, grad_value_ptr,
513
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
514
+ }
515
+
516
+ __syncthreads();
517
+ if (tid == 0)
518
+ {
519
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
520
+ int sid=2;
521
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
522
+ {
523
+ _grad_w += cache_grad_sampling_loc[sid];
524
+ _grad_h += cache_grad_sampling_loc[sid + 1];
525
+ _grad_a += cache_grad_attn_weight[tid];
526
+ sid += 2;
527
+ }
528
+
529
+
530
+ *grad_sampling_loc = _grad_w;
531
+ *(grad_sampling_loc + 1) = _grad_h;
532
+ *grad_attn_weight = _grad_a;
533
+ }
534
+ __syncthreads();
535
+
536
+ data_weight_ptr += 1;
537
+ data_loc_w_ptr += 2;
538
+ grad_attn_weight += grad_weight_stride;
539
+ grad_sampling_loc += grad_loc_stride;
540
+ }
541
+ }
542
+ }
543
+ }
544
+
545
+
546
+ template <typename scalar_t, unsigned int blockSize>
547
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
548
+ const scalar_t *grad_col,
549
+ const scalar_t *data_value,
550
+ const int64_t *data_spatial_shapes,
551
+ const int64_t *data_level_start_index,
552
+ const scalar_t *data_sampling_loc,
553
+ const scalar_t *data_attn_weight,
554
+ const int batch_size,
555
+ const int spatial_size,
556
+ const int num_heads,
557
+ const int channels,
558
+ const int num_levels,
559
+ const int num_query,
560
+ const int num_point,
561
+ scalar_t *grad_value,
562
+ scalar_t *grad_sampling_loc,
563
+ scalar_t *grad_attn_weight)
564
+ {
565
+ CUDA_KERNEL_LOOP(index, n)
566
+ {
567
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
568
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
569
+ unsigned int tid = threadIdx.x;
570
+ int _temp = index;
571
+ const int c_col = _temp % channels;
572
+ _temp /= channels;
573
+ const int sampling_index = _temp;
574
+ const int m_col = _temp % num_heads;
575
+ _temp /= num_heads;
576
+ const int q_col = _temp % num_query;
577
+ _temp /= num_query;
578
+ const int b_col = _temp;
579
+
580
+ const scalar_t top_grad = grad_col[index];
581
+
582
+ int data_weight_ptr = sampling_index * num_levels * num_point;
583
+ int data_loc_w_ptr = data_weight_ptr << 1;
584
+ const int grad_sampling_ptr = data_weight_ptr;
585
+ grad_sampling_loc += grad_sampling_ptr << 1;
586
+ grad_attn_weight += grad_sampling_ptr;
587
+ const int grad_weight_stride = 1;
588
+ const int grad_loc_stride = 2;
589
+ const int qid_stride = num_heads * channels;
590
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
591
+
592
+ for (int l_col=0; l_col < num_levels; ++l_col)
593
+ {
594
+ const int level_start_id = data_level_start_index[l_col];
595
+ const int spatial_h_ptr = l_col << 1;
596
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
597
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
598
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
599
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
600
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
601
+
602
+ for (int p_col=0; p_col < num_point; ++p_col)
603
+ {
604
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
605
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
606
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
607
+
608
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
609
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
610
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
611
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
612
+ *(cache_grad_attn_weight+threadIdx.x)=0;
613
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
614
+ {
615
+ ms_deform_attn_col2im_bilinear(
616
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
617
+ top_grad, weight, grad_value_ptr,
618
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
619
+ }
620
+
621
+ __syncthreads();
622
+
623
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
624
+ {
625
+ if (tid < s) {
626
+ const unsigned int xid1 = tid << 1;
627
+ const unsigned int xid2 = (tid + s) << 1;
628
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
629
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
630
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
631
+ }
632
+ __syncthreads();
633
+ }
634
+
635
+ if (tid == 0)
636
+ {
637
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
638
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
639
+ *grad_attn_weight = cache_grad_attn_weight[0];
640
+ }
641
+ __syncthreads();
642
+
643
+ data_weight_ptr += 1;
644
+ data_loc_w_ptr += 2;
645
+ grad_attn_weight += grad_weight_stride;
646
+ grad_sampling_loc += grad_loc_stride;
647
+ }
648
+ }
649
+ }
650
+ }
651
+
652
+
653
+ template <typename scalar_t>
654
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
655
+ const scalar_t *grad_col,
656
+ const scalar_t *data_value,
657
+ const int64_t *data_spatial_shapes,
658
+ const int64_t *data_level_start_index,
659
+ const scalar_t *data_sampling_loc,
660
+ const scalar_t *data_attn_weight,
661
+ const int batch_size,
662
+ const int spatial_size,
663
+ const int num_heads,
664
+ const int channels,
665
+ const int num_levels,
666
+ const int num_query,
667
+ const int num_point,
668
+ scalar_t *grad_value,
669
+ scalar_t *grad_sampling_loc,
670
+ scalar_t *grad_attn_weight)
671
+ {
672
+ CUDA_KERNEL_LOOP(index, n)
673
+ {
674
+ extern __shared__ int _s[];
675
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
676
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
677
+ unsigned int tid = threadIdx.x;
678
+ int _temp = index;
679
+ const int c_col = _temp % channels;
680
+ _temp /= channels;
681
+ const int sampling_index = _temp;
682
+ const int m_col = _temp % num_heads;
683
+ _temp /= num_heads;
684
+ const int q_col = _temp % num_query;
685
+ _temp /= num_query;
686
+ const int b_col = _temp;
687
+
688
+ const scalar_t top_grad = grad_col[index];
689
+
690
+ int data_weight_ptr = sampling_index * num_levels * num_point;
691
+ int data_loc_w_ptr = data_weight_ptr << 1;
692
+ const int grad_sampling_ptr = data_weight_ptr;
693
+ grad_sampling_loc += grad_sampling_ptr << 1;
694
+ grad_attn_weight += grad_sampling_ptr;
695
+ const int grad_weight_stride = 1;
696
+ const int grad_loc_stride = 2;
697
+ const int qid_stride = num_heads * channels;
698
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
699
+
700
+ for (int l_col=0; l_col < num_levels; ++l_col)
701
+ {
702
+ const int level_start_id = data_level_start_index[l_col];
703
+ const int spatial_h_ptr = l_col << 1;
704
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
705
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
706
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
707
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
708
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
709
+
710
+ for (int p_col=0; p_col < num_point; ++p_col)
711
+ {
712
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
713
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
714
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
715
+
716
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
717
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
718
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
719
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
720
+ *(cache_grad_attn_weight+threadIdx.x)=0;
721
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
722
+ {
723
+ ms_deform_attn_col2im_bilinear(
724
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
725
+ top_grad, weight, grad_value_ptr,
726
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
727
+ }
728
+
729
+ __syncthreads();
730
+ if (tid == 0)
731
+ {
732
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
733
+ int sid=2;
734
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
735
+ {
736
+ _grad_w += cache_grad_sampling_loc[sid];
737
+ _grad_h += cache_grad_sampling_loc[sid + 1];
738
+ _grad_a += cache_grad_attn_weight[tid];
739
+ sid += 2;
740
+ }
741
+
742
+
743
+ *grad_sampling_loc = _grad_w;
744
+ *(grad_sampling_loc + 1) = _grad_h;
745
+ *grad_attn_weight = _grad_a;
746
+ }
747
+ __syncthreads();
748
+
749
+ data_weight_ptr += 1;
750
+ data_loc_w_ptr += 2;
751
+ grad_attn_weight += grad_weight_stride;
752
+ grad_sampling_loc += grad_loc_stride;
753
+ }
754
+ }
755
+ }
756
+ }
757
+
758
+ template <typename scalar_t>
759
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
760
+ const scalar_t *grad_col,
761
+ const scalar_t *data_value,
762
+ const int64_t *data_spatial_shapes,
763
+ const int64_t *data_level_start_index,
764
+ const scalar_t *data_sampling_loc,
765
+ const scalar_t *data_attn_weight,
766
+ const int batch_size,
767
+ const int spatial_size,
768
+ const int num_heads,
769
+ const int channels,
770
+ const int num_levels,
771
+ const int num_query,
772
+ const int num_point,
773
+ scalar_t *grad_value,
774
+ scalar_t *grad_sampling_loc,
775
+ scalar_t *grad_attn_weight)
776
+ {
777
+ CUDA_KERNEL_LOOP(index, n)
778
+ {
779
+ extern __shared__ int _s[];
780
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
781
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
782
+ unsigned int tid = threadIdx.x;
783
+ int _temp = index;
784
+ const int c_col = _temp % channels;
785
+ _temp /= channels;
786
+ const int sampling_index = _temp;
787
+ const int m_col = _temp % num_heads;
788
+ _temp /= num_heads;
789
+ const int q_col = _temp % num_query;
790
+ _temp /= num_query;
791
+ const int b_col = _temp;
792
+
793
+ const scalar_t top_grad = grad_col[index];
794
+
795
+ int data_weight_ptr = sampling_index * num_levels * num_point;
796
+ int data_loc_w_ptr = data_weight_ptr << 1;
797
+ const int grad_sampling_ptr = data_weight_ptr;
798
+ grad_sampling_loc += grad_sampling_ptr << 1;
799
+ grad_attn_weight += grad_sampling_ptr;
800
+ const int grad_weight_stride = 1;
801
+ const int grad_loc_stride = 2;
802
+ const int qid_stride = num_heads * channels;
803
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
804
+
805
+ for (int l_col=0; l_col < num_levels; ++l_col)
806
+ {
807
+ const int level_start_id = data_level_start_index[l_col];
808
+ const int spatial_h_ptr = l_col << 1;
809
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
810
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
811
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
812
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
813
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
814
+
815
+ for (int p_col=0; p_col < num_point; ++p_col)
816
+ {
817
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
818
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
819
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
820
+
821
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
822
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
823
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
824
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
825
+ *(cache_grad_attn_weight+threadIdx.x)=0;
826
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
827
+ {
828
+ ms_deform_attn_col2im_bilinear(
829
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
830
+ top_grad, weight, grad_value_ptr,
831
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
832
+ }
833
+
834
+ __syncthreads();
835
+
836
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
837
+ {
838
+ if (tid < s) {
839
+ const unsigned int xid1 = tid << 1;
840
+ const unsigned int xid2 = (tid + s) << 1;
841
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
842
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
843
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
844
+ if (tid + (s << 1) < spre)
845
+ {
846
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
847
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
848
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
849
+ }
850
+ }
851
+ __syncthreads();
852
+ }
853
+
854
+ if (tid == 0)
855
+ {
856
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
857
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
858
+ *grad_attn_weight = cache_grad_attn_weight[0];
859
+ }
860
+ __syncthreads();
861
+
862
+ data_weight_ptr += 1;
863
+ data_loc_w_ptr += 2;
864
+ grad_attn_weight += grad_weight_stride;
865
+ grad_sampling_loc += grad_loc_stride;
866
+ }
867
+ }
868
+ }
869
+ }
870
+
871
+ template <typename scalar_t>
872
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
873
+ const scalar_t *grad_col,
874
+ const scalar_t *data_value,
875
+ const int64_t *data_spatial_shapes,
876
+ const int64_t *data_level_start_index,
877
+ const scalar_t *data_sampling_loc,
878
+ const scalar_t *data_attn_weight,
879
+ const int batch_size,
880
+ const int spatial_size,
881
+ const int num_heads,
882
+ const int channels,
883
+ const int num_levels,
884
+ const int num_query,
885
+ const int num_point,
886
+ scalar_t *grad_value,
887
+ scalar_t *grad_sampling_loc,
888
+ scalar_t *grad_attn_weight)
889
+ {
890
+ CUDA_KERNEL_LOOP(index, n)
891
+ {
892
+ extern __shared__ int _s[];
893
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
894
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
895
+ unsigned int tid = threadIdx.x;
896
+ int _temp = index;
897
+ const int c_col = _temp % channels;
898
+ _temp /= channels;
899
+ const int sampling_index = _temp;
900
+ const int m_col = _temp % num_heads;
901
+ _temp /= num_heads;
902
+ const int q_col = _temp % num_query;
903
+ _temp /= num_query;
904
+ const int b_col = _temp;
905
+
906
+ const scalar_t top_grad = grad_col[index];
907
+
908
+ int data_weight_ptr = sampling_index * num_levels * num_point;
909
+ int data_loc_w_ptr = data_weight_ptr << 1;
910
+ const int grad_sampling_ptr = data_weight_ptr;
911
+ grad_sampling_loc += grad_sampling_ptr << 1;
912
+ grad_attn_weight += grad_sampling_ptr;
913
+ const int grad_weight_stride = 1;
914
+ const int grad_loc_stride = 2;
915
+ const int qid_stride = num_heads * channels;
916
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
917
+
918
+ for (int l_col=0; l_col < num_levels; ++l_col)
919
+ {
920
+ const int level_start_id = data_level_start_index[l_col];
921
+ const int spatial_h_ptr = l_col << 1;
922
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
923
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
924
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
925
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
926
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
927
+
928
+ for (int p_col=0; p_col < num_point; ++p_col)
929
+ {
930
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
931
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
932
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
933
+
934
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
935
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
936
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
937
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
938
+ *(cache_grad_attn_weight+threadIdx.x)=0;
939
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
940
+ {
941
+ ms_deform_attn_col2im_bilinear(
942
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
943
+ top_grad, weight, grad_value_ptr,
944
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
945
+ }
946
+
947
+ __syncthreads();
948
+
949
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
950
+ {
951
+ if (tid < s) {
952
+ const unsigned int xid1 = tid << 1;
953
+ const unsigned int xid2 = (tid + s) << 1;
954
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
955
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
956
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
957
+ if (tid + (s << 1) < spre)
958
+ {
959
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
960
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
961
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
962
+ }
963
+ }
964
+ __syncthreads();
965
+ }
966
+
967
+ if (tid == 0)
968
+ {
969
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
970
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
971
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
972
+ }
973
+ __syncthreads();
974
+
975
+ data_weight_ptr += 1;
976
+ data_loc_w_ptr += 2;
977
+ grad_attn_weight += grad_weight_stride;
978
+ grad_sampling_loc += grad_loc_stride;
979
+ }
980
+ }
981
+ }
982
+ }
983
+
984
+
985
+ template <typename scalar_t>
986
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
987
+ const scalar_t *grad_col,
988
+ const scalar_t *data_value,
989
+ const int64_t *data_spatial_shapes,
990
+ const int64_t *data_level_start_index,
991
+ const scalar_t *data_sampling_loc,
992
+ const scalar_t *data_attn_weight,
993
+ const int batch_size,
994
+ const int spatial_size,
995
+ const int num_heads,
996
+ const int channels,
997
+ const int num_levels,
998
+ const int num_query,
999
+ const int num_point,
1000
+ scalar_t *grad_value,
1001
+ scalar_t *grad_sampling_loc,
1002
+ scalar_t *grad_attn_weight)
1003
+ {
1004
+ CUDA_KERNEL_LOOP(index, n)
1005
+ {
1006
+ int _temp = index;
1007
+ const int c_col = _temp % channels;
1008
+ _temp /= channels;
1009
+ const int sampling_index = _temp;
1010
+ const int m_col = _temp % num_heads;
1011
+ _temp /= num_heads;
1012
+ const int q_col = _temp % num_query;
1013
+ _temp /= num_query;
1014
+ const int b_col = _temp;
1015
+
1016
+ const scalar_t top_grad = grad_col[index];
1017
+
1018
+ int data_weight_ptr = sampling_index * num_levels * num_point;
1019
+ int data_loc_w_ptr = data_weight_ptr << 1;
1020
+ const int grad_sampling_ptr = data_weight_ptr;
1021
+ grad_sampling_loc += grad_sampling_ptr << 1;
1022
+ grad_attn_weight += grad_sampling_ptr;
1023
+ const int grad_weight_stride = 1;
1024
+ const int grad_loc_stride = 2;
1025
+ const int qid_stride = num_heads * channels;
1026
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
1027
+
1028
+ for (int l_col=0; l_col < num_levels; ++l_col)
1029
+ {
1030
+ const int level_start_id = data_level_start_index[l_col];
1031
+ const int spatial_h_ptr = l_col << 1;
1032
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
1033
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
1034
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
1035
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
1036
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
1037
+
1038
+ for (int p_col=0; p_col < num_point; ++p_col)
1039
+ {
1040
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
1041
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
1042
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
1043
+
1044
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
1045
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
1046
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
1047
+ {
1048
+ ms_deform_attn_col2im_bilinear_gm(
1049
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
1050
+ top_grad, weight, grad_value_ptr,
1051
+ grad_sampling_loc, grad_attn_weight);
1052
+ }
1053
+ data_weight_ptr += 1;
1054
+ data_loc_w_ptr += 2;
1055
+ grad_attn_weight += grad_weight_stride;
1056
+ grad_sampling_loc += grad_loc_stride;
1057
+ }
1058
+ }
1059
+ }
1060
+ }
1061
+
1062
+
1063
+ template <typename scalar_t>
1064
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
1065
+ const scalar_t* data_value,
1066
+ const int64_t* data_spatial_shapes,
1067
+ const int64_t* data_level_start_index,
1068
+ const scalar_t* data_sampling_loc,
1069
+ const scalar_t* data_attn_weight,
1070
+ const int batch_size,
1071
+ const int spatial_size,
1072
+ const int num_heads,
1073
+ const int channels,
1074
+ const int num_levels,
1075
+ const int num_query,
1076
+ const int num_point,
1077
+ scalar_t* data_col)
1078
+ {
1079
+ const int num_kernels = batch_size * num_query * num_heads * channels;
1080
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
1081
+ const int num_threads = CUDA_NUM_THREADS;
1082
+ ms_deformable_im2col_gpu_kernel<scalar_t>
1083
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1084
+ 0, stream>>>(
1085
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
1086
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
1087
+
1088
+ cudaError_t err = cudaGetLastError();
1089
+ if (err != cudaSuccess)
1090
+ {
1091
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
1092
+ }
1093
+
1094
+ }
1095
+
1096
+ template <typename scalar_t>
1097
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
1098
+ const scalar_t* grad_col,
1099
+ const scalar_t* data_value,
1100
+ const int64_t * data_spatial_shapes,
1101
+ const int64_t * data_level_start_index,
1102
+ const scalar_t * data_sampling_loc,
1103
+ const scalar_t * data_attn_weight,
1104
+ const int batch_size,
1105
+ const int spatial_size,
1106
+ const int num_heads,
1107
+ const int channels,
1108
+ const int num_levels,
1109
+ const int num_query,
1110
+ const int num_point,
1111
+ scalar_t* grad_value,
1112
+ scalar_t* grad_sampling_loc,
1113
+ scalar_t* grad_attn_weight)
1114
+ {
1115
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
1116
+ const int num_kernels = batch_size * num_query * num_heads * channels;
1117
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
1118
+ if (channels > 1024)
1119
+ {
1120
+ if ((channels & 1023) == 0)
1121
+ {
1122
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
1123
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1124
+ num_threads*3*sizeof(scalar_t), stream>>>(
1125
+ num_kernels,
1126
+ grad_col,
1127
+ data_value,
1128
+ data_spatial_shapes,
1129
+ data_level_start_index,
1130
+ data_sampling_loc,
1131
+ data_attn_weight,
1132
+ batch_size,
1133
+ spatial_size,
1134
+ num_heads,
1135
+ channels,
1136
+ num_levels,
1137
+ num_query,
1138
+ num_point,
1139
+ grad_value,
1140
+ grad_sampling_loc,
1141
+ grad_attn_weight);
1142
+ }
1143
+ else
1144
+ {
1145
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1146
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1147
+ 0, stream>>>(
1148
+ num_kernels,
1149
+ grad_col,
1150
+ data_value,
1151
+ data_spatial_shapes,
1152
+ data_level_start_index,
1153
+ data_sampling_loc,
1154
+ data_attn_weight,
1155
+ batch_size,
1156
+ spatial_size,
1157
+ num_heads,
1158
+ channels,
1159
+ num_levels,
1160
+ num_query,
1161
+ num_point,
1162
+ grad_value,
1163
+ grad_sampling_loc,
1164
+ grad_attn_weight);
1165
+ }
1166
+ }
1167
+ else{
1168
+ switch(channels)
1169
+ {
1170
+ case 1:
1171
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1172
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1173
+ 0, stream>>>(
1174
+ num_kernels,
1175
+ grad_col,
1176
+ data_value,
1177
+ data_spatial_shapes,
1178
+ data_level_start_index,
1179
+ data_sampling_loc,
1180
+ data_attn_weight,
1181
+ batch_size,
1182
+ spatial_size,
1183
+ num_heads,
1184
+ channels,
1185
+ num_levels,
1186
+ num_query,
1187
+ num_point,
1188
+ grad_value,
1189
+ grad_sampling_loc,
1190
+ grad_attn_weight);
1191
+ break;
1192
+ case 2:
1193
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1194
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1195
+ 0, stream>>>(
1196
+ num_kernels,
1197
+ grad_col,
1198
+ data_value,
1199
+ data_spatial_shapes,
1200
+ data_level_start_index,
1201
+ data_sampling_loc,
1202
+ data_attn_weight,
1203
+ batch_size,
1204
+ spatial_size,
1205
+ num_heads,
1206
+ channels,
1207
+ num_levels,
1208
+ num_query,
1209
+ num_point,
1210
+ grad_value,
1211
+ grad_sampling_loc,
1212
+ grad_attn_weight);
1213
+ break;
1214
+ case 4:
1215
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1216
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1217
+ 0, stream>>>(
1218
+ num_kernels,
1219
+ grad_col,
1220
+ data_value,
1221
+ data_spatial_shapes,
1222
+ data_level_start_index,
1223
+ data_sampling_loc,
1224
+ data_attn_weight,
1225
+ batch_size,
1226
+ spatial_size,
1227
+ num_heads,
1228
+ channels,
1229
+ num_levels,
1230
+ num_query,
1231
+ num_point,
1232
+ grad_value,
1233
+ grad_sampling_loc,
1234
+ grad_attn_weight);
1235
+ break;
1236
+ case 8:
1237
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1238
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1239
+ 0, stream>>>(
1240
+ num_kernels,
1241
+ grad_col,
1242
+ data_value,
1243
+ data_spatial_shapes,
1244
+ data_level_start_index,
1245
+ data_sampling_loc,
1246
+ data_attn_weight,
1247
+ batch_size,
1248
+ spatial_size,
1249
+ num_heads,
1250
+ channels,
1251
+ num_levels,
1252
+ num_query,
1253
+ num_point,
1254
+ grad_value,
1255
+ grad_sampling_loc,
1256
+ grad_attn_weight);
1257
+ break;
1258
+ case 16:
1259
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1260
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1261
+ 0, stream>>>(
1262
+ num_kernels,
1263
+ grad_col,
1264
+ data_value,
1265
+ data_spatial_shapes,
1266
+ data_level_start_index,
1267
+ data_sampling_loc,
1268
+ data_attn_weight,
1269
+ batch_size,
1270
+ spatial_size,
1271
+ num_heads,
1272
+ channels,
1273
+ num_levels,
1274
+ num_query,
1275
+ num_point,
1276
+ grad_value,
1277
+ grad_sampling_loc,
1278
+ grad_attn_weight);
1279
+ break;
1280
+ case 32:
1281
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1282
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1283
+ 0, stream>>>(
1284
+ num_kernels,
1285
+ grad_col,
1286
+ data_value,
1287
+ data_spatial_shapes,
1288
+ data_level_start_index,
1289
+ data_sampling_loc,
1290
+ data_attn_weight,
1291
+ batch_size,
1292
+ spatial_size,
1293
+ num_heads,
1294
+ channels,
1295
+ num_levels,
1296
+ num_query,
1297
+ num_point,
1298
+ grad_value,
1299
+ grad_sampling_loc,
1300
+ grad_attn_weight);
1301
+ break;
1302
+ case 64:
1303
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1304
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1305
+ 0, stream>>>(
1306
+ num_kernels,
1307
+ grad_col,
1308
+ data_value,
1309
+ data_spatial_shapes,
1310
+ data_level_start_index,
1311
+ data_sampling_loc,
1312
+ data_attn_weight,
1313
+ batch_size,
1314
+ spatial_size,
1315
+ num_heads,
1316
+ channels,
1317
+ num_levels,
1318
+ num_query,
1319
+ num_point,
1320
+ grad_value,
1321
+ grad_sampling_loc,
1322
+ grad_attn_weight);
1323
+ break;
1324
+ case 128:
1325
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1326
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1327
+ 0, stream>>>(
1328
+ num_kernels,
1329
+ grad_col,
1330
+ data_value,
1331
+ data_spatial_shapes,
1332
+ data_level_start_index,
1333
+ data_sampling_loc,
1334
+ data_attn_weight,
1335
+ batch_size,
1336
+ spatial_size,
1337
+ num_heads,
1338
+ channels,
1339
+ num_levels,
1340
+ num_query,
1341
+ num_point,
1342
+ grad_value,
1343
+ grad_sampling_loc,
1344
+ grad_attn_weight);
1345
+ break;
1346
+ case 256:
1347
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1348
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1349
+ 0, stream>>>(
1350
+ num_kernels,
1351
+ grad_col,
1352
+ data_value,
1353
+ data_spatial_shapes,
1354
+ data_level_start_index,
1355
+ data_sampling_loc,
1356
+ data_attn_weight,
1357
+ batch_size,
1358
+ spatial_size,
1359
+ num_heads,
1360
+ channels,
1361
+ num_levels,
1362
+ num_query,
1363
+ num_point,
1364
+ grad_value,
1365
+ grad_sampling_loc,
1366
+ grad_attn_weight);
1367
+ break;
1368
+ case 512:
1369
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1370
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1371
+ 0, stream>>>(
1372
+ num_kernels,
1373
+ grad_col,
1374
+ data_value,
1375
+ data_spatial_shapes,
1376
+ data_level_start_index,
1377
+ data_sampling_loc,
1378
+ data_attn_weight,
1379
+ batch_size,
1380
+ spatial_size,
1381
+ num_heads,
1382
+ channels,
1383
+ num_levels,
1384
+ num_query,
1385
+ num_point,
1386
+ grad_value,
1387
+ grad_sampling_loc,
1388
+ grad_attn_weight);
1389
+ break;
1390
+ case 1024:
1391
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1392
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1393
+ 0, stream>>>(
1394
+ num_kernels,
1395
+ grad_col,
1396
+ data_value,
1397
+ data_spatial_shapes,
1398
+ data_level_start_index,
1399
+ data_sampling_loc,
1400
+ data_attn_weight,
1401
+ batch_size,
1402
+ spatial_size,
1403
+ num_heads,
1404
+ channels,
1405
+ num_levels,
1406
+ num_query,
1407
+ num_point,
1408
+ grad_value,
1409
+ grad_sampling_loc,
1410
+ grad_attn_weight);
1411
+ break;
1412
+ default:
1413
+ if (channels < 64)
1414
+ {
1415
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1416
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1417
+ num_threads*3*sizeof(scalar_t), stream>>>(
1418
+ num_kernels,
1419
+ grad_col,
1420
+ data_value,
1421
+ data_spatial_shapes,
1422
+ data_level_start_index,
1423
+ data_sampling_loc,
1424
+ data_attn_weight,
1425
+ batch_size,
1426
+ spatial_size,
1427
+ num_heads,
1428
+ channels,
1429
+ num_levels,
1430
+ num_query,
1431
+ num_point,
1432
+ grad_value,
1433
+ grad_sampling_loc,
1434
+ grad_attn_weight);
1435
+ }
1436
+ else
1437
+ {
1438
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1439
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1440
+ num_threads*3*sizeof(scalar_t), stream>>>(
1441
+ num_kernels,
1442
+ grad_col,
1443
+ data_value,
1444
+ data_spatial_shapes,
1445
+ data_level_start_index,
1446
+ data_sampling_loc,
1447
+ data_attn_weight,
1448
+ batch_size,
1449
+ spatial_size,
1450
+ num_heads,
1451
+ channels,
1452
+ num_levels,
1453
+ num_query,
1454
+ num_point,
1455
+ grad_value,
1456
+ grad_sampling_loc,
1457
+ grad_attn_weight);
1458
+ }
1459
+ }
1460
+ }
1461
+ cudaError_t err = cudaGetLastError();
1462
+ if (err != cudaSuccess)
1463
+ {
1464
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1465
+ }
1466
+
1467
+ }
phivenv/Lib/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor ms_deform_attn_cuda_forward(
15
+ const at::Tensor &value,
16
+ const at::Tensor &spatial_shapes,
17
+ const at::Tensor &level_start_index,
18
+ const at::Tensor &sampling_loc,
19
+ const at::Tensor &attn_weight,
20
+ const int im2col_step);
21
+
22
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const at::Tensor &grad_output,
29
+ const int im2col_step);
phivenv/Lib/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
phivenv/Lib/site-packages/transformers/kernels/deta/ms_deform_attn.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "cpu/ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "cuda/ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+
20
+ at::Tensor
21
+ ms_deform_attn_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ if (value.type().is_cuda())
30
+ {
31
+ #ifdef WITH_CUDA
32
+ return ms_deform_attn_cuda_forward(
33
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34
+ #else
35
+ AT_ERROR("Not compiled with GPU support");
36
+ #endif
37
+ }
38
+ AT_ERROR("Not implemented on the CPU");
39
+ }
40
+
41
+ std::vector<at::Tensor>
42
+ ms_deform_attn_backward(
43
+ const at::Tensor &value,
44
+ const at::Tensor &spatial_shapes,
45
+ const at::Tensor &level_start_index,
46
+ const at::Tensor &sampling_loc,
47
+ const at::Tensor &attn_weight,
48
+ const at::Tensor &grad_output,
49
+ const int im2col_step)
50
+ {
51
+ if (value.type().is_cuda())
52
+ {
53
+ #ifdef WITH_CUDA
54
+ return ms_deform_attn_cuda_backward(
55
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56
+ #else
57
+ AT_ERROR("Not compiled with GPU support");
58
+ #endif
59
+ }
60
+ AT_ERROR("Not implemented on the CPU");
61
+ }
phivenv/Lib/site-packages/transformers/kernels/deta/vision.cpp ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include "ms_deform_attn.h"
12
+
13
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16
+ }
phivenv/Lib/site-packages/transformers/kernels/falcon_mamba/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from .selective_scan_with_ln_interface import mamba_inner_fn
phivenv/Lib/site-packages/transformers/kernels/falcon_mamba/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (248 Bytes). View file
 
phivenv/Lib/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-39.pyc ADDED
Binary file (9.83 kB). View file
 
phivenv/Lib/site-packages/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Original code from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from einops import rearrange, repeat
20
+ from torch.cuda.amp import custom_bwd, custom_fwd
21
+
22
+
23
+ try:
24
+ import causal_conv1d_cuda
25
+ except ImportError:
26
+ causal_conv1d_cuda = None
27
+
28
+ import mamba_ssm
29
+ import selective_scan_cuda
30
+
31
+
32
+ # For BC for old mamba-ssm versions: https://github.com/huggingface/transformers/pull/33195#discussion_r1736401127
33
+ if hasattr(mamba_ssm.ops.triton, "layernorm"):
34
+ from mamba_ssm.ops.triton.layernorm import _layer_norm_fwd
35
+ else:
36
+ from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
37
+
38
+
39
+ class SelectiveScanFn(torch.autograd.Function):
40
+ @staticmethod
41
+ def forward(
42
+ ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
43
+ ):
44
+ if u.stride(-1) != 1:
45
+ u = u.contiguous()
46
+ if delta.stride(-1) != 1:
47
+ delta = delta.contiguous()
48
+ if D is not None:
49
+ D = D.contiguous()
50
+ if B.stride(-1) != 1:
51
+ B = B.contiguous()
52
+ if C.stride(-1) != 1:
53
+ C = C.contiguous()
54
+ if z is not None and z.stride(-1) != 1:
55
+ z = z.contiguous()
56
+ if B.dim() == 3:
57
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
58
+ ctx.squeeze_B = True
59
+ if C.dim() == 3:
60
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
61
+ ctx.squeeze_C = True
62
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
63
+ ctx.delta_softplus = delta_softplus
64
+ ctx.has_z = z is not None
65
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
66
+ if not ctx.has_z:
67
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
68
+ return out if not return_last_state else (out, last_state)
69
+ else:
70
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
71
+ out_z = rest[0]
72
+ return out_z if not return_last_state else (out_z, last_state)
73
+
74
+ @staticmethod
75
+ def backward(ctx, dout, *args):
76
+ if not ctx.has_z:
77
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
78
+ z = None
79
+ out = None
80
+ else:
81
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
82
+ if dout.stride(-1) != 1:
83
+ dout = dout.contiguous()
84
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
85
+ # backward of selective_scan_cuda with the backward of chunk).
86
+ # Here we just pass in None and dz will be allocated in the C++ code.
87
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
88
+ u,
89
+ delta,
90
+ A,
91
+ B,
92
+ C,
93
+ D,
94
+ z,
95
+ delta_bias,
96
+ dout,
97
+ x,
98
+ out,
99
+ None,
100
+ ctx.delta_softplus,
101
+ False, # option to recompute out_z, not used here
102
+ )
103
+ dz = rest[0] if ctx.has_z else None
104
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
105
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
106
+ return (
107
+ du,
108
+ ddelta,
109
+ dA,
110
+ dB,
111
+ dC,
112
+ dD if D is not None else None,
113
+ dz,
114
+ ddelta_bias if delta_bias is not None else None,
115
+ None,
116
+ None,
117
+ )
118
+
119
+
120
+ def rms_norm_forward(
121
+ x,
122
+ weight,
123
+ bias,
124
+ eps=1e-6,
125
+ is_rms_norm=True,
126
+ ):
127
+ # x (b l) d
128
+ if x.stride(-1) != 1:
129
+ x = x.contiguous()
130
+ weight = weight.contiguous()
131
+ if bias is not None:
132
+ bias = bias.contiguous()
133
+ y = _layer_norm_fwd(x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm)[0]
134
+ # y (b l) d
135
+ return y
136
+
137
+
138
+ def selective_scan_fn(
139
+ u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
140
+ ):
141
+ """if return_last_state is True, returns (out, last_state)
142
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
143
+ not considered in the backward pass.
144
+ """
145
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
146
+
147
+
148
+ def selective_scan_ref(
149
+ u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
150
+ ):
151
+ """
152
+ u: r(B D L)
153
+ delta: r(B D L)
154
+ A: c(D N) or r(D N)
155
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
156
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
157
+ D: r(D)
158
+ z: r(B D L)
159
+ delta_bias: r(D), fp32
160
+
161
+ out: r(B D L)
162
+ last_state (optional): r(B D dstate) or c(B D dstate)
163
+ """
164
+ dtype_in = u.dtype
165
+ u = u.float()
166
+ delta = delta.float()
167
+ if delta_bias is not None:
168
+ delta = delta + delta_bias[..., None].float()
169
+ if delta_softplus:
170
+ delta = F.softplus(delta)
171
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
172
+ is_variable_B = B.dim() >= 3
173
+ is_variable_C = C.dim() >= 3
174
+ if A.is_complex():
175
+ if is_variable_B:
176
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
177
+ if is_variable_C:
178
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
179
+ else:
180
+ B = B.float()
181
+ C = C.float()
182
+ x = A.new_zeros((batch, dim, dstate))
183
+ ys = []
184
+ deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
185
+ if not is_variable_B:
186
+ deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
187
+ else:
188
+ if B.dim() == 3:
189
+ deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
190
+ else:
191
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
192
+ deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
193
+ if is_variable_C and C.dim() == 4:
194
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
195
+ last_state = None
196
+ for i in range(u.shape[2]):
197
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
198
+ if not is_variable_C:
199
+ y = torch.einsum("bdn,dn->bd", x, C)
200
+ else:
201
+ if C.dim() == 3:
202
+ y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
203
+ else:
204
+ y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
205
+ if i == u.shape[2] - 1:
206
+ last_state = x
207
+ if y.is_complex():
208
+ y = y.real * 2
209
+ ys.append(y)
210
+ y = torch.stack(ys, dim=2) # (batch dim L)
211
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
212
+ if z is not None:
213
+ out = out * F.silu(z)
214
+ out = out.to(dtype=dtype_in)
215
+ return out if not return_last_state else (out, last_state)
216
+
217
+
218
+ class MambaInnerFn(torch.autograd.Function):
219
+ @staticmethod
220
+ @custom_fwd
221
+ def forward(
222
+ ctx,
223
+ xz,
224
+ conv1d_weight,
225
+ conv1d_bias,
226
+ x_proj_weight,
227
+ delta_proj_weight,
228
+ out_proj_weight,
229
+ out_proj_bias,
230
+ A,
231
+ B=None,
232
+ C=None,
233
+ D=None,
234
+ delta_bias=None,
235
+ B_proj_bias=None,
236
+ C_proj_bias=None,
237
+ delta_softplus=True,
238
+ checkpoint_lvl=1,
239
+ b_rms_weight=None,
240
+ c_rms_weight=None,
241
+ dt_rms_weight=None,
242
+ b_c_dt_rms_eps=1e-6,
243
+ ):
244
+ """
245
+ xz: (batch, dim, seqlen)
246
+ """
247
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
248
+ assert checkpoint_lvl in [0, 1]
249
+ L = xz.shape[-1]
250
+ delta_rank = delta_proj_weight.shape[1]
251
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
252
+ if torch.is_autocast_enabled():
253
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
254
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
255
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
256
+ out_proj_bias = (
257
+ out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None
258
+ )
259
+ if xz.stride(-1) != 1:
260
+ xz = xz.contiguous()
261
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
262
+ x, z = xz.chunk(2, dim=1)
263
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
264
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
265
+ # We're being very careful here about the layout, to avoid extra transposes.
266
+ # We want delta to have d as the slowest moving dimension
267
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
268
+ x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d)
269
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
270
+ ctx.is_variable_B = B is None
271
+ ctx.is_variable_C = C is None
272
+ ctx.B_proj_bias_is_None = B_proj_bias is None
273
+ ctx.C_proj_bias_is_None = C_proj_bias is None
274
+ if B is None: # variable B
275
+ B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
276
+ if B_proj_bias is not None:
277
+ B = B + B_proj_bias.to(dtype=B.dtype)
278
+ if not A.is_complex():
279
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
280
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
281
+ else:
282
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
283
+ else:
284
+ if B.stride(-1) != 1:
285
+ B = B.contiguous()
286
+ if C is None: # variable C
287
+ C = x_dbl[:, -d_state:] # (bl dstate)
288
+ if C_proj_bias is not None:
289
+ C = C + C_proj_bias.to(dtype=C.dtype)
290
+ if not A.is_complex():
291
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
292
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
293
+ else:
294
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
295
+ else:
296
+ if C.stride(-1) != 1:
297
+ C = C.contiguous()
298
+ if D is not None:
299
+ D = D.contiguous()
300
+
301
+ if b_rms_weight is not None:
302
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
303
+ B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
304
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
305
+ if c_rms_weight is not None:
306
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
307
+ C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
308
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
309
+ if dt_rms_weight is not None:
310
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
311
+ delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
312
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
313
+
314
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
315
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
316
+ )
317
+ ctx.delta_softplus = delta_softplus
318
+ ctx.out_proj_bias_is_None = out_proj_bias is None
319
+ ctx.checkpoint_lvl = checkpoint_lvl
320
+ ctx.b_rms_weight = b_rms_weight
321
+ ctx.c_rms_weight = c_rms_weight
322
+ ctx.dt_rms_weight = dt_rms_weight
323
+ ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
324
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
325
+ conv1d_out, delta = None, None
326
+ ctx.save_for_backward(
327
+ xz,
328
+ conv1d_weight,
329
+ conv1d_bias,
330
+ x_dbl,
331
+ x_proj_weight,
332
+ delta_proj_weight,
333
+ out_proj_weight,
334
+ conv1d_out,
335
+ delta,
336
+ A,
337
+ B,
338
+ C,
339
+ D,
340
+ delta_bias,
341
+ scan_intermediates,
342
+ b_rms_weight,
343
+ c_rms_weight,
344
+ dt_rms_weight,
345
+ out,
346
+ )
347
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
348
+
349
+ @staticmethod
350
+ @custom_bwd
351
+ def backward(ctx, dout):
352
+ # dout: (batch, seqlen, dim)
353
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
354
+ (
355
+ xz,
356
+ conv1d_weight,
357
+ conv1d_bias,
358
+ x_dbl,
359
+ x_proj_weight,
360
+ delta_proj_weight,
361
+ out_proj_weight,
362
+ conv1d_out,
363
+ delta,
364
+ A,
365
+ B,
366
+ C,
367
+ D,
368
+ delta_bias,
369
+ scan_intermediates,
370
+ b_rms_weight,
371
+ c_rms_weight,
372
+ dt_rms_weight,
373
+ out,
374
+ ) = ctx.saved_tensors
375
+ L = xz.shape[-1]
376
+ delta_rank = delta_proj_weight.shape[1]
377
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
378
+ x, z = xz.chunk(2, dim=1)
379
+ if dout.stride(-1) != 1:
380
+ dout = dout.contiguous()
381
+ if ctx.checkpoint_lvl == 1:
382
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
383
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
384
+ if dt_rms_weight is not None:
385
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
386
+ delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
387
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
388
+ if b_rms_weight is not None:
389
+ # Recompute & RMSNorm B
390
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
391
+ B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
392
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
393
+ if c_rms_weight is not None:
394
+ # Recompute & RMSNorm C
395
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
396
+ C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
397
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
398
+
399
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
400
+ # backward of selective_scan_cuda with the backward of chunk).
401
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
402
+ dx, dz = dxz.chunk(2, dim=1)
403
+ dout = rearrange(dout, "b l e -> e (b l)")
404
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
405
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
406
+ conv1d_out,
407
+ delta,
408
+ A,
409
+ B,
410
+ C,
411
+ D,
412
+ z,
413
+ delta_bias,
414
+ dout_y,
415
+ scan_intermediates,
416
+ out,
417
+ dz,
418
+ ctx.delta_softplus,
419
+ True, # option to recompute out_z
420
+ )
421
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
422
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
423
+ dD = dD if D is not None else None
424
+ dx_dbl = torch.empty_like(x_dbl)
425
+ dB_proj_bias = None
426
+ if ctx.is_variable_B:
427
+ if not A.is_complex():
428
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
429
+ else:
430
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
431
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
432
+ dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
433
+ dB = None
434
+ dC_proj_bias = None
435
+ if ctx.is_variable_C:
436
+ if not A.is_complex():
437
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
438
+ else:
439
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
440
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
441
+ dx_dbl[:, -d_state:] = dC # (bl d)
442
+ dC = None
443
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
444
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
445
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
446
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
447
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
448
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
449
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
450
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
451
+ # backward of conv1d with the backward of chunk).
452
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
453
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
454
+ )
455
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
456
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
457
+ return (
458
+ dxz,
459
+ dconv1d_weight,
460
+ dconv1d_bias,
461
+ dx_proj_weight,
462
+ ddelta_proj_weight,
463
+ dout_proj_weight,
464
+ dout_proj_bias,
465
+ dA,
466
+ dB,
467
+ dC,
468
+ dD,
469
+ ddelta_bias if delta_bias is not None else None,
470
+ # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
471
+ dB_proj_bias,
472
+ dC_proj_bias,
473
+ None,
474
+ None,
475
+ None,
476
+ None,
477
+ None,
478
+ None,
479
+ )
480
+
481
+
482
+ def mamba_inner_fn(
483
+ xz,
484
+ conv1d_weight,
485
+ conv1d_bias,
486
+ x_proj_weight,
487
+ delta_proj_weight,
488
+ out_proj_weight,
489
+ out_proj_bias,
490
+ A,
491
+ B=None,
492
+ C=None,
493
+ D=None,
494
+ delta_bias=None,
495
+ B_proj_bias=None,
496
+ C_proj_bias=None,
497
+ delta_softplus=True,
498
+ checkpoint_lvl=1,
499
+ b_rms_weight=None,
500
+ c_rms_weight=None,
501
+ dt_rms_weight=None,
502
+ b_c_dt_rms_eps=1e-6,
503
+ ):
504
+ return MambaInnerFn.apply(
505
+ xz,
506
+ conv1d_weight,
507
+ conv1d_bias,
508
+ x_proj_weight,
509
+ delta_proj_weight,
510
+ out_proj_weight,
511
+ out_proj_bias,
512
+ A,
513
+ B,
514
+ C,
515
+ D,
516
+ delta_bias,
517
+ B_proj_bias,
518
+ C_proj_bias,
519
+ delta_softplus,
520
+ checkpoint_lvl,
521
+ b_rms_weight,
522
+ c_rms_weight,
523
+ dt_rms_weight,
524
+ b_c_dt_rms_eps,
525
+ )
phivenv/Lib/site-packages/transformers/kernels/mra/cuda_kernel.cu ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "cuda_kernel.h"
2
+
3
+ //////////////////////////////////////////////////////////////////////////////////////////////////
4
+ //////////////////////////////////////////////////////////////////////////////////////////////////
5
+
6
+ __global__ void index_max_cuda_kernel(
7
+ float *index_vals, // [batch_size, 32, num_block]
8
+ int *indices, // [batch_size, num_block]
9
+ float *max_vals, // [batch_size, A_num_block * 32]
10
+ float *max_vals_scatter, // [batch_size, 32, num_block]
11
+ long batch_size,
12
+ long A_num_block,
13
+ long B_num_block,
14
+ long num_block
15
+ ) {
16
+
17
+ long batch_idx = blockIdx.x;
18
+
19
+ long thread_idx = threadIdx.x;
20
+ long num_thread = blockDim.x;
21
+
22
+ extern __shared__ float buffer[];
23
+ int *max_buffer = (int*)buffer;
24
+
25
+ for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
26
+ int idx = i + thread_idx;
27
+ if (idx < A_num_block * 32) {
28
+ max_buffer[idx] = -1e8;
29
+ }
30
+ }
31
+ __syncthreads();
32
+
33
+ int *indices_pt = &indices[batch_idx * num_block];
34
+ float *index_vals_pt = &index_vals[batch_idx * num_block * 32];
35
+
36
+ for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
37
+ int idx = idx_start + thread_idx;
38
+ int A_block_idx = indices_pt[idx % num_block] / B_num_block;
39
+ atomicMax(&max_buffer[A_block_idx * 32 + idx / num_block], (int)(index_vals_pt[idx] * 1000));
40
+ }
41
+ __syncthreads();
42
+
43
+ float *max_vals_pt = &max_vals[batch_idx * A_num_block * 32];
44
+ for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
45
+ int idx = i + thread_idx;
46
+ if (idx < A_num_block * 32) {
47
+ max_vals_pt[idx] = (float)max_buffer[idx] / 1000.;
48
+ }
49
+ }
50
+
51
+ float *max_vals_scatter_pt = &max_vals_scatter[batch_idx * num_block * 32];
52
+ for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
53
+ int idx = idx_start + thread_idx;
54
+ int A_block_idx = indices_pt[idx % num_block] / B_num_block;
55
+ max_vals_scatter_pt[idx] = (float)max_buffer[A_block_idx * 32 + idx / num_block] / 1000.;
56
+ }
57
+
58
+ }
59
+
60
+ __global__ void mm_to_sparse_cuda_kernel(
61
+ float *dense_A, // [batch_size, A_num_block, dim, 32]
62
+ float *dense_B, // [batch_size, B_num_block, dim, 32]
63
+ int *indices, // [batch_size, num_block]
64
+ float *sparse_C, // [batch_size, num_block, 32, 32]
65
+ long batch_size,
66
+ long A_num_block,
67
+ long B_num_block,
68
+ long dim,
69
+ long num_block
70
+ ) {
71
+
72
+ long batch_idx = blockIdx.y;
73
+ long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
74
+
75
+ long thread_idx = threadIdx.x;
76
+
77
+ __shared__ float buffer[4096];
78
+ float *A_buffer = &buffer[threadIdx.y * 1024]; // [2, 8, 32]
79
+ float *B_buffer = &buffer[threadIdx.y * 1024 + 512]; // [2, 8, 32]
80
+
81
+ long batch_idx__block_idx = batch_idx * num_block + block_idx;
82
+
83
+ long AB_block_idx = indices[batch_idx__block_idx];
84
+ float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * dim * 32];
85
+ float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * dim * 32];
86
+
87
+ int reg_1_idx = thread_idx / 8; // [0000000011111111222222223333333344444444555555556666666677777777]
88
+ int reg_2_idx = thread_idx % 8; // [0123456701234567012345670123456701234567012345670123456701234567]
89
+
90
+ float reg_1[8];
91
+ float reg_2[8];
92
+
93
+ float reg_array[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
94
+
95
+ #pragma unroll
96
+ for (int i = 0; i < 4; i++) {
97
+ A_buffer[i * 64 + thread_idx] = dense_A_pt[i * 64 + thread_idx];
98
+ B_buffer[i * 64 + thread_idx] = dense_B_pt[i * 64 + thread_idx];
99
+ }
100
+
101
+ __syncthreads();
102
+
103
+ #pragma unroll
104
+ for (int i = 0; i < 4; i++) {
105
+ reg_1[i] = A_buffer[reg_1_idx * 4 + i];
106
+ reg_2[i] = B_buffer[reg_2_idx * 4 + i];
107
+ }
108
+
109
+ for (int dim_stride = 1; dim_stride < (dim / 8); dim_stride++) {
110
+
111
+ #pragma unroll
112
+ for (int i = 0; i < 4; i++) {
113
+ A_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_A_pt[dim_stride * 256 + i * 64 + thread_idx];
114
+ B_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_B_pt[dim_stride * 256 + i * 64 + thread_idx];
115
+ }
116
+
117
+ #pragma unroll
118
+ for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
119
+ #pragma unroll
120
+ for (int i = 0; i < 4; i++) {
121
+ reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
122
+ reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
123
+ }
124
+ #pragma unroll
125
+ for (int i = 0; i < 4; i++) {
126
+ #pragma unroll
127
+ for (int j = 0; j < 4; j++) {
128
+ reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
129
+ }
130
+ }
131
+ }
132
+
133
+ __syncthreads();
134
+
135
+ #pragma unroll
136
+ for (int i = 0; i < 4; i++) {
137
+ reg_1[i] = A_buffer[(dim_stride % 2) * 256 + reg_1_idx * 4 + i];
138
+ reg_2[i] = B_buffer[(dim_stride % 2) * 256 + reg_2_idx * 4 + i];
139
+ }
140
+
141
+ #pragma unroll
142
+ for (int i = 0; i < 4; i++) {
143
+ #pragma unroll
144
+ for (int j = 0; j < 4; j++) {
145
+ reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
146
+ }
147
+ }
148
+
149
+ }
150
+
151
+ #pragma unroll
152
+ for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
153
+ #pragma unroll
154
+ for (int i = 0; i < 4; i++) {
155
+ reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
156
+ reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
157
+ }
158
+ #pragma unroll
159
+ for (int i = 0; i < 4; i++) {
160
+ #pragma unroll
161
+ for (int j = 0; j < 4; j++) {
162
+ reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
163
+ }
164
+ }
165
+ }
166
+ #pragma unroll
167
+ for (int i = 0; i < 4; i++) {
168
+ #pragma unroll
169
+ for (int j = 0; j < 4; j++) {
170
+ reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
171
+ }
172
+ }
173
+ __syncthreads();
174
+
175
+ float *C_buffer = &buffer[threadIdx.y * 1024]; // [32, 32]
176
+
177
+ #pragma unroll
178
+ for (int i = 0; i < 4; i++) {
179
+ #pragma unroll
180
+ for (int j = 0; j < 4; j++) {
181
+ C_buffer[(reg_2_idx * 4 + j) * 32 + reg_1_idx * 4 + i] = reg_array[i * 4 + j];
182
+ }
183
+ }
184
+ __syncthreads();
185
+
186
+ float *sparse_C_pt = &sparse_C[batch_idx__block_idx * 1024];
187
+
188
+ #pragma unroll
189
+ for (int i = 0; i < 16; i++) {
190
+ sparse_C_pt[i * 64 + thread_idx] = C_buffer[i * 64 + thread_idx];
191
+ }
192
+
193
+ }
194
+
195
+ __global__ void sparse_dense_mm_cuda_kernel(
196
+ float *sparse_A, // [batch_size, num_block, 32, 32]
197
+ int *indices, // [batch_size, num_block]
198
+ float *dense_B, // [batch_size, B_num_block, dim, 32]
199
+ float *dense_C, // [batch_size, A_num_block, dim, 32]
200
+ long batch_size,
201
+ long A_num_block,
202
+ long B_num_block,
203
+ long dim,
204
+ long num_block
205
+ ) {
206
+
207
+ long batch_idx = blockIdx.y;
208
+ long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
209
+
210
+ long thread_idx = threadIdx.x;
211
+
212
+ __shared__ float buffer[6144];
213
+ float *A_buffer = &buffer[threadIdx.y * 3072]; // [32, 32]
214
+ float *B_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [32, 64]
215
+
216
+ long batch_idx__block_idx = batch_idx * num_block + block_idx;
217
+
218
+ float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
219
+ #pragma unroll
220
+ for (int i = 0; i < 8; i++) {
221
+ A_buffer[i * 128 + thread_idx] = sparse_A_pt[i * 128 + thread_idx];
222
+ }
223
+
224
+ long AB_block_idx = indices[batch_idx__block_idx];
225
+ float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * 32 * dim];
226
+ float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32 * dim];
227
+
228
+ // [0000000011111111222222223333333344444444555555556666666677777777]
229
+ // [0123456701234567012345670123456701234567012345670123456701234567]
230
+ int reg_1_idx = thread_idx / 8;
231
+ int reg_2_idx = thread_idx % 8;
232
+
233
+ float reg_1[8];
234
+ float reg_2[8];
235
+
236
+ float reg_array[16];
237
+
238
+ for (int dim_stride = 0; dim_stride < dim; dim_stride = dim_stride + 64) {
239
+
240
+ #pragma unroll
241
+ for (int i = 0; i < 16; i++) {
242
+ B_buffer[i * 128 + thread_idx] = dense_B_pt[dim_stride * 32 + i * 128 + thread_idx];
243
+ }
244
+
245
+ #pragma unroll
246
+ for (int i = 0; i < 16; i++) {
247
+ reg_array[i] = 0;
248
+ }
249
+
250
+ __syncthreads();
251
+
252
+ #pragma unroll
253
+ for (int i = 0; i < 4; i++) {
254
+ reg_1[i] = B_buffer[(reg_1_idx * 4 + i) * 32];
255
+ reg_2[i] = A_buffer[reg_2_idx * 4 + i];
256
+ }
257
+
258
+ #pragma unroll
259
+ for (int mini_dim_idx = 1; mini_dim_idx < 32; mini_dim_idx++) {
260
+ #pragma unroll
261
+ for (int i = 0; i < 4; i++) {
262
+ reg_1[(mini_dim_idx % 2) * 4 + i] = B_buffer[(reg_1_idx * 4 + i) * 32 + mini_dim_idx];
263
+ reg_2[(mini_dim_idx % 2) * 4 + i] = A_buffer[mini_dim_idx * 32 + reg_2_idx * 4 + i];
264
+ }
265
+ #pragma unroll
266
+ for (int i = 0; i < 4; i++) {
267
+ #pragma unroll
268
+ for (int j = 0; j < 4; j++) {
269
+ reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
270
+ }
271
+ }
272
+ }
273
+
274
+ #pragma unroll
275
+ for (int i = 0; i < 4; i++) {
276
+ #pragma unroll
277
+ for (int j = 0; j < 4; j++) {
278
+ reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
279
+ }
280
+ }
281
+
282
+ __syncthreads();
283
+
284
+ float *C_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [64, 32]
285
+
286
+ #pragma unroll
287
+ for (int i = 0; i < 4; i++) {
288
+ #pragma unroll
289
+ for (int j = 0; j < 4; j++) {
290
+ C_buffer[(reg_1_idx * 4 + i) * 32 + reg_2_idx * 4 + j] = reg_array[i * 4 + j];
291
+ }
292
+ }
293
+ __syncthreads();
294
+
295
+ #pragma unroll
296
+ for (int i = 0; i < 16; i++) {
297
+ atomicAdd(&dense_C_pt[dim_stride * 32 + i * 128 + thread_idx], C_buffer[i * 128 + thread_idx]);
298
+ }
299
+ __syncthreads();
300
+
301
+ }
302
+
303
+ }
304
+
305
+
306
+ __global__ void reduce_sum_cuda_kernel(
307
+ float *sparse_A, // [batch_size, num_block, 32, 32]
308
+ int *indices, // [batch_size, num_block]
309
+ float *dense_C, // [batch_size, A_num_block, 32]
310
+ long batch_size,
311
+ long A_num_block,
312
+ long B_num_block,
313
+ long num_block
314
+ ) {
315
+
316
+ long batch_idx = blockIdx.y;
317
+ long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
318
+
319
+ long thread_idx = threadIdx.x;
320
+
321
+ long batch_idx__block_idx = batch_idx * num_block + block_idx;
322
+
323
+ long AB_block_idx = indices[batch_idx__block_idx];
324
+ float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
325
+
326
+ float reg_array[16];
327
+ float value = 0;
328
+
329
+ #pragma unroll
330
+ for (int i = 0; i < 8; i++) {
331
+ reg_array[i] = sparse_A_pt[i * 32 + thread_idx];
332
+ }
333
+ #pragma unroll
334
+ for (int stride = 8; stride < 32; stride = stride + 8) {
335
+ #pragma unroll
336
+ for (int i = 0; i < 8; i++) {
337
+ reg_array[(stride + i) % 16] = sparse_A_pt[(stride + i) * 32 + thread_idx];
338
+ }
339
+ #pragma unroll
340
+ for (int i = 0; i < 8; i++) {
341
+ value = value + reg_array[(stride - 8 + i) % 16];
342
+ }
343
+ }
344
+ #pragma unroll
345
+ for (int i = 0; i < 8; i++) {
346
+ value = value + reg_array[8 + i];
347
+ }
348
+
349
+ float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
350
+
351
+ atomicAdd(&dense_C_pt[thread_idx], value);
352
+
353
+ }
354
+
355
+ __global__ void scatter_cuda_kernel(
356
+ float *dense_A, // [batch_size, A_num_block, 32]
357
+ int *indices, // [batch_size, num_block]
358
+ float *sparse_C, // [batch_size, num_block, 32, 32]
359
+ long batch_size,
360
+ long A_num_block,
361
+ long B_num_block,
362
+ long num_block
363
+ ) {
364
+
365
+ long batch_idx = blockIdx.y;
366
+ long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
367
+
368
+ long thread_idx = threadIdx.x;
369
+
370
+ long batch_idx__block_idx = batch_idx * num_block + block_idx;
371
+
372
+ long AB_block_idx = indices[batch_idx__block_idx];
373
+ float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
374
+ float *sparse_C_pt = &sparse_C[(batch_idx * num_block + block_idx) * 1024];
375
+
376
+ float value = dense_A_pt[thread_idx];
377
+
378
+ #pragma unroll
379
+ for (int i = 0; i < 32; i++) {
380
+ sparse_C_pt[i * 32 + thread_idx] = value;
381
+ }
382
+
383
+ }
phivenv/Lib/site-packages/transformers/kernels/mra/cuda_kernel.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #define WARP_SIZE 32
3
+ #define FULL_MASK 0xffffffff
4
+ #define OPTIMAL_THREADS 256
5
+
6
+ __global__ void index_max_cuda_kernel(
7
+ float *index_vals, // [batch_size, 32, num_block]
8
+ int *indices, // [batch_size, num_block]
9
+ float *max_vals, // [batch_size, A_num_block * 32]
10
+ float *max_vals_scatter, // [batch_size, 32, num_block]
11
+ long batch_size,
12
+ long A_num_block,
13
+ long B_num_block,
14
+ long num_block
15
+ );
16
+
17
+ __global__ void mm_to_sparse_cuda_kernel(
18
+ float *dense_A, // [batch_size, A_num_block, dim, 32]
19
+ float *dense_B, // [batch_size, B_num_block, dim, 32]
20
+ int *indices, // [batch_size, num_block]
21
+ float *sparse_C, // [batch_size, num_block, 32, 32]
22
+ long batch_size,
23
+ long A_num_block,
24
+ long B_num_block,
25
+ long dim,
26
+ long num_block
27
+ );
28
+
29
+ __global__ void sparse_dense_mm_cuda_kernel(
30
+ float *sparse_A, // [batch_size, num_block, 32, 32]
31
+ int *indices, // [batch_size, num_block]
32
+ float *dense_B, // [batch_size, B_num_block, dim, 32]
33
+ float *dense_C, // [batch_size, A_num_block, dim, 32]
34
+ long batch_size,
35
+ long A_num_block,
36
+ long B_num_block,
37
+ long dim,
38
+ long num_block
39
+ );
40
+
41
+ __global__ void reduce_sum_cuda_kernel(
42
+ float *sparse_A, // [batch_size, num_block, 32, 32]
43
+ int *indices, // [batch_size, num_block]
44
+ float *dense_C, // [batch_size, A_num_block, 32]
45
+ long batch_size,
46
+ long A_num_block,
47
+ long B_num_block,
48
+ long num_block
49
+ );
50
+
51
+ __global__ void scatter_cuda_kernel(
52
+ float *dense_A, // [batch_size, A_num_block, 32]
53
+ int *indices, // [batch_size, num_block]
54
+ float *sparse_C, // [batch_size, num_block, 32, 32]
55
+ long batch_size,
56
+ long A_num_block,
57
+ long B_num_block,
58
+ long num_block
59
+ );
phivenv/Lib/site-packages/transformers/kernels/mra/cuda_launch.cu ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include "cuda_launch.h"
4
+ #include "cuda_kernel.h"
5
+ #include <vector>
6
+
7
+ //////////////////////////////////////////////////////////////////////////////////////////////////
8
+ //////////////////////////////////////////////////////////////////////////////////////////////////
9
+
10
+ std::vector<at::Tensor> index_max_kernel(
11
+ at::Tensor index_vals, // [batch_size, 32, num_block]
12
+ at::Tensor indices, // [batch_size, num_block],
13
+ int A_num_block,
14
+ int B_num_block
15
+ ) {
16
+ int batch_size = indices.size(0);
17
+ int num_block = indices.size(1);
18
+
19
+ at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options());
20
+ at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options());
21
+
22
+ dim3 threads(256);
23
+ dim3 blocks(batch_size);
24
+ int shared_mem = A_num_block * 32 * sizeof(float);
25
+
26
+ index_max_cuda_kernel<<<blocks, threads, shared_mem>>>(
27
+ index_vals.data_ptr<float>(),
28
+ indices.data_ptr<int>(),
29
+ max_vals.data_ptr<float>(),
30
+ max_vals_scatter.data_ptr<float>(),
31
+ batch_size,
32
+ A_num_block,
33
+ B_num_block,
34
+ num_block
35
+ );
36
+
37
+ return {max_vals, max_vals_scatter};
38
+ }
39
+
40
+ at::Tensor mm_to_sparse_kernel(
41
+ at::Tensor dense_A, // [batch_size, A_num_block, dim, 32]
42
+ at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
43
+ at::Tensor indices // [batch_size, num_block]
44
+ ) {
45
+ int batch_size = dense_A.size(0);
46
+ int A_num_block = dense_A.size(1);
47
+ int B_num_block = dense_B.size(1);
48
+ int dim = dense_A.size(2);
49
+ int num_block = indices.size(1);
50
+
51
+ at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
52
+
53
+ dim3 threads(64, 4);
54
+ dim3 blocks(num_block / 4, batch_size);
55
+
56
+ mm_to_sparse_cuda_kernel<<<blocks, threads>>>(
57
+ dense_A.data_ptr<float>(),
58
+ dense_B.data_ptr<float>(),
59
+ indices.data_ptr<int>(),
60
+ sparse_C.data_ptr<float>(),
61
+ batch_size,
62
+ A_num_block,
63
+ B_num_block,
64
+ dim,
65
+ num_block
66
+ );
67
+
68
+ return sparse_C;
69
+ }
70
+
71
+ at::Tensor sparse_dense_mm_kernel(
72
+ at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
73
+ at::Tensor indices, // [batch_size, num_block]
74
+ at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
75
+ int A_num_block
76
+ ) {
77
+ int batch_size = sparse_A.size(0);
78
+ int num_block = sparse_A.size(1);
79
+ int B_num_block = dense_B.size(1);
80
+ int dim = dense_B.size(2);
81
+
82
+ at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options());
83
+
84
+ dim3 threads(128, 2);
85
+ dim3 blocks(num_block / 2, batch_size);
86
+
87
+ sparse_dense_mm_cuda_kernel<<<blocks, threads>>>(
88
+ sparse_A.data_ptr<float>(),
89
+ indices.data_ptr<int>(),
90
+ dense_B.data_ptr<float>(),
91
+ dense_C.data_ptr<float>(),
92
+ batch_size,
93
+ A_num_block,
94
+ B_num_block,
95
+ dim,
96
+ num_block
97
+ );
98
+
99
+ return dense_C;
100
+ }
101
+
102
+ at::Tensor reduce_sum_kernel(
103
+ at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
104
+ at::Tensor indices, // [batch_size, num_block]
105
+ int A_num_block,
106
+ int B_num_block
107
+ ) {
108
+ int batch_size = sparse_A.size(0);
109
+ int num_block = sparse_A.size(1);
110
+
111
+ at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options());
112
+
113
+ dim3 threads(32, 4);
114
+ dim3 blocks(num_block / 4, batch_size);
115
+
116
+ reduce_sum_cuda_kernel<<<blocks, threads>>>(
117
+ sparse_A.data_ptr<float>(),
118
+ indices.data_ptr<int>(),
119
+ dense_C.data_ptr<float>(),
120
+ batch_size,
121
+ A_num_block,
122
+ B_num_block,
123
+ num_block
124
+ );
125
+
126
+ return dense_C;
127
+ }
128
+
129
+ at::Tensor scatter_kernel(
130
+ at::Tensor dense_A, // [batch_size, A_num_block, 32]
131
+ at::Tensor indices, // [batch_size, num_block]
132
+ int B_num_block
133
+ ) {
134
+ int batch_size = dense_A.size(0);
135
+ int A_num_block = dense_A.size(1);
136
+ int num_block = indices.size(1);
137
+
138
+ at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
139
+
140
+ dim3 threads(32, 4);
141
+ dim3 blocks(num_block / 4, batch_size);
142
+
143
+ scatter_cuda_kernel<<<blocks, threads>>>(
144
+ dense_A.data_ptr<float>(),
145
+ indices.data_ptr<int>(),
146
+ sparse_C.data_ptr<float>(),
147
+ batch_size,
148
+ A_num_block,
149
+ B_num_block,
150
+ num_block
151
+ );
152
+
153
+ return sparse_C;
154
+ }
phivenv/Lib/site-packages/transformers/kernels/mra/cuda_launch.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include <vector>
4
+
5
+ #define min(a, b) ((a)<(b)?(a):(b))
6
+ #define max(a, b) ((a)>(b)?(a):(b))
7
+
8
+ std::vector<at::Tensor> index_max_kernel(
9
+ at::Tensor index_vals,
10
+ at::Tensor indices,
11
+ int A_num_block,
12
+ int B_num_block
13
+ );
14
+
15
+ at::Tensor mm_to_sparse_kernel(
16
+ at::Tensor dense_A,
17
+ at::Tensor dense_B,
18
+ at::Tensor indices
19
+ );
20
+
21
+ at::Tensor sparse_dense_mm_kernel(
22
+ at::Tensor sparse_A,
23
+ at::Tensor indices,
24
+ at::Tensor dense_B,
25
+ int A_num_block
26
+ );
27
+
28
+ at::Tensor reduce_sum_kernel(
29
+ at::Tensor sparse_A,
30
+ at::Tensor indices,
31
+ int A_num_block,
32
+ int B_num_block
33
+ );
34
+
35
+ at::Tensor scatter_kernel(
36
+ at::Tensor dense_A,
37
+ at::Tensor indices,
38
+ int B_num_block
39
+ );
phivenv/Lib/site-packages/transformers/kernels/mra/torch_extension.cpp ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include "cuda_launch.h"
4
+ #include <vector>
5
+
6
+ std::vector<at::Tensor> index_max(
7
+ at::Tensor index_vals,
8
+ at::Tensor indices,
9
+ int A_num_block,
10
+ int B_num_block
11
+ ) {
12
+ return index_max_kernel(
13
+ index_vals,
14
+ indices,
15
+ A_num_block,
16
+ B_num_block
17
+ );
18
+ }
19
+
20
+ at::Tensor mm_to_sparse(
21
+ at::Tensor dense_A,
22
+ at::Tensor dense_B,
23
+ at::Tensor indices
24
+ ) {
25
+ return mm_to_sparse_kernel(
26
+ dense_A,
27
+ dense_B,
28
+ indices
29
+ );
30
+ }
31
+
32
+ at::Tensor sparse_dense_mm(
33
+ at::Tensor sparse_A,
34
+ at::Tensor indices,
35
+ at::Tensor dense_B,
36
+ int A_num_block
37
+ ) {
38
+ return sparse_dense_mm_kernel(
39
+ sparse_A,
40
+ indices,
41
+ dense_B,
42
+ A_num_block
43
+ );
44
+ }
45
+
46
+ at::Tensor reduce_sum(
47
+ at::Tensor sparse_A,
48
+ at::Tensor indices,
49
+ int A_num_block,
50
+ int B_num_block
51
+ ) {
52
+ return reduce_sum_kernel(
53
+ sparse_A,
54
+ indices,
55
+ A_num_block,
56
+ B_num_block
57
+ );
58
+ }
59
+
60
+ at::Tensor scatter(
61
+ at::Tensor dense_A,
62
+ at::Tensor indices,
63
+ int B_num_block
64
+ ) {
65
+ return scatter_kernel(
66
+ dense_A,
67
+ indices,
68
+ B_num_block
69
+ );
70
+ }
71
+
72
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
73
+ m.def("index_max", &index_max, "index_max (CUDA)");
74
+ m.def("mm_to_sparse", &mm_to_sparse, "mm_to_sparse (CUDA)");
75
+ m.def("sparse_dense_mm", &sparse_dense_mm, "sparse_dense_mm (CUDA)");
76
+ m.def("reduce_sum", &reduce_sum, "reduce_sum (CUDA)");
77
+ m.def("scatter", &scatter, "scatter (CUDA)");
78
+ }
phivenv/Lib/site-packages/transformers/kernels/rwkv/wkv_cuda.cu ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+
4
+ #define MIN_VALUE (-1e38)
5
+
6
+ template <typename F>
7
+ __global__ void kernel_forward(
8
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
9
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y
10
+ ) {
11
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
12
+ const int _b = idx / C;
13
+ const int _c = idx % C;
14
+ const int _offset = _b * T * C + _c;
15
+
16
+ F u = _u[_c];
17
+ F w = _w[_c];
18
+ const F *__restrict__ const k = _k + _offset;
19
+ const F *__restrict__ const v = _v + _offset;
20
+ F *__restrict__ const y = _y + _offset;
21
+
22
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
23
+ F aa = 0, bb = 0, pp = MIN_VALUE;
24
+ for (int i = 0; i < T; i++) {
25
+ const int ii = i * C;
26
+ const F kk = k[ii];
27
+ const F vv = v[ii];
28
+
29
+ F ww = u + kk;
30
+ F p = max(pp, ww);
31
+ F e1 = exp(pp - p);
32
+ F e2 = exp(ww - p);
33
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
34
+
35
+ ww = w + pp;
36
+ p = max(ww, kk);
37
+ e1 = exp(ww - p);
38
+ e2 = exp(kk - p);
39
+ aa = e1 * aa + e2 * vv;
40
+ bb = e1 * bb + e2;
41
+ pp = p;
42
+ }
43
+ }
44
+
45
+ template <typename F>
46
+ __global__ void kernel_forward_with_state(
47
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
48
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s
49
+ ) {
50
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
51
+ const int _b = idx / C;
52
+ const int _c = idx % C;
53
+ const int _offset_s = _b * C * 3 + _c * 3;
54
+ const int _offset = _b * T * C + _c;
55
+
56
+ F u = _u[_c];
57
+ F w = _w[_c];
58
+ const F *__restrict__ const k = _k + _offset;
59
+ const F *__restrict__ const v = _v + _offset;
60
+ F *__restrict__ const y = _y + _offset;
61
+ F *__restrict__ const s = _s + _offset_s;
62
+
63
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
64
+ F aa = s[0], bb = s[1], pp = s[2];
65
+ for (int i = 0; i < T; i++) {
66
+ const int ii = i * C;
67
+ const F kk = k[ii];
68
+ const F vv = v[ii];
69
+
70
+ F ww = u + kk;
71
+ F p = max(pp, ww);
72
+ F e1 = exp(pp - p);
73
+ F e2 = exp(ww - p);
74
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
75
+
76
+ ww = w + pp;
77
+ p = max(ww, kk);
78
+ e1 = exp(ww - p);
79
+ e2 = exp(kk - p);
80
+ aa = e1 * aa + e2 * vv;
81
+ bb = e1 * bb + e2;
82
+ pp = p;
83
+ }
84
+ s[0] = aa;
85
+ s[1] = bb;
86
+ s[2] = pp;
87
+ }
88
+
89
+ template <typename F>
90
+ __global__ void kernel_backward(
91
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
92
+ const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y,
93
+ const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk,
94
+ F *__restrict__ const _gv
95
+ ) {
96
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
97
+ const int _b = idx / C;
98
+ const int _c = idx % C;
99
+ const int _offset = _b * T * C + _c;
100
+
101
+ F u = _u[_c];
102
+ F w = _w[_c];
103
+ const F *__restrict__ const k = _k + _offset;
104
+ const F *__restrict__ const v = _v + _offset;
105
+ const F *__restrict__ const y = _y + _offset;
106
+ const F *__restrict__ const gy = _gy + _offset;
107
+ F *__restrict__ const gk = _gk + _offset;
108
+ F *__restrict__ const gv = _gv + _offset;
109
+
110
+ F q[Tmax], r[Tmax];
111
+
112
+ F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
113
+ for (int i = 0; i < T; i++) {
114
+ const int ii = i * C;
115
+ const F kk = k[ii];
116
+ const F vv = v[ii];
117
+ const F yy = y[ii];
118
+
119
+ F ww = u + kk;
120
+ F p = max(pp, ww);
121
+ F e1 = exp(pp - p);
122
+ F e2 = exp(ww - p);
123
+ const F qq = gy[ii] / (e1 * bb + e2);
124
+ gw += (ga - gb * yy) * e1 * qq;
125
+ gu += (vv - yy) * e2 * qq;
126
+ q[i] = qq;
127
+ r[i] = ww - p;
128
+
129
+ ww = w + pp;
130
+ p = max(ww, kk);
131
+ e1 = exp(ww - p);
132
+ e2 = exp(kk - p);
133
+ ga = e1 * (aa + ga);
134
+ gb = e1 * (bb + gb);
135
+ aa = e1 * aa + e2 * vv;
136
+ bb = e1 * bb + e2;
137
+ pp = p;
138
+ }
139
+ const int _offsetBC = _b * C + _c;
140
+ _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
141
+ _gu[_offsetBC] = gu;
142
+
143
+ aa = 0, bb = 0, pp = MIN_VALUE;
144
+ for (int i = T - 1; i >= 0; i--) {
145
+ const int ii = i * C;
146
+ const F kk = k[ii];
147
+ const F vv = v[ii];
148
+ const F yy = y[ii];
149
+ const F qq = q[i];
150
+ const F rr = r[i];
151
+
152
+ F e1 = qq * exp(rr);
153
+ F e2 = exp(kk + pp);
154
+ gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
155
+ gv[ii] = e1 + e2 * aa;
156
+
157
+ const F ww = w + pp;
158
+ const F www = rr - u - kk;
159
+ const F p = max(ww, www);
160
+ e1 = exp(ww - p);
161
+ e2 = qq * exp(www - p);
162
+ aa = e1 * aa + e2;
163
+ bb = e1 * bb - e2 * yy;
164
+ pp = p;
165
+ }
166
+ }
167
+
168
+ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
169
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
170
+ assert(B * C % threadsPerBlock.x == 0);
171
+ dim3 numBlocks(B * C / threadsPerBlock.x);
172
+ kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
173
+ }
174
+
175
+ void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) {
176
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
177
+ assert(B * C % threadsPerBlock.x == 0);
178
+ dim3 numBlocks(B * C / threadsPerBlock.x);
179
+ kernel_forward_with_state<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s);
180
+ }
181
+
182
+ void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
183
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
184
+ assert(B * C % threadsPerBlock.x == 0);
185
+ dim3 numBlocks(B * C / threadsPerBlock.x);
186
+ kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
187
+ }
phivenv/Lib/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ #define MIN_VALUE (-1e38)
5
+ typedef at::BFloat16 bf16;
6
+
7
+ __global__ void kernel_forward_bf16(
8
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
9
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y
10
+ ) {
11
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
12
+ const int _b = idx / C;
13
+ const int _c = idx % C;
14
+ const int _offset = _b * T * C + _c;
15
+
16
+ float u = float(_u[_c]);
17
+ float w = _w[_c];
18
+ const bf16 *__restrict__ const k = _k + _offset;
19
+ const bf16 *__restrict__ const v = _v + _offset;
20
+ bf16 *__restrict__ const y = _y + _offset;
21
+
22
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
23
+ float aa = 0, bb = 0, pp = MIN_VALUE;
24
+ for (int i = 0; i < T; i++) {
25
+ const int ii = i * C;
26
+ const float kk = float(k[ii]);
27
+ const float vv = float(v[ii]);
28
+
29
+ float ww = u + kk;
30
+ float p = max(pp, ww);
31
+ float e1 = exp(pp - p);
32
+ float e2 = exp(ww - p);
33
+ y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
34
+
35
+ ww = w + pp;
36
+ p = max(ww, kk);
37
+ e1 = exp(ww - p);
38
+ e2 = exp(kk - p);
39
+ aa = e1 * aa + e2 * vv;
40
+ bb = e1 * bb + e2;
41
+ pp = p;
42
+ }
43
+ }
44
+
45
+ __global__ void kernel_forward_with_state_bf16(
46
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
47
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y,
48
+ float *__restrict__ const _s
49
+ ) {
50
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
51
+ const int _b = idx / C;
52
+ const int _c = idx % C;
53
+ const int _offset_s = _b * C * 3 + _c * 3;
54
+ const int _offset = _b * T * C + _c;
55
+
56
+ float u = float(_u[_c]);
57
+ float w = _w[_c];
58
+ const bf16 *__restrict__ const k = _k + _offset;
59
+ const bf16 *__restrict__ const v = _v + _offset;
60
+ bf16 *__restrict__ const y = _y + _offset;
61
+ float *__restrict__ const s = _s + _offset_s;
62
+
63
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
64
+ float aa = s[0], bb = s[1], pp = s[2];
65
+ for (int i = 0; i < T; i++) {
66
+ const int ii = i * C;
67
+ const float kk = float(k[ii]);
68
+ const float vv = float(v[ii]);
69
+
70
+ float ww = u + kk;
71
+ float p = max(pp, ww);
72
+ float e1 = exp(pp - p);
73
+ float e2 = exp(ww - p);
74
+ y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2);
75
+
76
+ ww = w + pp;
77
+ p = max(ww, kk);
78
+ e1 = exp(ww - p);
79
+ e2 = exp(kk - p);
80
+ aa = e1 * aa + e2 * vv;
81
+ bb = e1 * bb + e2;
82
+ pp = p;
83
+ }
84
+ s[0] = aa;
85
+ s[1] = bb;
86
+ s[2] = pp;
87
+ }
88
+
89
+ __global__ void kernel_backward_bf16(
90
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
91
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y,
92
+ const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu,
93
+ bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv
94
+ ) {
95
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
96
+ const int _b = idx / C;
97
+ const int _c = idx % C;
98
+ const int _offset = _b * T * C + _c;
99
+
100
+ float u = float(_u[_c]);
101
+ float w = _w[_c];
102
+ const bf16 *__restrict__ const k = _k + _offset;
103
+ const bf16 *__restrict__ const v = _v + _offset;
104
+ const bf16 *__restrict__ const y = _y + _offset;
105
+ const bf16 *__restrict__ const gy = _gy + _offset;
106
+ bf16 *__restrict__ const gk = _gk + _offset;
107
+ bf16 *__restrict__ const gv = _gv + _offset;
108
+
109
+ float q[Tmax], r[Tmax];
110
+
111
+ float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
112
+ for (int i = 0; i < T; i++) {
113
+ const int ii = i * C;
114
+ const float kk = float(k[ii]);
115
+ const float vv = float(v[ii]);
116
+ const float yy = float(y[ii]);
117
+
118
+ float ww = u + kk;
119
+ float p = max(pp, ww);
120
+ float e1 = exp(pp - p);
121
+ float e2 = exp(ww - p);
122
+ const float qq = float(gy[ii]) / (e1 * bb + e2);
123
+ gw += (ga - gb * yy) * e1 * qq;
124
+ gu += (vv - yy) * e2 * qq;
125
+ q[i] = qq;
126
+ r[i] = ww - p;
127
+
128
+ ww = w + pp;
129
+ p = max(ww, kk);
130
+ e1 = exp(ww - p);
131
+ e2 = exp(kk - p);
132
+ ga = e1 * (aa + ga);
133
+ gb = e1 * (bb + gb);
134
+ aa = e1 * aa + e2 * vv;
135
+ bb = e1 * bb + e2;
136
+ pp = p;
137
+ }
138
+ const int _offsetBC = _b * C + _c;
139
+ _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
140
+ _gu[_offsetBC] = bf16(gu);
141
+
142
+ aa = 0, bb = 0, pp = MIN_VALUE;
143
+ for (int i = T - 1; i >= 0; i--) {
144
+ const int ii = i * C;
145
+ const float kk = float(k[ii]);
146
+ const float vv = float(v[ii]);
147
+ const float yy = float(y[ii]);
148
+ const float qq = q[i];
149
+ const float rr = r[i];
150
+
151
+ float e1 = qq * exp(rr);
152
+ float e2 = exp(kk + pp);
153
+ gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
154
+ gv[ii] = bf16(e1 + e2 * aa);
155
+
156
+ const float ww = w + pp;
157
+ const float www = rr - u - kk;
158
+ const float p = max(ww, www);
159
+ e1 = exp(ww - p);
160
+ e2 = qq * exp(www - p);
161
+ aa = e1 * aa + e2;
162
+ bb = e1 * bb - e2 * yy;
163
+ pp = p;
164
+ }
165
+ }
166
+
167
+ void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
168
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
169
+ assert(B * C % threadsPerBlock.x == 0);
170
+ dim3 numBlocks(B * C / threadsPerBlock.x);
171
+ kernel_forward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
172
+ }
173
+
174
+ void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) {
175
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
176
+ assert(B * C % threadsPerBlock.x == 0);
177
+ dim3 numBlocks(B * C / threadsPerBlock.x);
178
+ kernel_forward_with_state_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s);
179
+ }
180
+
181
+ void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
182
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
183
+ assert(B * C % threadsPerBlock.x == 0);
184
+ dim3 numBlocks(B * C / threadsPerBlock.x);
185
+ kernel_backward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
186
+ }
phivenv/Lib/site-packages/transformers/kernels/rwkv/wkv_op.cpp ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ typedef at::BFloat16 bf16;
4
+
5
+ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
6
+ void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
7
+ void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);
8
+ void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);
9
+ void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
10
+ void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
11
+
12
+ void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
13
+ const int B = k.size(0);
14
+ const int T = k.size(1);
15
+ const int C = k.size(2);
16
+ cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
17
+ }
18
+ void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
19
+ const int B = k.size(0);
20
+ const int T = k.size(1);
21
+ const int C = k.size(2);
22
+ cuda_forward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
23
+ }
24
+ void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
25
+ const int B = k.size(0);
26
+ const int T = k.size(1);
27
+ const int C = k.size(2);
28
+ cuda_forward_with_state(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), s.data_ptr<float>());
29
+ }
30
+ void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
31
+ const int B = k.size(0);
32
+ const int T = k.size(1);
33
+ const int C = k.size(2);
34
+ cuda_forward_with_state_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), s.data_ptr<float>());
35
+ }
36
+ void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
37
+ const int B = k.size(0);
38
+ const int T = k.size(1);
39
+ const int C = k.size(2);
40
+ cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
41
+ }
42
+ void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
43
+ const int B = k.size(0);
44
+ const int T = k.size(1);
45
+ const int C = k.size(2);
46
+ cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
47
+ gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
48
+ }
49
+
50
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
51
+ m.def("forward", &forward, "wkv forward");
52
+ m.def("forward_bf16", &forward_bf16, "wkv forward bf16");
53
+ m.def("forward_with_state", &forward_with_state, "wkv forward with state");
54
+ m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16");
55
+ m.def("backward", &backward, "wkv backward");
56
+ m.def("backward_bf16", &backward_bf16, "wkv backward bf16");
57
+ }
58
+
59
+ TORCH_LIBRARY(wkv, m) {
60
+ m.def("forward", forward);
61
+ m.def("forward_bf16", forward_bf16);
62
+ m.def("forward_with_state", forward_with_state);
63
+ m.def("forward_with_state_bf16", forward_with_state_bf16);
64
+ m.def("backward", backward);
65
+ m.def("backward_bf16", backward_bf16);
66
+ }
phivenv/Lib/site-packages/transformers/kernels/yoso/common.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #define min(a, b) ((a)<(b)?(a):(b))
3
+ #define max(a, b) ((a)>(b)?(a):(b))
4
+ #define ceil_divide(a, b) ((a)/(b)+((a)%(b)!=0))
5
+ #define select(cond, a, b) ((cond)?(a):(b))
6
+ #define PI 3.141592
7
+ #define EPSILON 1e-8
8
+ #define MAX_VAL 1e12
9
+ #define MIN_VAL -1e12
10
+ #define EMPTY_VALUE -1