koichi12 commited on
Commit
7754566
·
verified ·
1 Parent(s): 98595c0

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc +3 -0
  4. .venv/lib/python3.11/site-packages/transformers/__init__.py +0 -0
  5. .venv/lib/python3.11/site-packages/transformers/activations.py +239 -0
  6. .venv/lib/python3.11/site-packages/transformers/activations_tf.py +147 -0
  7. .venv/lib/python3.11/site-packages/transformers/audio_utils.py +1123 -0
  8. .venv/lib/python3.11/site-packages/transformers/cache_utils.py +0 -0
  9. .venv/lib/python3.11/site-packages/transformers/configuration_utils.py +1187 -0
  10. .venv/lib/python3.11/site-packages/transformers/convert_graph_to_onnx.py +551 -0
  11. .venv/lib/python3.11/site-packages/transformers/convert_pytorch_checkpoint_to_tf2.py +446 -0
  12. .venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py +1642 -0
  13. .venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizers_checkpoints_to_fast.py +130 -0
  14. .venv/lib/python3.11/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py +87 -0
  15. .venv/lib/python3.11/site-packages/transformers/debug_utils.py +346 -0
  16. .venv/lib/python3.11/site-packages/transformers/dependency_versions_check.py +63 -0
  17. .venv/lib/python3.11/site-packages/transformers/dependency_versions_table.py +102 -0
  18. .venv/lib/python3.11/site-packages/transformers/dynamic_module_utils.py +685 -0
  19. .venv/lib/python3.11/site-packages/transformers/feature_extraction_sequence_utils.py +372 -0
  20. .venv/lib/python3.11/site-packages/transformers/hf_argparser.py +437 -0
  21. .venv/lib/python3.11/site-packages/transformers/hyperparameter_search.py +141 -0
  22. .venv/lib/python3.11/site-packages/transformers/image_processing_base.py +559 -0
  23. .venv/lib/python3.11/site-packages/transformers/image_processing_utils.py +287 -0
  24. .venv/lib/python3.11/site-packages/transformers/image_processing_utils_fast.py +133 -0
  25. .venv/lib/python3.11/site-packages/transformers/image_transforms.py +860 -0
  26. .venv/lib/python3.11/site-packages/transformers/image_utils.py +871 -0
  27. .venv/lib/python3.11/site-packages/transformers/keras_callbacks.py +413 -0
  28. .venv/lib/python3.11/site-packages/transformers/kernels/__init__.py +0 -0
  29. .venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +1327 -0
  30. .venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/ms_deform_attn.h +61 -0
  31. .venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp +40 -0
  32. .venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h +32 -0
  33. .venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu +156 -0
  34. .venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh +1467 -0
  35. .venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h +29 -0
  36. .venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh +1327 -0
  37. .venv/lib/python3.11/site-packages/transformers/kernels/deta/ms_deform_attn.h +61 -0
  38. .venv/lib/python3.11/site-packages/transformers/kernels/deta/vision.cpp +16 -0
  39. .venv/lib/python3.11/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda.cu +187 -0
  41. .venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu +186 -0
  42. .venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_op.cpp +66 -0
  43. .venv/lib/python3.11/site-packages/transformers/kernels/yoso/common.h +10 -0
  44. .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation.cu +588 -0
  45. .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation.h +71 -0
  46. .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_cuda.cu +825 -0
  47. .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_cuda.h +157 -0
  48. .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_torch.cpp +128 -0
  49. .venv/lib/python3.11/site-packages/transformers/modelcard.py +908 -0
  50. .venv/lib/python3.11/site-packages/transformers/modeling_attn_mask_utils.py +481 -0
.gitattributes CHANGED
@@ -419,3 +419,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
419
  .venv/lib/python3.11/site-packages/jiter/jiter.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
420
  .venv/lib/python3.11/site-packages/idna/__pycache__/uts46data.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
421
  .venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
419
  .venv/lib/python3.11/site-packages/jiter/jiter.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
420
  .venv/lib/python3.11/site-packages/idna/__pycache__/uts46data.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
421
  .venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
422
+ .venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
423
+ .venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae8fd4ca816177bb2c6f471e0b6c7334eb9caa4704a71b0935e8abf1ca1a36d2
3
+ size 159566
.venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c669d4b4c31b91773ae2bd09aa2fd0eb809698068c77850d9ca93283b9acc875
3
+ size 277639
.venv/lib/python3.11/site-packages/transformers/__init__.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/transformers/activations.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from collections import OrderedDict
17
+
18
+ import torch
19
+ from packaging import version
20
+ from torch import Tensor, nn
21
+
22
+ from .utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class PytorchGELUTanh(nn.Module):
29
+ """
30
+ A fast C implementation of the tanh approximation of the GeLU activation function. See
31
+ https://arxiv.org/abs/1606.08415.
32
+
33
+ This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
34
+ match due to rounding errors.
35
+ """
36
+
37
+ def __init__(self):
38
+ super().__init__()
39
+ if version.parse(torch.__version__) < version.parse("1.12.0"):
40
+ raise ImportError(
41
+ f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
42
+ "PytorchGELUTanh. Please upgrade torch."
43
+ )
44
+
45
+ def forward(self, input: Tensor) -> Tensor:
46
+ return nn.functional.gelu(input, approximate="tanh")
47
+
48
+
49
+ class NewGELUActivation(nn.Module):
50
+ """
51
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
52
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
53
+ """
54
+
55
+ def forward(self, input: Tensor) -> Tensor:
56
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
57
+
58
+
59
+ class GELUActivation(nn.Module):
60
+ """
61
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
62
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
63
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
64
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
65
+ """
66
+
67
+ def __init__(self, use_gelu_python: bool = False):
68
+ super().__init__()
69
+ if use_gelu_python:
70
+ self.act = self._gelu_python
71
+ else:
72
+ self.act = nn.functional.gelu
73
+
74
+ def _gelu_python(self, input: Tensor) -> Tensor:
75
+ return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
76
+
77
+ def forward(self, input: Tensor) -> Tensor:
78
+ return self.act(input)
79
+
80
+
81
+ class FastGELUActivation(nn.Module):
82
+ """
83
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
84
+ """
85
+
86
+ def forward(self, input: Tensor) -> Tensor:
87
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
88
+
89
+
90
+ class QuickGELUActivation(nn.Module):
91
+ """
92
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
93
+ """
94
+
95
+ def forward(self, input: Tensor) -> Tensor:
96
+ return input * torch.sigmoid(1.702 * input)
97
+
98
+
99
+ class ClippedGELUActivation(nn.Module):
100
+ """
101
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
102
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
103
+ https://arxiv.org/abs/2004.09602.
104
+
105
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
106
+ initially created.
107
+
108
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
109
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
110
+ """
111
+
112
+ def __init__(self, min: float, max: float):
113
+ if min > max:
114
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
115
+
116
+ super().__init__()
117
+ self.min = min
118
+ self.max = max
119
+
120
+ def forward(self, x: Tensor) -> Tensor:
121
+ return torch.clip(gelu(x), self.min, self.max)
122
+
123
+
124
+ class AccurateGELUActivation(nn.Module):
125
+ """
126
+ Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
127
+ https://github.com/hendrycks/GELUs
128
+
129
+ Implemented along with MEGA (Moving Average Equipped Gated Attention)
130
+ """
131
+
132
+ def __init__(self):
133
+ super().__init__()
134
+ self.precomputed_constant = math.sqrt(2 / math.pi)
135
+
136
+ def forward(self, input: Tensor) -> Tensor:
137
+ return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
138
+
139
+
140
+ class MishActivation(nn.Module):
141
+ """
142
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
143
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
144
+ """
145
+
146
+ def __init__(self):
147
+ super().__init__()
148
+ if version.parse(torch.__version__) < version.parse("1.9.0"):
149
+ self.act = self._mish_python
150
+ else:
151
+ self.act = nn.functional.mish
152
+
153
+ def _mish_python(self, input: Tensor) -> Tensor:
154
+ return input * torch.tanh(nn.functional.softplus(input))
155
+
156
+ def forward(self, input: Tensor) -> Tensor:
157
+ return self.act(input)
158
+
159
+
160
+ class LinearActivation(nn.Module):
161
+ """
162
+ Applies the linear activation function, i.e. forwarding input directly to output.
163
+ """
164
+
165
+ def forward(self, input: Tensor) -> Tensor:
166
+ return input
167
+
168
+
169
+ class LaplaceActivation(nn.Module):
170
+ """
171
+ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
172
+ https://arxiv.org/abs/2209.10655
173
+
174
+ Inspired by squared relu, but with bounded range and gradient for better stability
175
+ """
176
+
177
+ def forward(self, input, mu=0.707107, sigma=0.282095):
178
+ input = (input - mu).div(sigma * math.sqrt(2.0))
179
+ return 0.5 * (1.0 + torch.erf(input))
180
+
181
+
182
+ class ReLUSquaredActivation(nn.Module):
183
+ """
184
+ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
185
+ """
186
+
187
+ def forward(self, input):
188
+ relu_applied = nn.functional.relu(input)
189
+ squared = torch.square(relu_applied)
190
+ return squared
191
+
192
+
193
+ class ClassInstantier(OrderedDict):
194
+ def __getitem__(self, key):
195
+ content = super().__getitem__(key)
196
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
197
+ return cls(**kwargs)
198
+
199
+
200
+ ACT2CLS = {
201
+ "gelu": GELUActivation,
202
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
203
+ "gelu_fast": FastGELUActivation,
204
+ "gelu_new": NewGELUActivation,
205
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
206
+ "gelu_pytorch_tanh": PytorchGELUTanh,
207
+ "gelu_accurate": AccurateGELUActivation,
208
+ "laplace": LaplaceActivation,
209
+ "leaky_relu": nn.LeakyReLU,
210
+ "linear": LinearActivation,
211
+ "mish": MishActivation,
212
+ "quick_gelu": QuickGELUActivation,
213
+ "relu": nn.ReLU,
214
+ "relu2": ReLUSquaredActivation,
215
+ "relu6": nn.ReLU6,
216
+ "sigmoid": nn.Sigmoid,
217
+ "silu": nn.SiLU,
218
+ "swish": nn.SiLU,
219
+ "tanh": nn.Tanh,
220
+ }
221
+ ACT2FN = ClassInstantier(ACT2CLS)
222
+
223
+
224
+ def get_activation(activation_string):
225
+ if activation_string in ACT2FN:
226
+ return ACT2FN[activation_string]
227
+ else:
228
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
229
+
230
+
231
+ # For backwards compatibility with: from activations import gelu_python
232
+ gelu_python = get_activation("gelu_python")
233
+ gelu_new = get_activation("gelu_new")
234
+ gelu = get_activation("gelu")
235
+ gelu_fast = get_activation("gelu_fast")
236
+ quick_gelu = get_activation("quick_gelu")
237
+ silu = get_activation("silu")
238
+ mish = get_activation("mish")
239
+ linear_act = get_activation("linear")
.venv/lib/python3.11/site-packages/transformers/activations_tf.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import tensorflow as tf
18
+ from packaging.version import parse
19
+
20
+
21
+ try:
22
+ import tf_keras as keras
23
+ except (ModuleNotFoundError, ImportError):
24
+ import keras
25
+
26
+ if parse(keras.__version__).major > 2:
27
+ raise ValueError(
28
+ "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
29
+ "Transformers. Please install the backwards-compatible tf-keras package with "
30
+ "`pip install tf-keras`."
31
+ )
32
+
33
+
34
+ def _gelu(x):
35
+ """
36
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
37
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
38
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
39
+ https://arxiv.org/abs/1606.08415
40
+ """
41
+ x = tf.convert_to_tensor(x)
42
+ cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
43
+
44
+ return x * cdf
45
+
46
+
47
+ def _gelu_new(x):
48
+ """
49
+ Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
50
+
51
+ Args:
52
+ x: float Tensor to perform activation
53
+
54
+ Returns:
55
+ `x` with the GELU activation applied.
56
+ """
57
+ x = tf.convert_to_tensor(x)
58
+ pi = tf.cast(math.pi, x.dtype)
59
+ coeff = tf.cast(0.044715, x.dtype)
60
+ cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
61
+
62
+ return x * cdf
63
+
64
+
65
+ def mish(x):
66
+ x = tf.convert_to_tensor(x)
67
+
68
+ return x * tf.tanh(tf.math.softplus(x))
69
+
70
+
71
+ def gelu_fast(x):
72
+ x = tf.convert_to_tensor(x)
73
+ coeff1 = tf.cast(0.044715, x.dtype)
74
+ coeff2 = tf.cast(0.7978845608, x.dtype)
75
+
76
+ return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
77
+
78
+
79
+ def quick_gelu(x):
80
+ x = tf.convert_to_tensor(x)
81
+ coeff = tf.cast(1.702, x.dtype)
82
+ return x * tf.math.sigmoid(coeff * x)
83
+
84
+
85
+ def gelu_10(x):
86
+ """
87
+ Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
88
+ it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
89
+ https://arxiv.org/abs/2004.09602
90
+
91
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
92
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
93
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
94
+ https://arxiv.org/abs/1606.08415 :param x: :return:
95
+ """
96
+ return tf.clip_by_value(_gelu(x), -10, 10)
97
+
98
+
99
+ def glu(x, axis=-1):
100
+ """
101
+ Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
102
+ the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
103
+
104
+ Args:
105
+ `x`: float Tensor to perform activation
106
+ `axis`: dimension across which `x` be split in half
107
+
108
+ Returns:
109
+ `x` with the GLU activation applied (with its size halved across the dimension `axis`).
110
+ """
111
+ a, b = tf.split(x, 2, axis=axis)
112
+ return a * tf.math.sigmoid(b)
113
+
114
+
115
+ if parse(tf.version.VERSION) >= parse("2.4"):
116
+
117
+ def approximate_gelu_wrap(x):
118
+ return keras.activations.gelu(x, approximate=True)
119
+
120
+ gelu = keras.activations.gelu
121
+ gelu_new = approximate_gelu_wrap
122
+ else:
123
+ gelu = _gelu
124
+ gelu_new = _gelu_new
125
+
126
+
127
+ ACT2FN = {
128
+ "gelu": gelu,
129
+ "gelu_10": gelu_10,
130
+ "gelu_fast": gelu_fast,
131
+ "gelu_new": gelu_new,
132
+ "glu": glu,
133
+ "mish": mish,
134
+ "quick_gelu": quick_gelu,
135
+ "relu": keras.activations.relu,
136
+ "sigmoid": keras.activations.sigmoid,
137
+ "silu": keras.activations.swish,
138
+ "swish": keras.activations.swish,
139
+ "tanh": keras.activations.tanh,
140
+ }
141
+
142
+
143
+ def get_tf_activation(activation_string):
144
+ if activation_string in ACT2FN:
145
+ return ACT2FN[activation_string]
146
+ else:
147
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
.venv/lib/python3.11/site-packages/transformers/audio_utils.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
17
+ and remove unnecessary dependencies.
18
+ """
19
+
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+
25
+
26
+ def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
27
+ """
28
+ Convert frequency from hertz to mels.
29
+
30
+ Args:
31
+ freq (`float` or `np.ndarray`):
32
+ The frequency, or multiple frequencies, in hertz (Hz).
33
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
34
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
35
+
36
+ Returns:
37
+ `float` or `np.ndarray`: The frequencies on the mel scale.
38
+ """
39
+
40
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
41
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
42
+
43
+ if mel_scale == "htk":
44
+ return 2595.0 * np.log10(1.0 + (freq / 700.0))
45
+ elif mel_scale == "kaldi":
46
+ return 1127.0 * np.log(1.0 + (freq / 700.0))
47
+
48
+ min_log_hertz = 1000.0
49
+ min_log_mel = 15.0
50
+ logstep = 27.0 / np.log(6.4)
51
+ mels = 3.0 * freq / 200.0
52
+
53
+ if isinstance(freq, np.ndarray):
54
+ log_region = freq >= min_log_hertz
55
+ mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
56
+ elif freq >= min_log_hertz:
57
+ mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
58
+
59
+ return mels
60
+
61
+
62
+ def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
63
+ """
64
+ Convert frequency from mels to hertz.
65
+
66
+ Args:
67
+ mels (`float` or `np.ndarray`):
68
+ The frequency, or multiple frequencies, in mels.
69
+ mel_scale (`str`, *optional*, `"htk"`):
70
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
71
+
72
+ Returns:
73
+ `float` or `np.ndarray`: The frequencies in hertz.
74
+ """
75
+
76
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
77
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
78
+
79
+ if mel_scale == "htk":
80
+ return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
81
+ elif mel_scale == "kaldi":
82
+ return 700.0 * (np.exp(mels / 1127.0) - 1.0)
83
+
84
+ min_log_hertz = 1000.0
85
+ min_log_mel = 15.0
86
+ logstep = np.log(6.4) / 27.0
87
+ freq = 200.0 * mels / 3.0
88
+
89
+ if isinstance(mels, np.ndarray):
90
+ log_region = mels >= min_log_mel
91
+ freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
92
+ elif mels >= min_log_mel:
93
+ freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
94
+
95
+ return freq
96
+
97
+
98
+ def hertz_to_octave(
99
+ freq: Union[float, np.ndarray], tuning: Optional[float] = 0.0, bins_per_octave: Optional[int] = 12
100
+ ):
101
+ """
102
+ Convert frequency from hertz to fractional octave numbers.
103
+ Adapted from *librosa*.
104
+
105
+ Args:
106
+ freq (`float` or `np.ndarray`):
107
+ The frequency, or multiple frequencies, in hertz (Hz).
108
+ tuning (`float`, defaults to `0.`):
109
+ Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
110
+ bins_per_octave (`int`, defaults to `12`):
111
+ Number of bins per octave.
112
+
113
+ Returns:
114
+ `float` or `np.ndarray`: The frequencies on the octave scale.
115
+ """
116
+ stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
117
+ octave = np.log2(freq / (float(stuttgart_pitch) / 16))
118
+ return octave
119
+
120
+
121
+ def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
122
+ """
123
+ Creates a triangular filter bank.
124
+
125
+ Adapted from *torchaudio* and *librosa*.
126
+
127
+ Args:
128
+ fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
129
+ Discrete frequencies of the FFT bins in Hz.
130
+ filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
131
+ Center frequencies of the triangular filters to create, in Hz.
132
+
133
+ Returns:
134
+ `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
135
+ """
136
+ filter_diff = np.diff(filter_freqs)
137
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
138
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
139
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
140
+ return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
141
+
142
+
143
+ def chroma_filter_bank(
144
+ num_frequency_bins: int,
145
+ num_chroma: int,
146
+ sampling_rate: int,
147
+ tuning: float = 0.0,
148
+ power: Optional[float] = 2.0,
149
+ weighting_parameters: Optional[Tuple[float]] = (5.0, 2),
150
+ start_at_c_chroma: Optional[bool] = True,
151
+ ):
152
+ """
153
+ Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.
154
+
155
+ Adapted from *librosa*.
156
+
157
+ Args:
158
+ num_frequency_bins (`int`):
159
+ Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
160
+ num_chroma (`int`):
161
+ Number of chroma bins (i.e pitch classes).
162
+ sampling_rate (`float`):
163
+ Sample rate of the audio waveform.
164
+ tuning (`float`):
165
+ Tuning deviation from A440 in fractions of a chroma bin.
166
+ power (`float`, *optional*, defaults to 2.0):
167
+ If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm.
168
+ weighting_parameters (`Tuple[float]`, *optional*, defaults to `(5., 2.)`):
169
+ If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and
170
+ the second element being the Gaussian half-width.
171
+ start_at_c_chroma (`float`, *optional*, defaults to `True`):
172
+ If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'.
173
+ Returns:
174
+ `np.ndarray` of shape `(num_frequency_bins, num_chroma)`
175
+ """
176
+ # Get the FFT bins, not counting the DC component
177
+ frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]
178
+
179
+ freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)
180
+
181
+ # make up a value for the 0 Hz bin = 1.5 octaves below bin 1
182
+ # (so chroma is 50% rotated from bin 1, and bin width is broad)
183
+ freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))
184
+
185
+ bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))
186
+
187
+ chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T
188
+
189
+ num_chroma2 = np.round(float(num_chroma) / 2)
190
+
191
+ # Project into range -num_chroma/2 .. num_chroma/2
192
+ # add on fixed offset of 10*num_chroma to ensure all values passed to
193
+ # rem are positive
194
+ chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2
195
+
196
+ # Gaussian bumps - 2*D to make them narrower
197
+ chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)
198
+
199
+ # normalize each column
200
+ if power is not None:
201
+ chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)
202
+
203
+ # Maybe apply scaling for fft bins
204
+ if weighting_parameters is not None:
205
+ center, half_width = weighting_parameters
206
+ chroma_filters *= np.tile(
207
+ np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
208
+ (num_chroma, 1),
209
+ )
210
+
211
+ if start_at_c_chroma:
212
+ chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)
213
+
214
+ # remove aliasing columns, copy to ensure row-contiguity
215
+ return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
216
+
217
+
218
+ def mel_filter_bank(
219
+ num_frequency_bins: int,
220
+ num_mel_filters: int,
221
+ min_frequency: float,
222
+ max_frequency: float,
223
+ sampling_rate: int,
224
+ norm: Optional[str] = None,
225
+ mel_scale: str = "htk",
226
+ triangularize_in_mel_space: bool = False,
227
+ ) -> np.ndarray:
228
+ """
229
+ Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
230
+ various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
231
+ are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
232
+ features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
233
+
234
+ Different banks of mel filters were introduced in the literature. The following variations are supported:
235
+
236
+ - MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
237
+ bandwidth of `[0, 4600]` Hz.
238
+ - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
239
+ bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
240
+ - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
241
+ speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
242
+ - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
243
+ 12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
244
+
245
+ This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
246
+ `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
247
+
248
+ Args:
249
+ num_frequency_bins (`int`):
250
+ Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
251
+ num_mel_filters (`int`):
252
+ Number of mel filters to generate.
253
+ min_frequency (`float`):
254
+ Lowest frequency of interest in Hz.
255
+ max_frequency (`float`):
256
+ Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
257
+ sampling_rate (`int`):
258
+ Sample rate of the audio waveform.
259
+ norm (`str`, *optional*):
260
+ If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
261
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
262
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
263
+ triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
264
+ If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
265
+ should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
266
+
267
+ Returns:
268
+ `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
269
+ projection matrix to go from a spectrogram to a mel spectrogram.
270
+ """
271
+ if norm is not None and norm != "slaney":
272
+ raise ValueError('norm must be one of None or "slaney"')
273
+
274
+ # center points of the triangular mel filters
275
+ mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
276
+ mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
277
+ mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
278
+ filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
279
+
280
+ if triangularize_in_mel_space:
281
+ # frequencies of FFT bins in Hz, but filters triangularized in mel space
282
+ fft_bin_width = sampling_rate / (num_frequency_bins * 2)
283
+ fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
284
+ filter_freqs = mel_freqs
285
+ else:
286
+ # frequencies of FFT bins in Hz
287
+ fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
288
+
289
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
290
+
291
+ if norm is not None and norm == "slaney":
292
+ # Slaney-style mel is scaled to be approx constant energy per channel
293
+ enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
294
+ mel_filters *= np.expand_dims(enorm, 0)
295
+
296
+ if (mel_filters.max(axis=0) == 0.0).any():
297
+ warnings.warn(
298
+ "At least one mel filter has all zero values. "
299
+ f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
300
+ f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
301
+ )
302
+
303
+ return mel_filters
304
+
305
+
306
+ def optimal_fft_length(window_length: int) -> int:
307
+ """
308
+ Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
309
+ already a power of two, rounds it up to the next power or two.
310
+
311
+ The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
312
+ of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
313
+ is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
314
+ it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
315
+ """
316
+ return 2 ** int(np.ceil(np.log2(window_length)))
317
+
318
+
319
+ def window_function(
320
+ window_length: int,
321
+ name: str = "hann",
322
+ periodic: bool = True,
323
+ frame_length: Optional[int] = None,
324
+ center: bool = True,
325
+ ) -> np.ndarray:
326
+ """
327
+ Returns an array containing the specified window. This window is intended to be used with `stft`.
328
+
329
+ The following window types are supported:
330
+
331
+ - `"boxcar"`: a rectangular window
332
+ - `"hamming"`: the Hamming window
333
+ - `"hann"`: the Hann window
334
+ - `"povey"`: the Povey window
335
+
336
+ Args:
337
+ window_length (`int`):
338
+ The length of the window in samples.
339
+ name (`str`, *optional*, defaults to `"hann"`):
340
+ The name of the window function.
341
+ periodic (`bool`, *optional*, defaults to `True`):
342
+ Whether the window is periodic or symmetric.
343
+ frame_length (`int`, *optional*):
344
+ The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
345
+ than the frame length, so that it will be zero-padded.
346
+ center (`bool`, *optional*, defaults to `True`):
347
+ Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
348
+
349
+ Returns:
350
+ `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
351
+ """
352
+ length = window_length + 1 if periodic else window_length
353
+
354
+ if name == "boxcar":
355
+ window = np.ones(length)
356
+ elif name in ["hamming", "hamming_window"]:
357
+ window = np.hamming(length)
358
+ elif name in ["hann", "hann_window"]:
359
+ window = np.hanning(length)
360
+ elif name in ["povey"]:
361
+ window = np.power(np.hanning(length), 0.85)
362
+ else:
363
+ raise ValueError(f"Unknown window function '{name}'")
364
+
365
+ if periodic:
366
+ window = window[:-1]
367
+
368
+ if frame_length is None:
369
+ return window
370
+
371
+ if window_length > frame_length:
372
+ raise ValueError(
373
+ f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
374
+ )
375
+
376
+ padded_window = np.zeros(frame_length)
377
+ offset = (frame_length - window_length) // 2 if center else 0
378
+ padded_window[offset : offset + window_length] = window
379
+ return padded_window
380
+
381
+
382
+ # TODO This method does not support batching yet as we are mainly focused on inference.
383
+ def spectrogram(
384
+ waveform: np.ndarray,
385
+ window: np.ndarray,
386
+ frame_length: int,
387
+ hop_length: int,
388
+ fft_length: Optional[int] = None,
389
+ power: Optional[float] = 1.0,
390
+ center: bool = True,
391
+ pad_mode: str = "reflect",
392
+ onesided: bool = True,
393
+ preemphasis: Optional[float] = None,
394
+ mel_filters: Optional[np.ndarray] = None,
395
+ mel_floor: float = 1e-10,
396
+ log_mel: Optional[str] = None,
397
+ reference: float = 1.0,
398
+ min_value: float = 1e-10,
399
+ db_range: Optional[float] = None,
400
+ remove_dc_offset: Optional[bool] = None,
401
+ dtype: np.dtype = np.float32,
402
+ ) -> np.ndarray:
403
+ """
404
+ Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
405
+
406
+ This function can create the following kinds of spectrograms:
407
+
408
+ - amplitude spectrogram (`power = 1.0`)
409
+ - power spectrogram (`power = 2.0`)
410
+ - complex-valued spectrogram (`power = None`)
411
+ - log spectrogram (use `log_mel` argument)
412
+ - mel spectrogram (provide `mel_filters`)
413
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
414
+
415
+ How this works:
416
+
417
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
418
+ - hop_length` samples.
419
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
420
+ 3. The DFT is taken of each windowed frame.
421
+ 4. The results are stacked into a spectrogram.
422
+
423
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
424
+
425
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
426
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
427
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
428
+
429
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
430
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
431
+ typically the next power of two.
432
+
433
+ Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
434
+ `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
435
+ can be constructed.
436
+
437
+ Args:
438
+ waveform (`np.ndarray` of shape `(length,)`):
439
+ The input waveform. This must be a single real-valued, mono waveform.
440
+ window (`np.ndarray` of shape `(frame_length,)`):
441
+ The windowing function to apply, including zero-padding if necessary. The actual window length may be
442
+ shorter than `frame_length`, but we're assuming the array has already been zero-padded.
443
+ frame_length (`int`):
444
+ The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
445
+ allow smaller sizes.
446
+ hop_length (`int`):
447
+ The stride between successive analysis frames in samples.
448
+ fft_length (`int`, *optional*):
449
+ The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
450
+ For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
451
+ power (`float`, *optional*, defaults to 1.0):
452
+ If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
453
+ complex numbers.
454
+ center (`bool`, *optional*, defaults to `True`):
455
+ Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
456
+ `t` will start at time `t * hop_length`.
457
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
458
+ Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
459
+ (pad with edge values), `"reflect"` (pads with mirrored values).
460
+ onesided (`bool`, *optional*, defaults to `True`):
461
+ If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
462
+ frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
463
+ preemphasis (`float`, *optional*)
464
+ Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
465
+ mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
466
+ The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
467
+ mel_floor (`float`, *optional*, defaults to 1e-10):
468
+ Minimum value of mel frequency banks.
469
+ log_mel (`str`, *optional*):
470
+ How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
471
+ the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
472
+ used when `power` is not `None`.
473
+ reference (`float`, *optional*, defaults to 1.0):
474
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
475
+ the loudest part to 0 dB. Must be greater than zero.
476
+ min_value (`float`, *optional*, defaults to `1e-10`):
477
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
478
+ `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
479
+ amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
480
+ db_range (`float`, *optional*):
481
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
482
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
483
+ remove_dc_offset (`bool`, *optional*):
484
+ Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
485
+ order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
486
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
487
+ Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
488
+ `np.complex64`.
489
+
490
+ Returns:
491
+ `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
492
+ `(num_mel_filters, length)` for a mel spectrogram.
493
+ """
494
+ window_length = len(window)
495
+
496
+ if fft_length is None:
497
+ fft_length = frame_length
498
+
499
+ if frame_length > fft_length:
500
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
501
+
502
+ if window_length != frame_length:
503
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
504
+
505
+ if hop_length <= 0:
506
+ raise ValueError("hop_length must be greater than zero")
507
+
508
+ if waveform.ndim != 1:
509
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
510
+
511
+ if np.iscomplexobj(waveform):
512
+ raise ValueError("Complex-valued input waveforms are not currently supported")
513
+
514
+ if power is None and mel_filters is not None:
515
+ raise ValueError(
516
+ "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
517
+ "Specify `power` to fix this issue."
518
+ )
519
+
520
+ # center pad the waveform
521
+ if center:
522
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
523
+ waveform = np.pad(waveform, padding, mode=pad_mode)
524
+
525
+ # promote to float64, since np.fft uses float64 internally
526
+ waveform = waveform.astype(np.float64)
527
+ window = window.astype(np.float64)
528
+
529
+ # split waveform into frames of frame_length size
530
+ num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
531
+
532
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
533
+ spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
534
+
535
+ # rfft is faster than fft
536
+ fft_func = np.fft.rfft if onesided else np.fft.fft
537
+ buffer = np.zeros(fft_length)
538
+
539
+ timestep = 0
540
+ for frame_idx in range(num_frames):
541
+ buffer[:frame_length] = waveform[timestep : timestep + frame_length]
542
+
543
+ if remove_dc_offset:
544
+ buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
545
+
546
+ if preemphasis is not None:
547
+ buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
548
+ buffer[0] *= 1 - preemphasis
549
+
550
+ buffer[:frame_length] *= window
551
+
552
+ spectrogram[frame_idx] = fft_func(buffer)
553
+ timestep += hop_length
554
+
555
+ # note: ** is much faster than np.power
556
+ if power is not None:
557
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
558
+
559
+ spectrogram = spectrogram.T
560
+
561
+ if mel_filters is not None:
562
+ spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
563
+
564
+ if power is not None and log_mel is not None:
565
+ if log_mel == "log":
566
+ spectrogram = np.log(spectrogram)
567
+ elif log_mel == "log10":
568
+ spectrogram = np.log10(spectrogram)
569
+ elif log_mel == "dB":
570
+ if power == 1.0:
571
+ spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
572
+ elif power == 2.0:
573
+ spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
574
+ else:
575
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
576
+ else:
577
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
578
+
579
+ spectrogram = np.asarray(spectrogram, dtype)
580
+
581
+ return spectrogram
582
+
583
+
584
+ def spectrogram_batch(
585
+ waveform_list: List[np.ndarray],
586
+ window: np.ndarray,
587
+ frame_length: int,
588
+ hop_length: int,
589
+ fft_length: Optional[int] = None,
590
+ power: Optional[float] = 1.0,
591
+ center: bool = True,
592
+ pad_mode: str = "reflect",
593
+ onesided: bool = True,
594
+ preemphasis: Optional[float] = None,
595
+ mel_filters: Optional[np.ndarray] = None,
596
+ mel_floor: float = 1e-10,
597
+ log_mel: Optional[str] = None,
598
+ reference: float = 1.0,
599
+ min_value: float = 1e-10,
600
+ db_range: Optional[float] = None,
601
+ remove_dc_offset: Optional[bool] = None,
602
+ dtype: np.dtype = np.float32,
603
+ ) -> List[np.ndarray]:
604
+ """
605
+ Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing.
606
+ This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting.
607
+
608
+ It supports generating various types of spectrograms:
609
+
610
+ - amplitude spectrogram (`power = 1.0`)
611
+ - power spectrogram (`power = 2.0`)
612
+ - complex-valued spectrogram (`power = None`)
613
+ - log spectrogram (use `log_mel` argument)
614
+ - mel spectrogram (provide `mel_filters`)
615
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
616
+
617
+ How this works:
618
+
619
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
620
+ - hop_length` samples.
621
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
622
+ 3. The DFT is taken of each windowed frame.
623
+ 4. The results are stacked into a spectrogram.
624
+
625
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
626
+
627
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
628
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
629
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
630
+
631
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
632
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
633
+ typically the next power of two.
634
+
635
+ Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`.
636
+
637
+ Args:
638
+ waveform_list (`List[np.ndarray]` with arrays of shape `(length,)`):
639
+ The list of input waveforms, each a single-channel (mono) signal.
640
+ window (`np.ndarray` of shape `(frame_length,)`):
641
+ The windowing function to apply, including zero-padding if necessary.
642
+ frame_length (`int`):
643
+ The length of each frame for analysis.
644
+ hop_length (`int`):
645
+ The step size between successive frames.
646
+ fft_length (`int`, *optional*):
647
+ The size of the FFT buffer, defining frequency bin resolution.
648
+ power (`float`, *optional*, defaults to 1.0):
649
+ Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex.
650
+ center (`bool`, *optional*, defaults to `True`):
651
+ Whether to center-pad the waveform frames.
652
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
653
+ The padding strategy when `center` is `True`.
654
+ onesided (`bool`, *optional*, defaults to `True`):
655
+ If True, returns a one-sided spectrogram for real input signals.
656
+ preemphasis (`float`, *optional*):
657
+ Applies a pre-emphasis filter to each frame.
658
+ mel_filters (`np.ndarray`, *optional*):
659
+ Mel filter bank for converting to mel spectrogram.
660
+ mel_floor (`float`, *optional*, defaults to 1e-10):
661
+ Floor value for mel spectrogram to avoid log(0).
662
+ log_mel (`str`, *optional*):
663
+ Specifies log scaling strategy; options are None, "log", "log10", "dB".
664
+ reference (`float`, *optional*, defaults to 1.0):
665
+ Reference value for dB conversion in log_mel.
666
+ min_value (`float`, *optional*, defaults to 1e-10):
667
+ Minimum floor value for log scale conversions.
668
+ db_range (`float`, *optional*):
669
+ Dynamic range for dB scale spectrograms.
670
+ remove_dc_offset (`bool`, *optional*):
671
+ Whether to remove the DC offset from each frame.
672
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
673
+ Data type of the output spectrogram.
674
+
675
+ Returns:
676
+ List[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform.
677
+ """
678
+ window_length = len(window)
679
+
680
+ if fft_length is None:
681
+ fft_length = frame_length
682
+
683
+ if frame_length > fft_length:
684
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
685
+
686
+ if window_length != frame_length:
687
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
688
+
689
+ if hop_length <= 0:
690
+ raise ValueError("hop_length must be greater than zero")
691
+
692
+ # Check the dimensions of the waveform , and if waveform is complex
693
+ for waveform in waveform_list:
694
+ if waveform.ndim != 1:
695
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
696
+ if np.iscomplexobj(waveform):
697
+ raise ValueError("Complex-valued input waveforms are not currently supported")
698
+ # Center pad the waveform
699
+ if center:
700
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
701
+ waveform_list = [
702
+ np.pad(
703
+ waveform,
704
+ padding,
705
+ mode=pad_mode,
706
+ )
707
+ for waveform in waveform_list
708
+ ]
709
+ original_waveform_lengths = [
710
+ len(waveform) for waveform in waveform_list
711
+ ] # these lengths will be used to remove padding later
712
+
713
+ # Batch pad the waveform
714
+ max_length = max(original_waveform_lengths)
715
+ padded_waveform_batch = np.array(
716
+ [
717
+ np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
718
+ for waveform in waveform_list
719
+ ],
720
+ dtype=dtype,
721
+ )
722
+
723
+ # Promote to float64, since np.fft uses float64 internally
724
+ padded_waveform_batch = padded_waveform_batch.astype(np.float64)
725
+ window = window.astype(np.float64)
726
+
727
+ # Split waveform into frames of frame_length size
728
+ num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length))
729
+ # these lengths will be used to remove padding later
730
+ true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths]
731
+ num_batches = padded_waveform_batch.shape[0]
732
+
733
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
734
+ spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64)
735
+
736
+ # rfft is faster than fft
737
+ fft_func = np.fft.rfft if onesided else np.fft.fft
738
+ buffer = np.zeros((num_batches, fft_length))
739
+
740
+ for frame_idx in range(num_frames):
741
+ timestep = frame_idx * hop_length
742
+ buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
743
+
744
+ if remove_dc_offset:
745
+ buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
746
+
747
+ if preemphasis is not None:
748
+ buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
749
+ buffer[:, 0] *= 1 - preemphasis
750
+
751
+ buffer[:, :frame_length] *= window
752
+
753
+ spectrogram[:, frame_idx] = fft_func(buffer)
754
+
755
+ # Note: ** is much faster than np.power
756
+ if power is not None:
757
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
758
+
759
+ # Apply mel filters if provided
760
+ if mel_filters is not None:
761
+ result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1]))
762
+ spectrogram = np.maximum(mel_floor, result)
763
+
764
+ # Convert to log scale if specified
765
+ if power is not None and log_mel is not None:
766
+ if log_mel == "log":
767
+ spectrogram = np.log(spectrogram)
768
+ elif log_mel == "log10":
769
+ spectrogram = np.log10(spectrogram)
770
+ elif log_mel == "dB":
771
+ if power == 1.0:
772
+ spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range)
773
+ elif power == 2.0:
774
+ spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range)
775
+ else:
776
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
777
+ else:
778
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
779
+
780
+ spectrogram = np.asarray(spectrogram, dtype)
781
+
782
+ spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
783
+
784
+ return spectrogram_list
785
+
786
+
787
+ def power_to_db(
788
+ spectrogram: np.ndarray,
789
+ reference: float = 1.0,
790
+ min_value: float = 1e-10,
791
+ db_range: Optional[float] = None,
792
+ ) -> np.ndarray:
793
+ """
794
+ Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
795
+ logarithm properties for numerical stability.
796
+
797
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
798
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
799
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
800
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
801
+
802
+ Based on the implementation of `librosa.power_to_db`.
803
+
804
+ Args:
805
+ spectrogram (`np.ndarray`):
806
+ The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
807
+ reference (`float`, *optional*, defaults to 1.0):
808
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
809
+ the loudest part to 0 dB. Must be greater than zero.
810
+ min_value (`float`, *optional*, defaults to `1e-10`):
811
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
812
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
813
+ db_range (`float`, *optional*):
814
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
815
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
816
+
817
+ Returns:
818
+ `np.ndarray`: the spectrogram in decibels
819
+ """
820
+ if reference <= 0.0:
821
+ raise ValueError("reference must be greater than zero")
822
+ if min_value <= 0.0:
823
+ raise ValueError("min_value must be greater than zero")
824
+
825
+ reference = max(min_value, reference)
826
+
827
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
828
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
829
+
830
+ if db_range is not None:
831
+ if db_range <= 0.0:
832
+ raise ValueError("db_range must be greater than zero")
833
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
834
+
835
+ return spectrogram
836
+
837
+
838
+ def power_to_db_batch(
839
+ spectrogram: np.ndarray,
840
+ reference: float = 1.0,
841
+ min_value: float = 1e-10,
842
+ db_range: Optional[float] = None,
843
+ ) -> np.ndarray:
844
+ """
845
+ Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
846
+ using basic logarithm properties for numerical stability.
847
+
848
+ This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram.
849
+
850
+ Args:
851
+ spectrogram (`np.ndarray`):
852
+ The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
853
+ Note that a power spectrogram has the amplitudes squared!
854
+ reference (`float`, *optional*, defaults to 1.0):
855
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
856
+ the loudest part to 0 dB. Must be greater than zero.
857
+ min_value (`float`, *optional*, defaults to `1e-10`):
858
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
859
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
860
+ db_range (`float`, *optional*):
861
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
862
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
863
+
864
+ Returns:
865
+ `np.ndarray`: the batch of spectrograms in decibels
866
+ """
867
+ if reference <= 0.0:
868
+ raise ValueError("reference must be greater than zero")
869
+ if min_value <= 0.0:
870
+ raise ValueError("min_value must be greater than zero")
871
+
872
+ reference = max(min_value, reference)
873
+
874
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
875
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
876
+
877
+ if db_range is not None:
878
+ if db_range <= 0.0:
879
+ raise ValueError("db_range must be greater than zero")
880
+ # Apply db_range clipping per batch item
881
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
882
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
883
+
884
+ return spectrogram
885
+
886
+
887
+ def amplitude_to_db(
888
+ spectrogram: np.ndarray,
889
+ reference: float = 1.0,
890
+ min_value: float = 1e-5,
891
+ db_range: Optional[float] = None,
892
+ ) -> np.ndarray:
893
+ """
894
+ Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
895
+ basic logarithm properties for numerical stability.
896
+
897
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
898
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
899
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
900
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
901
+
902
+ Args:
903
+ spectrogram (`np.ndarray`):
904
+ The input amplitude (mel) spectrogram.
905
+ reference (`float`, *optional*, defaults to 1.0):
906
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
907
+ the loudest part to 0 dB. Must be greater than zero.
908
+ min_value (`float`, *optional*, defaults to `1e-5`):
909
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
910
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
911
+ db_range (`float`, *optional*):
912
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
913
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
914
+
915
+ Returns:
916
+ `np.ndarray`: the spectrogram in decibels
917
+ """
918
+ if reference <= 0.0:
919
+ raise ValueError("reference must be greater than zero")
920
+ if min_value <= 0.0:
921
+ raise ValueError("min_value must be greater than zero")
922
+
923
+ reference = max(min_value, reference)
924
+
925
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
926
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
927
+
928
+ if db_range is not None:
929
+ if db_range <= 0.0:
930
+ raise ValueError("db_range must be greater than zero")
931
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
932
+
933
+ return spectrogram
934
+
935
+
936
+ def amplitude_to_db_batch(
937
+ spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: Optional[float] = None
938
+ ) -> np.ndarray:
939
+ """
940
+ Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
941
+ using basic logarithm properties for numerical stability.
942
+
943
+ The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram.
944
+
945
+ Args:
946
+ spectrogram (`np.ndarray`):
947
+ The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
948
+ reference (`float`, *optional*, defaults to 1.0):
949
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
950
+ the loudest part to 0 dB. Must be greater than zero.
951
+ min_value (`float`, *optional*, defaults to `1e-5`):
952
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
953
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
954
+ db_range (`float`, *optional*):
955
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
956
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
957
+
958
+ Returns:
959
+ `np.ndarray`: the batch of spectrograms in decibels
960
+ """
961
+ if reference <= 0.0:
962
+ raise ValueError("reference must be greater than zero")
963
+ if min_value <= 0.0:
964
+ raise ValueError("min_value must be greater than zero")
965
+
966
+ reference = max(min_value, reference)
967
+
968
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
969
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
970
+
971
+ if db_range is not None:
972
+ if db_range <= 0.0:
973
+ raise ValueError("db_range must be greater than zero")
974
+ # Apply db_range clipping per batch item
975
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
976
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
977
+
978
+ return spectrogram
979
+
980
+
981
+ ### deprecated functions below this line ###
982
+
983
+
984
+ def get_mel_filter_banks(
985
+ nb_frequency_bins: int,
986
+ nb_mel_filters: int,
987
+ frequency_min: float,
988
+ frequency_max: float,
989
+ sample_rate: int,
990
+ norm: Optional[str] = None,
991
+ mel_scale: str = "htk",
992
+ ) -> np.array:
993
+ warnings.warn(
994
+ "The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
995
+ FutureWarning,
996
+ )
997
+ return mel_filter_bank(
998
+ num_frequency_bins=nb_frequency_bins,
999
+ num_mel_filters=nb_mel_filters,
1000
+ min_frequency=frequency_min,
1001
+ max_frequency=frequency_max,
1002
+ sampling_rate=sample_rate,
1003
+ norm=norm,
1004
+ mel_scale=mel_scale,
1005
+ )
1006
+
1007
+
1008
+ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
1009
+ """
1010
+ In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed
1011
+ segments called `frames`.
1012
+
1013
+ The window length (window_length) defines how much of the signal is contained in each frame, while the hop length
1014
+ defines the step between the beginning of each new frame.
1015
+
1016
+
1017
+ Args:
1018
+ waveform (`np.array` of shape `(sample_length,)`):
1019
+ The raw waveform which will be split into smaller chunks.
1020
+ hop_length (`int`, *optional*, defaults to 160):
1021
+ Step between each window of the waveform.
1022
+ fft_window_size (`int`, *optional*, defaults to 400):
1023
+ Defines the size of the window.
1024
+ center (`bool`, defaults to `True`):
1025
+ Whether or not to center each frame around the middle of the frame. Centering is done by reflecting the
1026
+ waveform on the left and on the right.
1027
+
1028
+ Return:
1029
+ framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
1030
+ The framed waveforms that can be fed to `np.fft`.
1031
+ """
1032
+ warnings.warn(
1033
+ "The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
1034
+ FutureWarning,
1035
+ )
1036
+ frames = []
1037
+ for i in range(0, waveform.shape[0] + 1, hop_length):
1038
+ if center:
1039
+ half_window = (fft_window_size - 1) // 2 + 1
1040
+ start = i - half_window if i > half_window else 0
1041
+ end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
1042
+ frame = waveform[start:end]
1043
+ if start == 0:
1044
+ padd_width = (-i + half_window, 0)
1045
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
1046
+
1047
+ elif end == waveform.shape[0]:
1048
+ padd_width = (0, (i - waveform.shape[0] + half_window))
1049
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
1050
+
1051
+ else:
1052
+ frame = waveform[i : i + fft_window_size]
1053
+ frame_width = frame.shape[0]
1054
+ if frame_width < waveform.shape[0]:
1055
+ frame = np.lib.pad(
1056
+ frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0
1057
+ )
1058
+ frames.append(frame)
1059
+
1060
+ frames = np.stack(frames, 0)
1061
+ return frames
1062
+
1063
+
1064
+ def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
1065
+ """
1066
+ Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
1067
+ as `torch.stft`.
1068
+
1069
+ Args:
1070
+ frames (`np.array` of dimension `(num_frames, fft_window_size)`):
1071
+ A framed audio signal obtained using `audio_utils.fram_wav`.
1072
+ windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`:
1073
+ A array representing the function that will be used to reduces the amplitude of the discontinuities at the
1074
+ boundaries of each frame when computing the STFT. Each frame will be multiplied by the windowing_function.
1075
+ For more information on the discontinuities, called *Spectral leakage*, refer to [this
1076
+ tutorial]https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf
1077
+ fft_window_size (`int`, *optional*):
1078
+ Size of the window om which the Fourier transform is applied. This controls the frequency resolution of the
1079
+ spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. The number of
1080
+ frequency bins (`nb_frequency_bins`) used to divide the window into equal strips is equal to
1081
+ `(1+fft_window_size)//2`. An increase of the fft_window_size slows the calculus time proportionnally.
1082
+
1083
+ Example:
1084
+
1085
+ ```python
1086
+ >>> from transformers.audio_utils import stft, fram_wave
1087
+ >>> import numpy as np
1088
+
1089
+ >>> audio = np.random.rand(50)
1090
+ >>> fft_window_size = 10
1091
+ >>> hop_length = 2
1092
+ >>> framed_audio = fram_wave(audio, hop_length, fft_window_size)
1093
+ >>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1))
1094
+ ```
1095
+
1096
+ Returns:
1097
+ spectrogram (`np.ndarray`):
1098
+ A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm
1099
+ """
1100
+ warnings.warn(
1101
+ "The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
1102
+ FutureWarning,
1103
+ )
1104
+ frame_size = frames.shape[1]
1105
+
1106
+ if fft_window_size is None:
1107
+ fft_window_size = frame_size
1108
+
1109
+ if fft_window_size < frame_size:
1110
+ raise ValueError("FFT size must greater or equal the frame size")
1111
+ # number of FFT bins to store
1112
+ nb_frequency_bins = (fft_window_size >> 1) + 1
1113
+
1114
+ spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64)
1115
+ fft_signal = np.zeros(fft_window_size)
1116
+
1117
+ for f, frame in enumerate(frames):
1118
+ if windowing_function is not None:
1119
+ np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
1120
+ else:
1121
+ fft_signal[:frame_size] = frame
1122
+ spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
1123
+ return spectrogram.T
.venv/lib/python3.11/site-packages/transformers/cache_utils.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/transformers/configuration_utils.py ADDED
@@ -0,0 +1,1187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Configuration base class and utilities."""
17
+
18
+ import copy
19
+ import json
20
+ import os
21
+ import re
22
+ import warnings
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
+
25
+ from packaging import version
26
+
27
+ from . import __version__
28
+ from .dynamic_module_utils import custom_object_save
29
+ from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
30
+ from .utils import (
31
+ CONFIG_NAME,
32
+ PushToHubMixin,
33
+ add_model_info_to_auto_map,
34
+ add_model_info_to_custom_pipelines,
35
+ cached_file,
36
+ copy_func,
37
+ download_url,
38
+ extract_commit_hash,
39
+ is_remote_url,
40
+ is_torch_available,
41
+ logging,
42
+ )
43
+ from .utils.generic import is_timm_config_dict
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
49
+
50
+
51
+ class PretrainedConfig(PushToHubMixin):
52
+ # no-format
53
+ r"""
54
+ Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
55
+ methods for loading/downloading/saving configurations.
56
+
57
+ <Tip>
58
+
59
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
60
+ initialize a model does **not** load the model weights. It only affects the model's configuration.
61
+
62
+ </Tip>
63
+
64
+ Class attributes (overridden by derived classes):
65
+
66
+ - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
67
+ the correct object in [`~transformers.AutoConfig`].
68
+ - **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
69
+ config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
70
+ [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
71
+ - **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
72
+ outputs of the model during inference.
73
+ - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
74
+ naming of attributes.
75
+ - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
76
+ parallel plan applied to the sub-module when `model.tensor_parallel` is called.
77
+
78
+ Common attributes (present in all subclasses):
79
+
80
+ - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
81
+ embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
82
+ - **hidden_size** (`int`) -- The hidden size of the model.
83
+ - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
84
+ model.
85
+ - **num_hidden_layers** (`int`) -- The number of blocks in the model.
86
+
87
+ <Tip warning={true}>
88
+
89
+ Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading
90
+ some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set
91
+ them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more
92
+ information about the individual parameters.
93
+
94
+ </Tip>
95
+
96
+ Arg:
97
+ name_or_path (`str`, *optional*, defaults to `""`):
98
+ Store the string that was passed to [`PreTrainedModel.from_pretrained`] or
99
+ [`TFPreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` if the configuration was created
100
+ with such a method.
101
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
102
+ Whether or not the model should return all hidden-states.
103
+ output_attentions (`bool`, *optional*, defaults to `False`):
104
+ Whether or not the model should returns all attentions.
105
+ return_dict (`bool`, *optional*, defaults to `True`):
106
+ Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
107
+ is_encoder_decoder (`bool`, *optional*, defaults to `False`):
108
+ Whether the model is used as an encoder/decoder or not.
109
+ is_decoder (`bool`, *optional*, defaults to `False`):
110
+ Whether the model is used as decoder or not (in which case it's used as an encoder).
111
+ cross_attention_hidden_size** (`bool`, *optional*):
112
+ The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
113
+ setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
114
+ add_cross_attention (`bool`, *optional*, defaults to `False`):
115
+ Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
116
+ that can be used as decoder models within the [`EncoderDecoderModel`] class, which consists of all models
117
+ in `AUTO_MODELS_FOR_CAUSAL_LM`.
118
+ tie_encoder_decoder (`bool`, *optional*, defaults to `False`):
119
+ Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
120
+ and decoder model to have the exact same parameter names.
121
+ prune_heads (`Dict[int, List[int]]`, *optional*, defaults to `{}`):
122
+ Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of
123
+ heads to prune in said layer.
124
+
125
+ For instance `{1: [0, 2], 2: [2, 3]}` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
126
+ chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
127
+ The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
128
+ the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
129
+ sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
130
+ Forward Chunking work?](../glossary.html#feed-forward-chunking).
131
+
132
+ > Parameters for fine-tuning tasks
133
+
134
+ architectures (`List[str]`, *optional*):
135
+ Model architectures that can be used with the model pretrained weights.
136
+ finetuning_task (`str`, *optional*):
137
+ Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow
138
+ or PyTorch) checkpoint.
139
+ id2label (`Dict[int, str]`, *optional*):
140
+ A map from index (for instance prediction index, or target index) to label.
141
+ label2id (`Dict[str, int]`, *optional*): A map from label to index for the model.
142
+ num_labels (`int`, *optional*):
143
+ Number of labels to use in the last layer added to the model, typically for a classification task.
144
+ task_specific_params (`Dict[str, Any]`, *optional*):
145
+ Additional keyword arguments to store for the current task.
146
+ problem_type (`str`, *optional*):
147
+ Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
148
+ `"single_label_classification"` or `"multi_label_classification"`.
149
+
150
+ > Parameters linked to the tokenizer
151
+
152
+ tokenizer_class (`str`, *optional*):
153
+ The name of the associated tokenizer class to use (if none is set, will use the tokenizer associated to the
154
+ model by default).
155
+ prefix (`str`, *optional*):
156
+ A specific prompt that should be added at the beginning of each text before calling the model.
157
+ bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token.
158
+ pad_token_id (`int`, *optional*): The id of the _padding_ token.
159
+ eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token.
160
+ decoder_start_token_id (`int`, *optional*):
161
+ If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
162
+ sep_token_id (`int`, *optional*): The id of the _separation_ token.
163
+
164
+ > PyTorch specific parameters
165
+
166
+ torchscript (`bool`, *optional*, defaults to `False`):
167
+ Whether or not the model should be used with Torchscript.
168
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
169
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
170
+ model has a output word embedding layer.
171
+ torch_dtype (`str`, *optional*):
172
+ The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
173
+ (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
174
+ model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
175
+ `float16` weights. Since the config object is stored in plain text, this attribute contains just the
176
+ floating type string without the `torch.` prefix. For example, for `torch.float16` ``torch_dtype` is the
177
+ `"float16"` string.
178
+
179
+ This attribute is currently not being used during model loading time, but this may change in the future
180
+ versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
181
+
182
+ > TensorFlow specific parameters
183
+
184
+ use_bfloat16 (`bool`, *optional*, defaults to `False`):
185
+ Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
186
+ tf_legacy_loss (`bool`, *optional*, defaults to `False`):
187
+ Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
188
+ not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
189
+ v5.
190
+ loss_type (`str`, *optional*):
191
+ The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will
192
+ be automatically infered from the model architecture.
193
+ """
194
+
195
+ model_type: str = ""
196
+ base_config_key: str = ""
197
+ sub_configs: Dict[str, "PretrainedConfig"] = {}
198
+ is_composition: bool = False
199
+ attribute_map: Dict[str, str] = {}
200
+ base_model_tp_plan: Optional[Dict[str, Any]] = None
201
+ _auto_class: Optional[str] = None
202
+
203
+ def __setattr__(self, key, value):
204
+ if key in super().__getattribute__("attribute_map"):
205
+ key = super().__getattribute__("attribute_map")[key]
206
+ super().__setattr__(key, value)
207
+
208
+ def __getattribute__(self, key):
209
+ if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
210
+ key = super().__getattribute__("attribute_map")[key]
211
+ return super().__getattribute__(key)
212
+
213
+ def __init__(self, **kwargs):
214
+ # Attributes with defaults
215
+ self.return_dict = kwargs.pop("return_dict", True)
216
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
217
+ self.output_attentions = kwargs.pop("output_attentions", False)
218
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
219
+ self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
220
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
221
+ self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
222
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
223
+ self.tie_word_embeddings = kwargs.pop(
224
+ "tie_word_embeddings", True
225
+ ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
226
+ self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
227
+
228
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
229
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
230
+ self.is_decoder = kwargs.pop("is_decoder", False)
231
+ self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
232
+ self.add_cross_attention = kwargs.pop("add_cross_attention", False)
233
+ self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
234
+
235
+ # Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
236
+ # parameters, saving them will be deprecated. In a distant future, we won't need to load them.
237
+ for parameter_name, default_value in self._get_global_generation_defaults().items():
238
+ setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
239
+
240
+ # Fine-tuning task arguments
241
+ self.architectures = kwargs.pop("architectures", None)
242
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
243
+ self.id2label = kwargs.pop("id2label", None)
244
+ self.label2id = kwargs.pop("label2id", None)
245
+ if self.label2id is not None and not isinstance(self.label2id, dict):
246
+ raise ValueError("Argument label2id should be a dictionary.")
247
+ if self.id2label is not None:
248
+ if not isinstance(self.id2label, dict):
249
+ raise ValueError("Argument id2label should be a dictionary.")
250
+ num_labels = kwargs.pop("num_labels", None)
251
+ if num_labels is not None and len(self.id2label) != num_labels:
252
+ logger.warning(
253
+ f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
254
+ f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
255
+ )
256
+ self.id2label = {int(key): value for key, value in self.id2label.items()}
257
+ # Keys are always strings in JSON so convert ids to int here.
258
+ else:
259
+ self.num_labels = kwargs.pop("num_labels", 2)
260
+
261
+ if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
262
+ # we will start using self.torch_dtype in v5, but to be consistent with
263
+ # from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
264
+ if is_torch_available():
265
+ import torch
266
+
267
+ self.torch_dtype = getattr(torch, self.torch_dtype)
268
+
269
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
270
+ self.tokenizer_class = kwargs.pop("tokenizer_class", None)
271
+ self.prefix = kwargs.pop("prefix", None)
272
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
273
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
274
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
275
+ self.sep_token_id = kwargs.pop("sep_token_id", None)
276
+
277
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
278
+
279
+ # task specific arguments
280
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
281
+
282
+ # regression / multi-label classification
283
+ self.problem_type = kwargs.pop("problem_type", None)
284
+ allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
285
+ if self.problem_type is not None and self.problem_type not in allowed_problem_types:
286
+ raise ValueError(
287
+ f"The config parameter `problem_type` was not understood: received {self.problem_type} "
288
+ "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
289
+ )
290
+
291
+ # TPU arguments
292
+ if kwargs.pop("xla_device", None) is not None:
293
+ logger.warning(
294
+ "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
295
+ "safely remove it from your `config.json` file."
296
+ )
297
+
298
+ # Name or path to the pretrained checkpoint
299
+ self._name_or_path = str(kwargs.pop("name_or_path", ""))
300
+ # Config hash
301
+ self._commit_hash = kwargs.pop("_commit_hash", None)
302
+
303
+ # Attention implementation to use, if relevant.
304
+ self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
305
+ self._attn_implementation_autoset = False
306
+
307
+ # Drop the transformers version info
308
+ self.transformers_version = kwargs.pop("transformers_version", None)
309
+
310
+ # Deal with gradient checkpointing
311
+ if kwargs.get("gradient_checkpointing", False):
312
+ warnings.warn(
313
+ "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
314
+ "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
315
+ "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
316
+ )
317
+
318
+ # Additional attributes without default values
319
+ for key, value in kwargs.items():
320
+ try:
321
+ setattr(self, key, value)
322
+ except AttributeError as err:
323
+ logger.error(f"Can't set {key} with value {value} for {self}")
324
+ raise err
325
+
326
+ @property
327
+ def name_or_path(self) -> str:
328
+ return getattr(self, "_name_or_path", None)
329
+
330
+ @name_or_path.setter
331
+ def name_or_path(self, value):
332
+ self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
333
+
334
+ @property
335
+ def use_return_dict(self) -> bool:
336
+ """
337
+ `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
338
+ """
339
+ # If torchscript is set, force `return_dict=False` to avoid jit errors
340
+ return self.return_dict and not self.torchscript
341
+
342
+ @property
343
+ def num_labels(self) -> int:
344
+ """
345
+ `int`: The number of labels for classification models.
346
+ """
347
+ return len(self.id2label)
348
+
349
+ @num_labels.setter
350
+ def num_labels(self, num_labels: int):
351
+ if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
352
+ self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
353
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
354
+
355
+ @property
356
+ def _attn_implementation(self):
357
+ # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
358
+ if hasattr(self, "_attn_implementation_internal"):
359
+ if self._attn_implementation_internal is None:
360
+ # `config.attn_implementation` should never be None, for backward compatibility.
361
+ return "eager"
362
+ else:
363
+ return self._attn_implementation_internal
364
+ else:
365
+ return "eager"
366
+
367
+ @_attn_implementation.setter
368
+ def _attn_implementation(self, value):
369
+ self._attn_implementation_internal = value
370
+
371
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
372
+ """
373
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
374
+ [`~PretrainedConfig.from_pretrained`] class method.
375
+
376
+ Args:
377
+ save_directory (`str` or `os.PathLike`):
378
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
379
+ push_to_hub (`bool`, *optional*, defaults to `False`):
380
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
381
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
382
+ namespace).
383
+ kwargs (`Dict[str, Any]`, *optional*):
384
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
385
+ """
386
+ self._set_token_in_kwargs(kwargs)
387
+
388
+ if os.path.isfile(save_directory):
389
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
390
+
391
+ non_default_generation_parameters = self._get_non_default_generation_parameters()
392
+ if len(non_default_generation_parameters) > 0:
393
+ # TODO (joao): this should be an exception if the user has modified the loaded config. See #33886
394
+ warnings.warn(
395
+ "Some non-default generation parameters are set in the model config. These should go into either a) "
396
+ "`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
397
+ "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)."
398
+ "This warning will become an exception in the future."
399
+ f"\nNon-default generation parameters: {str(non_default_generation_parameters)}",
400
+ UserWarning,
401
+ )
402
+
403
+ os.makedirs(save_directory, exist_ok=True)
404
+
405
+ if push_to_hub:
406
+ commit_message = kwargs.pop("commit_message", None)
407
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
408
+ repo_id = self._create_repo(repo_id, **kwargs)
409
+ files_timestamps = self._get_files_timestamps(save_directory)
410
+
411
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
412
+ # loaded from the Hub.
413
+ if self._auto_class is not None:
414
+ custom_object_save(self, save_directory, config=self)
415
+
416
+ # If we save using the predefined names, we can load using `from_pretrained`
417
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
418
+
419
+ self.to_json_file(output_config_file, use_diff=True)
420
+ logger.info(f"Configuration saved in {output_config_file}")
421
+
422
+ if push_to_hub:
423
+ self._upload_modified_files(
424
+ save_directory,
425
+ repo_id,
426
+ files_timestamps,
427
+ commit_message=commit_message,
428
+ token=kwargs.get("token"),
429
+ )
430
+
431
+ @staticmethod
432
+ def _set_token_in_kwargs(kwargs, token=None):
433
+ """Temporary method to deal with `token` and `use_auth_token`.
434
+
435
+ This method is to avoid apply the same changes in all model config classes that overwrite `from_pretrained`.
436
+
437
+ Need to clean up `use_auth_token` in a follow PR.
438
+ """
439
+ # Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet.
440
+ if token is None:
441
+ token = kwargs.pop("token", None)
442
+ use_auth_token = kwargs.pop("use_auth_token", None)
443
+
444
+ if use_auth_token is not None:
445
+ warnings.warn(
446
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
447
+ FutureWarning,
448
+ )
449
+ if token is not None:
450
+ raise ValueError(
451
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
452
+ )
453
+ token = use_auth_token
454
+
455
+ if token is not None:
456
+ kwargs["token"] = token
457
+
458
+ @classmethod
459
+ def from_pretrained(
460
+ cls,
461
+ pretrained_model_name_or_path: Union[str, os.PathLike],
462
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
463
+ force_download: bool = False,
464
+ local_files_only: bool = False,
465
+ token: Optional[Union[str, bool]] = None,
466
+ revision: str = "main",
467
+ **kwargs,
468
+ ) -> "PretrainedConfig":
469
+ r"""
470
+ Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
471
+
472
+ Args:
473
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
474
+ This can be either:
475
+
476
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
477
+ huggingface.co.
478
+ - a path to a *directory* containing a configuration file saved using the
479
+ [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
480
+ - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
481
+ cache_dir (`str` or `os.PathLike`, *optional*):
482
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
483
+ standard cache should not be used.
484
+ force_download (`bool`, *optional*, defaults to `False`):
485
+ Whether or not to force to (re-)download the configuration files and override the cached versions if
486
+ they exist.
487
+ resume_download:
488
+ Deprecated and ignored. All downloads are now resumed by default when possible.
489
+ Will be removed in v5 of Transformers.
490
+ proxies (`Dict[str, str]`, *optional*):
491
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
492
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
493
+ token (`str` or `bool`, *optional*):
494
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
495
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
496
+ revision (`str`, *optional*, defaults to `"main"`):
497
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
498
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
499
+ identifier allowed by git.
500
+
501
+ <Tip>
502
+
503
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
504
+
505
+ </Tip>
506
+
507
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
508
+ If `False`, then this function returns just the final configuration object.
509
+
510
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
511
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
512
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
513
+ subfolder (`str`, *optional*, defaults to `""`):
514
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
515
+ specify the folder name here.
516
+ kwargs (`Dict[str, Any]`, *optional*):
517
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
518
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
519
+ by the `return_unused_kwargs` keyword parameter.
520
+
521
+ Returns:
522
+ [`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
523
+
524
+ Examples:
525
+
526
+ ```python
527
+ # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
528
+ # derived class: BertConfig
529
+ config = BertConfig.from_pretrained(
530
+ "google-bert/bert-base-uncased"
531
+ ) # Download configuration from huggingface.co and cache.
532
+ config = BertConfig.from_pretrained(
533
+ "./test/saved_model/"
534
+ ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
535
+ config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
536
+ config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
537
+ assert config.output_attentions == True
538
+ config, unused_kwargs = BertConfig.from_pretrained(
539
+ "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
540
+ )
541
+ assert config.output_attentions == True
542
+ assert unused_kwargs == {"foo": False}
543
+ ```"""
544
+ kwargs["cache_dir"] = cache_dir
545
+ kwargs["force_download"] = force_download
546
+ kwargs["local_files_only"] = local_files_only
547
+ kwargs["revision"] = revision
548
+
549
+ cls._set_token_in_kwargs(kwargs, token)
550
+
551
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
552
+ if cls.base_config_key and cls.base_config_key in config_dict:
553
+ config_dict = config_dict[cls.base_config_key]
554
+
555
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
556
+ # sometimes the config has no `base_config_key` if the config is used in several composite models
557
+ # e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
558
+ for k, v in config_dict.items():
559
+ if isinstance(v, dict) and v.get("model_type") == cls.model_type:
560
+ config_dict = v
561
+
562
+ # raise warning only if we still can't see a match in `model_type`
563
+ if config_dict["model_type"] != cls.model_type:
564
+ logger.warning(
565
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
566
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
567
+ )
568
+
569
+ return cls.from_dict(config_dict, **kwargs)
570
+
571
+ @classmethod
572
+ def get_config_dict(
573
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
574
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
575
+ """
576
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
577
+ [`PretrainedConfig`] using `from_dict`.
578
+
579
+ Parameters:
580
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
581
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
582
+
583
+ Returns:
584
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
585
+
586
+ """
587
+ cls._set_token_in_kwargs(kwargs)
588
+
589
+ original_kwargs = copy.deepcopy(kwargs)
590
+ # Get config dict associated with the base config file
591
+ config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
592
+ if config_dict is None:
593
+ return {}, kwargs
594
+ if "_commit_hash" in config_dict:
595
+ original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
596
+
597
+ # That config file may point us toward another config file to use.
598
+ if "configuration_files" in config_dict:
599
+ configuration_file = get_configuration_file(config_dict["configuration_files"])
600
+ config_dict, kwargs = cls._get_config_dict(
601
+ pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
602
+ )
603
+
604
+ return config_dict, kwargs
605
+
606
+ @classmethod
607
+ def _get_config_dict(
608
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
609
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
610
+ cache_dir = kwargs.pop("cache_dir", None)
611
+ force_download = kwargs.pop("force_download", False)
612
+ resume_download = kwargs.pop("resume_download", None)
613
+ proxies = kwargs.pop("proxies", None)
614
+ token = kwargs.pop("token", None)
615
+ local_files_only = kwargs.pop("local_files_only", False)
616
+ revision = kwargs.pop("revision", None)
617
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
618
+ subfolder = kwargs.pop("subfolder", "")
619
+ from_pipeline = kwargs.pop("_from_pipeline", None)
620
+ from_auto_class = kwargs.pop("_from_auto", False)
621
+ commit_hash = kwargs.pop("_commit_hash", None)
622
+
623
+ gguf_file = kwargs.get("gguf_file", None)
624
+
625
+ if trust_remote_code is True:
626
+ logger.warning(
627
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
628
+ " ignored."
629
+ )
630
+
631
+ user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
632
+ if from_pipeline is not None:
633
+ user_agent["using_pipeline"] = from_pipeline
634
+
635
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
636
+
637
+ is_local = os.path.isdir(pretrained_model_name_or_path)
638
+ if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
639
+ # Special case when pretrained_model_name_or_path is a local file
640
+ resolved_config_file = pretrained_model_name_or_path
641
+ is_local = True
642
+ elif is_remote_url(pretrained_model_name_or_path):
643
+ configuration_file = pretrained_model_name_or_path if gguf_file is None else gguf_file
644
+ resolved_config_file = download_url(pretrained_model_name_or_path)
645
+ else:
646
+ configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
647
+
648
+ try:
649
+ # Load from local folder or from cache or download from model Hub and cache
650
+ resolved_config_file = cached_file(
651
+ pretrained_model_name_or_path,
652
+ configuration_file,
653
+ cache_dir=cache_dir,
654
+ force_download=force_download,
655
+ proxies=proxies,
656
+ resume_download=resume_download,
657
+ local_files_only=local_files_only,
658
+ token=token,
659
+ user_agent=user_agent,
660
+ revision=revision,
661
+ subfolder=subfolder,
662
+ _commit_hash=commit_hash,
663
+ )
664
+ if resolved_config_file is None:
665
+ return None, kwargs
666
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
667
+ except EnvironmentError:
668
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
669
+ # the original exception.
670
+ raise
671
+ except Exception:
672
+ # For any other exception, we throw a generic error.
673
+ raise EnvironmentError(
674
+ f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
675
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
676
+ f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
677
+ f" containing a {configuration_file} file"
678
+ )
679
+
680
+ try:
681
+ if gguf_file:
682
+ config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
683
+ else:
684
+ # Load config dict
685
+ config_dict = cls._dict_from_json_file(resolved_config_file)
686
+
687
+ config_dict["_commit_hash"] = commit_hash
688
+ except (json.JSONDecodeError, UnicodeDecodeError):
689
+ raise EnvironmentError(
690
+ f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
691
+ )
692
+
693
+ if is_local:
694
+ logger.info(f"loading configuration file {resolved_config_file}")
695
+ else:
696
+ logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
697
+
698
+ if "auto_map" in config_dict and not is_local:
699
+ config_dict["auto_map"] = add_model_info_to_auto_map(
700
+ config_dict["auto_map"], pretrained_model_name_or_path
701
+ )
702
+ if "custom_pipelines" in config_dict and not is_local:
703
+ config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
704
+ config_dict["custom_pipelines"], pretrained_model_name_or_path
705
+ )
706
+
707
+ # timm models are not saved with the model_type in the config file
708
+ if "model_type" not in config_dict and is_timm_config_dict(config_dict):
709
+ config_dict["model_type"] = "timm_wrapper"
710
+
711
+ return config_dict, kwargs
712
+
713
+ @classmethod
714
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
715
+ """
716
+ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
717
+
718
+ Args:
719
+ config_dict (`Dict[str, Any]`):
720
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
721
+ retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
722
+ kwargs (`Dict[str, Any]`):
723
+ Additional parameters from which to initialize the configuration object.
724
+
725
+ Returns:
726
+ [`PretrainedConfig`]: The configuration object instantiated from those parameters.
727
+ """
728
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
729
+ # Those arguments may be passed along for our internal telemetry.
730
+ # We remove them so they don't appear in `return_unused_kwargs`.
731
+ kwargs.pop("_from_auto", None)
732
+ kwargs.pop("_from_pipeline", None)
733
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
734
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
735
+ kwargs["_commit_hash"] = config_dict["_commit_hash"]
736
+
737
+ # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
738
+ config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
739
+
740
+ config = cls(**config_dict)
741
+
742
+ if hasattr(config, "pruned_heads"):
743
+ config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
744
+
745
+ # Update config with kwargs if needed
746
+ if "num_labels" in kwargs and "id2label" in kwargs:
747
+ num_labels = kwargs["num_labels"]
748
+ id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
749
+ if len(id2label) != num_labels:
750
+ raise ValueError(
751
+ f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
752
+ f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
753
+ "one of them."
754
+ )
755
+ to_remove = []
756
+ for key, value in kwargs.items():
757
+ if hasattr(config, key):
758
+ current_attr = getattr(config, key)
759
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
760
+ if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
761
+ value = current_attr.__class__(**value)
762
+ setattr(config, key, value)
763
+ if key != "torch_dtype":
764
+ to_remove.append(key)
765
+ for key in to_remove:
766
+ kwargs.pop(key, None)
767
+
768
+ logger.info(f"Model config {config}")
769
+ if return_unused_kwargs:
770
+ return config, kwargs
771
+ else:
772
+ return config
773
+
774
+ @classmethod
775
+ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
776
+ """
777
+ Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
778
+
779
+ Args:
780
+ json_file (`str` or `os.PathLike`):
781
+ Path to the JSON file containing the parameters.
782
+
783
+ Returns:
784
+ [`PretrainedConfig`]: The configuration object instantiated from that JSON file.
785
+
786
+ """
787
+ config_dict = cls._dict_from_json_file(json_file)
788
+ return cls(**config_dict)
789
+
790
+ @classmethod
791
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
792
+ with open(json_file, "r", encoding="utf-8") as reader:
793
+ text = reader.read()
794
+ return json.loads(text)
795
+
796
+ def __eq__(self, other):
797
+ return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)
798
+
799
+ def __repr__(self):
800
+ return f"{self.__class__.__name__} {self.to_json_string()}"
801
+
802
+ def __iter__(self):
803
+ for attr in self.__dict__:
804
+ yield attr
805
+
806
+ def to_diff_dict(self) -> Dict[str, Any]:
807
+ """
808
+ Removes all attributes from config which correspond to the default config attributes for better readability and
809
+ serializes to a Python dictionary.
810
+
811
+ Returns:
812
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
813
+ """
814
+ config_dict = self.to_dict()
815
+
816
+ # get the default config dict
817
+ default_config_dict = PretrainedConfig().to_dict()
818
+
819
+ # get class specific config dict
820
+ class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
821
+
822
+ serializable_config_dict = {}
823
+
824
+ # only serialize values that differ from the default config
825
+ for key, value in config_dict.items():
826
+ if (
827
+ isinstance(getattr(self, key, None), PretrainedConfig)
828
+ and key in class_config_dict
829
+ and isinstance(class_config_dict[key], dict)
830
+ ):
831
+ # For nested configs we need to clean the diff recursively
832
+ diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
833
+ if "model_type" in value:
834
+ # Needs to be set even if it's not in the diff
835
+ diff["model_type"] = value["model_type"]
836
+ if len(diff) > 0:
837
+ serializable_config_dict[key] = diff
838
+ elif (
839
+ key not in default_config_dict
840
+ or key == "transformers_version"
841
+ or value != default_config_dict[key]
842
+ or (key in class_config_dict and value != class_config_dict[key])
843
+ ):
844
+ serializable_config_dict[key] = value
845
+
846
+ if hasattr(self, "quantization_config"):
847
+ serializable_config_dict["quantization_config"] = (
848
+ self.quantization_config.to_dict()
849
+ if not isinstance(self.quantization_config, dict)
850
+ else self.quantization_config
851
+ )
852
+
853
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
854
+ _ = serializable_config_dict.pop("_pre_quantization_dtype", None)
855
+
856
+ self.dict_torch_dtype_to_str(serializable_config_dict)
857
+
858
+ if "_attn_implementation_internal" in serializable_config_dict:
859
+ del serializable_config_dict["_attn_implementation_internal"]
860
+ # Do not serialize `base_model_tp_plan` for now
861
+ if "base_model_tp_plan" in serializable_config_dict:
862
+ del serializable_config_dict["base_model_tp_plan"]
863
+
864
+ return serializable_config_dict
865
+
866
+ def to_dict(self) -> Dict[str, Any]:
867
+ """
868
+ Serializes this instance to a Python dictionary.
869
+
870
+ Returns:
871
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
872
+ """
873
+ output = copy.deepcopy(self.__dict__)
874
+ if hasattr(self.__class__, "model_type"):
875
+ output["model_type"] = self.__class__.model_type
876
+ if "_auto_class" in output:
877
+ del output["_auto_class"]
878
+ if "_commit_hash" in output:
879
+ del output["_commit_hash"]
880
+ if "_attn_implementation_internal" in output:
881
+ del output["_attn_implementation_internal"]
882
+ # Do not serialize `base_model_tp_plan` for now
883
+ if "base_model_tp_plan" in output:
884
+ del output["base_model_tp_plan"]
885
+
886
+ # Transformers version when serializing the model
887
+ output["transformers_version"] = __version__
888
+
889
+ for key, value in output.items():
890
+ # Deal with nested configs like CLIP
891
+ if isinstance(value, PretrainedConfig):
892
+ value = value.to_dict()
893
+ del value["transformers_version"]
894
+
895
+ output[key] = value
896
+
897
+ if hasattr(self, "quantization_config"):
898
+ output["quantization_config"] = (
899
+ self.quantization_config.to_dict()
900
+ if not isinstance(self.quantization_config, dict)
901
+ else self.quantization_config
902
+ )
903
+
904
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
905
+ _ = output.pop("_pre_quantization_dtype", None)
906
+
907
+ self.dict_torch_dtype_to_str(output)
908
+
909
+ return output
910
+
911
+ def to_json_string(self, use_diff: bool = True) -> str:
912
+ """
913
+ Serializes this instance to a JSON string.
914
+
915
+ Args:
916
+ use_diff (`bool`, *optional*, defaults to `True`):
917
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
918
+ is serialized to JSON string.
919
+
920
+ Returns:
921
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
922
+ """
923
+ if use_diff is True:
924
+ config_dict = self.to_diff_dict()
925
+ else:
926
+ config_dict = self.to_dict()
927
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
928
+
929
+ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
930
+ """
931
+ Save this instance to a JSON file.
932
+
933
+ Args:
934
+ json_file_path (`str` or `os.PathLike`):
935
+ Path to the JSON file in which this configuration instance's parameters will be saved.
936
+ use_diff (`bool`, *optional*, defaults to `True`):
937
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
938
+ is serialized to JSON file.
939
+ """
940
+ with open(json_file_path, "w", encoding="utf-8") as writer:
941
+ writer.write(self.to_json_string(use_diff=use_diff))
942
+
943
+ def update(self, config_dict: Dict[str, Any]):
944
+ """
945
+ Updates attributes of this class with attributes from `config_dict`.
946
+
947
+ Args:
948
+ config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
949
+ """
950
+ for key, value in config_dict.items():
951
+ setattr(self, key, value)
952
+
953
+ def update_from_string(self, update_str: str):
954
+ """
955
+ Updates attributes of this class with attributes from `update_str`.
956
+
957
+ The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
958
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
959
+
960
+ The keys to change have to already exist in the config object.
961
+
962
+ Args:
963
+ update_str (`str`): String with attributes that should be updated for this class.
964
+
965
+ """
966
+
967
+ d = dict(x.split("=") for x in update_str.split(","))
968
+ for k, v in d.items():
969
+ if not hasattr(self, k):
970
+ raise ValueError(f"key {k} isn't in the original config dict")
971
+
972
+ old_v = getattr(self, k)
973
+ if isinstance(old_v, bool):
974
+ if v.lower() in ["true", "1", "y", "yes"]:
975
+ v = True
976
+ elif v.lower() in ["false", "0", "n", "no"]:
977
+ v = False
978
+ else:
979
+ raise ValueError(f"can't derive true or false from {v} (key {k})")
980
+ elif isinstance(old_v, int):
981
+ v = int(v)
982
+ elif isinstance(old_v, float):
983
+ v = float(v)
984
+ elif not isinstance(old_v, str):
985
+ raise TypeError(
986
+ f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
987
+ )
988
+
989
+ setattr(self, k, v)
990
+
991
+ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
992
+ """
993
+ Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
994
+ converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
995
+ string, which can then be stored in the json format.
996
+ """
997
+ if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
998
+ d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
999
+ for value in d.values():
1000
+ if isinstance(value, dict):
1001
+ self.dict_torch_dtype_to_str(value)
1002
+
1003
+ @classmethod
1004
+ def register_for_auto_class(cls, auto_class="AutoConfig"):
1005
+ """
1006
+ Register this class with a given auto class. This should only be used for custom configurations as the ones in
1007
+ the library are already mapped with `AutoConfig`.
1008
+
1009
+ <Tip warning={true}>
1010
+
1011
+ This API is experimental and may have some slight breaking changes in the next releases.
1012
+
1013
+ </Tip>
1014
+
1015
+ Args:
1016
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
1017
+ The auto class to register this new configuration with.
1018
+ """
1019
+ if not isinstance(auto_class, str):
1020
+ auto_class = auto_class.__name__
1021
+
1022
+ import transformers.models.auto as auto_module
1023
+
1024
+ if not hasattr(auto_module, auto_class):
1025
+ raise ValueError(f"{auto_class} is not a valid auto class.")
1026
+
1027
+ cls._auto_class = auto_class
1028
+
1029
+ @staticmethod
1030
+ def _get_global_generation_defaults() -> Dict[str, Any]:
1031
+ return {
1032
+ "max_length": 20,
1033
+ "min_length": 0,
1034
+ "do_sample": False,
1035
+ "early_stopping": False,
1036
+ "num_beams": 1,
1037
+ "num_beam_groups": 1,
1038
+ "diversity_penalty": 0.0,
1039
+ "temperature": 1.0,
1040
+ "top_k": 50,
1041
+ "top_p": 1.0,
1042
+ "typical_p": 1.0,
1043
+ "repetition_penalty": 1.0,
1044
+ "length_penalty": 1.0,
1045
+ "no_repeat_ngram_size": 0,
1046
+ "encoder_no_repeat_ngram_size": 0,
1047
+ "bad_words_ids": None,
1048
+ "num_return_sequences": 1,
1049
+ "output_scores": False,
1050
+ "return_dict_in_generate": False,
1051
+ "forced_bos_token_id": None,
1052
+ "forced_eos_token_id": None,
1053
+ "remove_invalid_values": False,
1054
+ "exponential_decay_length_penalty": None,
1055
+ "suppress_tokens": None,
1056
+ "begin_suppress_tokens": None,
1057
+ }
1058
+
1059
+ def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
1060
+ """
1061
+ Gets the non-default generation parameters on the PretrainedConfig instance
1062
+ """
1063
+ non_default_generation_parameters = {}
1064
+ decoder_attribute_name = None
1065
+
1066
+ # Composite models don't have a default config, use their decoder config as a fallback for default values
1067
+ # If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
1068
+ try:
1069
+ default_config = self.__class__()
1070
+ except ValueError:
1071
+ decoder_config = self.get_text_config(decoder=True)
1072
+ if decoder_config is not self:
1073
+ default_config = decoder_config.__class__()
1074
+ else:
1075
+ default_config = None
1076
+
1077
+ # If it is a composite model, we want to check the subconfig that will be used for generation
1078
+ self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
1079
+
1080
+ for parameter_name, default_global_value in self._get_global_generation_defaults().items():
1081
+ if hasattr(self_decoder_config, parameter_name):
1082
+ is_default_in_config = is_default_generation_value = None
1083
+ parameter_value = getattr(self_decoder_config, parameter_name)
1084
+ # Three cases in which is okay for the model config to hold generation config parameters:
1085
+ # 1. The parameter is set to `None`, effectivelly delegating its value to the generation config
1086
+ if parameter_value is None:
1087
+ continue
1088
+ # 2. If we have a default config, then the instance should hold the same generation defaults
1089
+ if default_config is not None:
1090
+ is_default_in_config = parameter_value == getattr(default_config, parameter_name)
1091
+ # 3. if we don't have a default config, then the instance should hold the global generation defaults
1092
+ else:
1093
+ is_default_generation_value = parameter_value == default_global_value
1094
+
1095
+ is_non_default = (is_default_in_config is False) or (
1096
+ is_default_in_config is None and is_default_generation_value is False
1097
+ )
1098
+ if is_non_default:
1099
+ non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name)
1100
+
1101
+ return non_default_generation_parameters
1102
+
1103
+ def get_text_config(self, decoder=False) -> "PretrainedConfig":
1104
+ """
1105
+ Returns the config that is meant to be used with text IO. On most models, it is the original config instance
1106
+ itself. On specific composite models, it is under a set of valid names.
1107
+
1108
+ If `decoder` is set to `True`, then only search for decoder config names.
1109
+ """
1110
+ decoder_possible_text_config_names = ("decoder", "generator", "text_config")
1111
+ encoder_possible_text_config_names = ("text_encoder",)
1112
+ if decoder:
1113
+ possible_text_config_names = decoder_possible_text_config_names
1114
+ else:
1115
+ possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1116
+
1117
+ valid_text_config_names = []
1118
+ for text_config_name in possible_text_config_names:
1119
+ if hasattr(self, text_config_name):
1120
+ text_config = getattr(self, text_config_name, None)
1121
+ if text_config is not None:
1122
+ valid_text_config_names += [text_config_name]
1123
+
1124
+ if len(valid_text_config_names) > 1:
1125
+ raise ValueError(
1126
+ f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
1127
+ "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
1128
+ )
1129
+ elif len(valid_text_config_names) == 1:
1130
+ return getattr(self, valid_text_config_names[0])
1131
+ return self
1132
+
1133
+
1134
+ def get_configuration_file(configuration_files: List[str]) -> str:
1135
+ """
1136
+ Get the configuration file to use for this version of transformers.
1137
+
1138
+ Args:
1139
+ configuration_files (`List[str]`): The list of available configuration files.
1140
+
1141
+ Returns:
1142
+ `str`: The configuration file to use.
1143
+ """
1144
+ configuration_files_map = {}
1145
+ for file_name in configuration_files:
1146
+ search = _re_configuration_file.search(file_name)
1147
+ if search is not None:
1148
+ v = search.groups()[0]
1149
+ configuration_files_map[v] = file_name
1150
+ available_versions = sorted(configuration_files_map.keys())
1151
+
1152
+ # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
1153
+ configuration_file = CONFIG_NAME
1154
+ transformers_version = version.parse(__version__)
1155
+ for v in available_versions:
1156
+ if version.parse(v) <= transformers_version:
1157
+ configuration_file = configuration_files_map[v]
1158
+ else:
1159
+ # No point going further since the versions are sorted.
1160
+ break
1161
+
1162
+ return configuration_file
1163
+
1164
+
1165
+ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
1166
+ """
1167
+ Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
1168
+ values from `dict_a` that are different from values in `dict_b`.
1169
+ """
1170
+ diff = {}
1171
+ default = config_obj.__class__().to_dict() if config_obj is not None else {}
1172
+ for key, value in dict_a.items():
1173
+ obj_value = getattr(config_obj, str(key), None)
1174
+ if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
1175
+ diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
1176
+ if len(diff_value) > 0:
1177
+ diff[key] = diff_value
1178
+ elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
1179
+ diff[key] = value
1180
+ return diff
1181
+
1182
+
1183
+ PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
1184
+ if PretrainedConfig.push_to_hub.__doc__ is not None:
1185
+ PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
1186
+ object="config", object_class="AutoConfig", object_files="configuration file"
1187
+ )
.venv/lib/python3.11/site-packages/transformers/convert_graph_to_onnx.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from argparse import ArgumentParser
17
+ from os import listdir, makedirs
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Tuple
20
+
21
+ from packaging.version import Version, parse
22
+
23
+ from transformers.pipelines import Pipeline, pipeline
24
+ from transformers.tokenization_utils import BatchEncoding
25
+ from transformers.utils import ModelOutput, is_tf_available, is_torch_available
26
+
27
+
28
+ # This is the minimal required version to
29
+ # support some ONNX Runtime features
30
+ ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
31
+
32
+
33
+ SUPPORTED_PIPELINES = [
34
+ "feature-extraction",
35
+ "ner",
36
+ "sentiment-analysis",
37
+ "fill-mask",
38
+ "question-answering",
39
+ "text-generation",
40
+ "translation_en_to_fr",
41
+ "translation_en_to_de",
42
+ "translation_en_to_ro",
43
+ ]
44
+
45
+
46
+ class OnnxConverterArgumentParser(ArgumentParser):
47
+ """
48
+ Wraps all the script arguments supported to export transformers models to ONNX IR
49
+ """
50
+
51
+ def __init__(self):
52
+ super().__init__("ONNX Converter")
53
+
54
+ self.add_argument(
55
+ "--pipeline",
56
+ type=str,
57
+ choices=SUPPORTED_PIPELINES,
58
+ default="feature-extraction",
59
+ )
60
+ self.add_argument(
61
+ "--model",
62
+ type=str,
63
+ required=True,
64
+ help="Model's id or path (ex: google-bert/bert-base-cased)",
65
+ )
66
+ self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: google-bert/bert-base-cased)")
67
+ self.add_argument(
68
+ "--framework",
69
+ type=str,
70
+ choices=["pt", "tf"],
71
+ help="Framework for loading the model",
72
+ )
73
+ self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
74
+ self.add_argument(
75
+ "--check-loading",
76
+ action="store_true",
77
+ help="Check ONNX is able to load the model",
78
+ )
79
+ self.add_argument(
80
+ "--use-external-format",
81
+ action="store_true",
82
+ help="Allow exporting model >= than 2Gb",
83
+ )
84
+ self.add_argument(
85
+ "--quantize",
86
+ action="store_true",
87
+ help="Quantize the neural network to be run with int8",
88
+ )
89
+ self.add_argument("output")
90
+
91
+
92
+ def generate_identified_filename(filename: Path, identifier: str) -> Path:
93
+ """
94
+ Append a string-identifier at the end (before the extension, if any) to the provided filepath
95
+
96
+ Args:
97
+ filename: pathlib.Path The actual path object we would like to add an identifier suffix
98
+ identifier: The suffix to add
99
+
100
+ Returns: String with concatenated identifier at the end of the filename
101
+ """
102
+ return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)
103
+
104
+
105
+ def check_onnxruntime_requirements(minimum_version: Version):
106
+ """
107
+ Check onnxruntime is installed and if the installed version match is recent enough
108
+
109
+ Raises:
110
+ ImportError: If onnxruntime is not installed or too old version is found
111
+ """
112
+ try:
113
+ import onnxruntime
114
+
115
+ # Parse the version of the installed onnxruntime
116
+ ort_version = parse(onnxruntime.__version__)
117
+
118
+ # We require 1.4.0 minimum
119
+ if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
120
+ raise ImportError(
121
+ f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
122
+ f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
123
+ "Please update onnxruntime by running `pip install --upgrade onnxruntime`"
124
+ )
125
+
126
+ except ImportError:
127
+ raise ImportError(
128
+ "onnxruntime doesn't seem to be currently installed. "
129
+ "Please install the onnxruntime by running `pip install onnxruntime`"
130
+ " and relaunch the conversion."
131
+ )
132
+
133
+
134
+ def ensure_valid_input(model, tokens, input_names):
135
+ """
136
+ Ensure inputs are presented in the correct order, without any Non
137
+
138
+ Args:
139
+ model: The model used to forward the input data
140
+ tokens: BatchEncoding holding the input data
141
+ input_names: The name of the inputs
142
+
143
+ Returns: Tuple
144
+
145
+ """
146
+ print("Ensuring inputs are in correct order")
147
+
148
+ model_args_name = model.forward.__code__.co_varnames
149
+ model_args, ordered_input_names = [], []
150
+ for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
151
+ if arg_name in input_names:
152
+ ordered_input_names.append(arg_name)
153
+ model_args.append(tokens[arg_name])
154
+ else:
155
+ print(f"{arg_name} is not present in the generated input list.")
156
+ break
157
+
158
+ print(f"Generated inputs order: {ordered_input_names}")
159
+ return ordered_input_names, tuple(model_args)
160
+
161
+
162
+ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
163
+ """
164
+ Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model
165
+
166
+ Args:
167
+ nlp: The pipeline object holding the model to be exported
168
+ framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)
169
+
170
+ Returns:
171
+
172
+ - List of the inferred input variable names
173
+ - List of the inferred output variable names
174
+ - Dictionary with input/output variables names as key and shape tensor as value
175
+ - a BatchEncoding reference which was used to infer all the above information
176
+ """
177
+
178
+ def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
179
+ if isinstance(tensor, (tuple, list)):
180
+ return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
181
+
182
+ else:
183
+ # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
184
+ axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
185
+ if is_input:
186
+ if len(tensor.shape) == 2:
187
+ axes[1] = "sequence"
188
+ else:
189
+ raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
190
+ else:
191
+ seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
192
+ axes.update({dim: "sequence" for dim in seq_axes})
193
+
194
+ print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
195
+ return axes
196
+
197
+ tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
198
+ seq_len = tokens.input_ids.shape[-1]
199
+ outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
200
+ if isinstance(outputs, ModelOutput):
201
+ outputs = outputs.to_tuple()
202
+ if not isinstance(outputs, (list, tuple)):
203
+ outputs = (outputs,)
204
+
205
+ # Generate input names & axes
206
+ input_vars = list(tokens.keys())
207
+ input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
208
+
209
+ # flatten potentially grouped outputs (past for gpt2, attentions)
210
+ outputs_flat = []
211
+ for output in outputs:
212
+ if isinstance(output, (tuple, list)):
213
+ outputs_flat.extend(output)
214
+ else:
215
+ outputs_flat.append(output)
216
+
217
+ # Generate output names & axes
218
+ output_names = [f"output_{i}" for i in range(len(outputs_flat))]
219
+ output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
220
+
221
+ # Create the aggregated axes representation
222
+ dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
223
+ return input_vars, output_names, dynamic_axes, tokens
224
+
225
+
226
+ def load_graph_from_args(
227
+ pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs
228
+ ) -> Pipeline:
229
+ """
230
+ Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model
231
+
232
+ Args:
233
+ pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
234
+ framework: The actual model to convert the pipeline from ("pt" or "tf")
235
+ model: The model name which will be loaded by the pipeline
236
+ tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value
237
+
238
+ Returns: Pipeline object
239
+
240
+ """
241
+ # If no tokenizer provided
242
+ if tokenizer is None:
243
+ tokenizer = model
244
+
245
+ # Check the wanted framework is available
246
+ if framework == "pt" and not is_torch_available():
247
+ raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
248
+ if framework == "tf" and not is_tf_available():
249
+ raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
250
+
251
+ print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")
252
+
253
+ # Allocate tokenizer and model
254
+ return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)
255
+
256
+
257
+ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
258
+ """
259
+ Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
260
+
261
+ Args:
262
+ nlp: The pipeline to be exported
263
+ opset: The actual version of the ONNX operator set to use
264
+ output: Path where will be stored the generated ONNX model
265
+ use_external_format: Split the model definition from its parameters to allow model bigger than 2GB
266
+
267
+ Returns:
268
+
269
+ """
270
+ if not is_torch_available():
271
+ raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
272
+
273
+ import torch
274
+ from torch.onnx import export
275
+
276
+ print(f"Using framework PyTorch: {torch.__version__}")
277
+
278
+ with torch.no_grad():
279
+ input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
280
+ ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
281
+
282
+ export(
283
+ nlp.model,
284
+ model_args,
285
+ f=output.as_posix(),
286
+ input_names=ordered_input_names,
287
+ output_names=output_names,
288
+ dynamic_axes=dynamic_axes,
289
+ do_constant_folding=True,
290
+ opset_version=opset,
291
+ )
292
+
293
+
294
+ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
295
+ """
296
+ Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
297
+
298
+ Args:
299
+ nlp: The pipeline to be exported
300
+ opset: The actual version of the ONNX operator set to use
301
+ output: Path where will be stored the generated ONNX model
302
+
303
+ Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow
304
+
305
+ """
306
+ if not is_tf_available():
307
+ raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
308
+
309
+ print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
310
+
311
+ try:
312
+ import tensorflow as tf
313
+ import tf2onnx
314
+ from tf2onnx import __version__ as t2ov
315
+
316
+ print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
317
+
318
+ # Build
319
+ input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
320
+
321
+ # Forward
322
+ nlp.model.predict(tokens.data)
323
+ input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()]
324
+ model_proto, _ = tf2onnx.convert.from_keras(
325
+ nlp.model, input_signature, opset=opset, output_path=output.as_posix()
326
+ )
327
+
328
+ except ImportError as e:
329
+ raise Exception(
330
+ f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}"
331
+ )
332
+
333
+
334
+ def convert(
335
+ framework: str,
336
+ model: str,
337
+ output: Path,
338
+ opset: int,
339
+ tokenizer: Optional[str] = None,
340
+ use_external_format: bool = False,
341
+ pipeline_name: str = "feature-extraction",
342
+ **model_kwargs,
343
+ ):
344
+ """
345
+ Convert the pipeline object to the ONNX Intermediate Representation (IR) format
346
+
347
+ Args:
348
+ framework: The framework the pipeline is backed by ("pt" or "tf")
349
+ model: The name of the model to load for the pipeline
350
+ output: The path where the ONNX graph will be stored
351
+ opset: The actual version of the ONNX operator set to use
352
+ tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
353
+ use_external_format:
354
+ Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
355
+ pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
356
+ model_kwargs: Keyword arguments to be forwarded to the model constructor
357
+
358
+ Returns:
359
+
360
+ """
361
+ warnings.warn(
362
+ "The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of"
363
+ " Transformers",
364
+ FutureWarning,
365
+ )
366
+ print(f"ONNX opset version set to: {opset}")
367
+
368
+ # Load the pipeline
369
+ nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)
370
+
371
+ if not output.parent.exists():
372
+ print(f"Creating folder {output.parent}")
373
+ makedirs(output.parent.as_posix())
374
+ elif len(listdir(output.parent.as_posix())) > 0:
375
+ raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")
376
+
377
+ # Export the graph
378
+ if framework == "pt":
379
+ convert_pytorch(nlp, opset, output, use_external_format)
380
+ else:
381
+ convert_tensorflow(nlp, opset, output)
382
+
383
+
384
+ def optimize(onnx_model_path: Path) -> Path:
385
+ """
386
+ Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the
387
+ optimizations possible
388
+
389
+ Args:
390
+ onnx_model_path: filepath where the model binary description is stored
391
+
392
+ Returns: Path where the optimized model binary description has been saved
393
+
394
+ """
395
+ from onnxruntime import InferenceSession, SessionOptions
396
+
397
+ # Generate model name with suffix "optimized"
398
+ opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
399
+ sess_option = SessionOptions()
400
+ sess_option.optimized_model_filepath = opt_model_path.as_posix()
401
+ _ = InferenceSession(onnx_model_path.as_posix(), sess_option)
402
+
403
+ print(f"Optimized model has been written at {opt_model_path}: \N{HEAVY CHECK MARK}")
404
+ print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")
405
+
406
+ return opt_model_path
407
+
408
+
409
+ def quantize(onnx_model_path: Path) -> Path:
410
+ """
411
+ Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
412
+
413
+ Args:
414
+ onnx_model_path: Path to location the exported ONNX model is stored
415
+
416
+ Returns: The Path generated for the quantized
417
+ """
418
+ import onnx
419
+ import onnxruntime
420
+ from onnx.onnx_pb import ModelProto
421
+ from onnxruntime.quantization import QuantizationMode
422
+ from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
423
+ from onnxruntime.quantization.registry import IntegerOpsRegistry
424
+
425
+ # Load the ONNX model
426
+ onnx_model = onnx.load(onnx_model_path.as_posix())
427
+
428
+ if parse(onnx.__version__) < parse("1.5.0"):
429
+ print(
430
+ "Models larger than 2GB will fail to quantize due to protobuf constraint.\n"
431
+ "Please upgrade to onnxruntime >= 1.5.0."
432
+ )
433
+
434
+ # Copy it
435
+ copy_model = ModelProto()
436
+ copy_model.CopyFrom(onnx_model)
437
+
438
+ # Construct quantizer
439
+ # onnxruntime renamed input_qType to activation_qType in v1.13.1, so we
440
+ # check the onnxruntime version to ensure backward compatibility.
441
+ # See also: https://github.com/microsoft/onnxruntime/pull/12873
442
+ if parse(onnxruntime.__version__) < parse("1.13.1"):
443
+ quantizer = ONNXQuantizer(
444
+ model=copy_model,
445
+ per_channel=False,
446
+ reduce_range=False,
447
+ mode=QuantizationMode.IntegerOps,
448
+ static=False,
449
+ weight_qType=True,
450
+ input_qType=False,
451
+ tensors_range=None,
452
+ nodes_to_quantize=None,
453
+ nodes_to_exclude=None,
454
+ op_types_to_quantize=list(IntegerOpsRegistry),
455
+ )
456
+ else:
457
+ quantizer = ONNXQuantizer(
458
+ model=copy_model,
459
+ per_channel=False,
460
+ reduce_range=False,
461
+ mode=QuantizationMode.IntegerOps,
462
+ static=False,
463
+ weight_qType=True,
464
+ activation_qType=False,
465
+ tensors_range=None,
466
+ nodes_to_quantize=None,
467
+ nodes_to_exclude=None,
468
+ op_types_to_quantize=list(IntegerOpsRegistry),
469
+ )
470
+
471
+ # Quantize and export
472
+ quantizer.quantize_model()
473
+
474
+ # Append "-quantized" at the end of the model's name
475
+ quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
476
+
477
+ # Save model
478
+ print(f"Quantized model has been written at {quantized_model_path}: \N{HEAVY CHECK MARK}")
479
+ onnx.save_model(quantizer.model.model, quantized_model_path.as_posix())
480
+
481
+ return quantized_model_path
482
+
483
+
484
+ def verify(path: Path):
485
+ from onnxruntime import InferenceSession, SessionOptions
486
+ from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
487
+
488
+ print(f"Checking ONNX model loading from: {path} ...")
489
+ try:
490
+ onnx_options = SessionOptions()
491
+ _ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
492
+ print(f"Model {path} correctly loaded: \N{HEAVY CHECK MARK}")
493
+ except RuntimeException as re:
494
+ print(f"Error while loading the model {re}: \N{HEAVY BALLOT X}")
495
+
496
+
497
+ if __name__ == "__main__":
498
+ parser = OnnxConverterArgumentParser()
499
+ args = parser.parse_args()
500
+
501
+ # Make sure output is absolute path
502
+ args.output = Path(args.output).absolute()
503
+
504
+ try:
505
+ print("\n====== Converting model to ONNX ======")
506
+ # Convert
507
+ convert(
508
+ args.framework,
509
+ args.model,
510
+ args.output,
511
+ args.opset,
512
+ args.tokenizer,
513
+ args.use_external_format,
514
+ args.pipeline,
515
+ )
516
+
517
+ if args.quantize:
518
+ # Ensure requirements for quantization on onnxruntime is met
519
+ check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)
520
+
521
+ # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch
522
+ if args.framework == "tf":
523
+ print(
524
+ "\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
525
+ "\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
526
+ "\t For more information, please refer to the onnxruntime documentation:\n"
527
+ "\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
528
+ )
529
+
530
+ print("\n====== Optimizing ONNX model ======")
531
+
532
+ # Quantization works best when using the optimized version of the model
533
+ args.optimized_output = optimize(args.output)
534
+
535
+ # Do the quantization on the right graph
536
+ args.quantized_output = quantize(args.optimized_output)
537
+
538
+ # And verify
539
+ if args.check_loading:
540
+ print("\n====== Check exported ONNX model(s) ======")
541
+ verify(args.output)
542
+
543
+ if hasattr(args, "optimized_output"):
544
+ verify(args.optimized_output)
545
+
546
+ if hasattr(args, "quantized_output"):
547
+ verify(args.quantized_output)
548
+
549
+ except Exception as e:
550
+ print(f"Error while converting the model: {e}")
551
+ exit(1)
.venv/lib/python3.11/site-packages/transformers/convert_pytorch_checkpoint_to_tf2.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert pytorch checkpoints to TensorFlow"""
16
+
17
+ import argparse
18
+ import os
19
+
20
+ from . import (
21
+ AlbertConfig,
22
+ BartConfig,
23
+ BertConfig,
24
+ CamembertConfig,
25
+ CTRLConfig,
26
+ DistilBertConfig,
27
+ DPRConfig,
28
+ ElectraConfig,
29
+ FlaubertConfig,
30
+ GPT2Config,
31
+ LayoutLMConfig,
32
+ LxmertConfig,
33
+ OpenAIGPTConfig,
34
+ RobertaConfig,
35
+ T5Config,
36
+ TFAlbertForPreTraining,
37
+ TFBartForConditionalGeneration,
38
+ TFBartForSequenceClassification,
39
+ TFBertForPreTraining,
40
+ TFBertForQuestionAnswering,
41
+ TFBertForSequenceClassification,
42
+ TFCamembertForMaskedLM,
43
+ TFCTRLLMHeadModel,
44
+ TFDistilBertForMaskedLM,
45
+ TFDistilBertForQuestionAnswering,
46
+ TFDPRContextEncoder,
47
+ TFDPRQuestionEncoder,
48
+ TFDPRReader,
49
+ TFElectraForPreTraining,
50
+ TFFlaubertWithLMHeadModel,
51
+ TFGPT2LMHeadModel,
52
+ TFLayoutLMForMaskedLM,
53
+ TFLxmertForPreTraining,
54
+ TFLxmertVisualFeatureEncoder,
55
+ TFOpenAIGPTLMHeadModel,
56
+ TFRobertaForCausalLM,
57
+ TFRobertaForMaskedLM,
58
+ TFRobertaForSequenceClassification,
59
+ TFT5ForConditionalGeneration,
60
+ TFTransfoXLLMHeadModel,
61
+ TFWav2Vec2Model,
62
+ TFXLMRobertaForMaskedLM,
63
+ TFXLMWithLMHeadModel,
64
+ TFXLNetLMHeadModel,
65
+ TransfoXLConfig,
66
+ Wav2Vec2Config,
67
+ Wav2Vec2Model,
68
+ XLMConfig,
69
+ XLMRobertaConfig,
70
+ XLNetConfig,
71
+ is_torch_available,
72
+ load_pytorch_checkpoint_in_tf2_model,
73
+ )
74
+ from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
75
+
76
+
77
+ if is_torch_available():
78
+ import numpy as np
79
+ import torch
80
+
81
+ from . import (
82
+ AlbertForPreTraining,
83
+ BartForConditionalGeneration,
84
+ BertForPreTraining,
85
+ BertForQuestionAnswering,
86
+ BertForSequenceClassification,
87
+ CamembertForMaskedLM,
88
+ CTRLLMHeadModel,
89
+ DistilBertForMaskedLM,
90
+ DistilBertForQuestionAnswering,
91
+ DPRContextEncoder,
92
+ DPRQuestionEncoder,
93
+ DPRReader,
94
+ ElectraForPreTraining,
95
+ FlaubertWithLMHeadModel,
96
+ GPT2LMHeadModel,
97
+ LayoutLMForMaskedLM,
98
+ LxmertForPreTraining,
99
+ LxmertVisualFeatureEncoder,
100
+ OpenAIGPTLMHeadModel,
101
+ RobertaForMaskedLM,
102
+ RobertaForSequenceClassification,
103
+ T5ForConditionalGeneration,
104
+ TransfoXLLMHeadModel,
105
+ XLMRobertaForMaskedLM,
106
+ XLMWithLMHeadModel,
107
+ XLNetLMHeadModel,
108
+ )
109
+
110
+
111
+ logging.set_verbosity_info()
112
+
113
+ MODEL_CLASSES = {
114
+ "bart": (
115
+ BartConfig,
116
+ TFBartForConditionalGeneration,
117
+ TFBartForSequenceClassification,
118
+ BartForConditionalGeneration,
119
+ ),
120
+ "bert": (
121
+ BertConfig,
122
+ TFBertForPreTraining,
123
+ BertForPreTraining,
124
+ ),
125
+ "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
126
+ BertConfig,
127
+ TFBertForQuestionAnswering,
128
+ BertForQuestionAnswering,
129
+ ),
130
+ "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
131
+ BertConfig,
132
+ TFBertForQuestionAnswering,
133
+ BertForQuestionAnswering,
134
+ ),
135
+ "google-bert/bert-base-cased-finetuned-mrpc": (
136
+ BertConfig,
137
+ TFBertForSequenceClassification,
138
+ BertForSequenceClassification,
139
+ ),
140
+ "dpr": (
141
+ DPRConfig,
142
+ TFDPRQuestionEncoder,
143
+ TFDPRContextEncoder,
144
+ TFDPRReader,
145
+ DPRQuestionEncoder,
146
+ DPRContextEncoder,
147
+ DPRReader,
148
+ ),
149
+ "openai-community/gpt2": (
150
+ GPT2Config,
151
+ TFGPT2LMHeadModel,
152
+ GPT2LMHeadModel,
153
+ ),
154
+ "xlnet": (
155
+ XLNetConfig,
156
+ TFXLNetLMHeadModel,
157
+ XLNetLMHeadModel,
158
+ ),
159
+ "xlm": (
160
+ XLMConfig,
161
+ TFXLMWithLMHeadModel,
162
+ XLMWithLMHeadModel,
163
+ ),
164
+ "xlm-roberta": (
165
+ XLMRobertaConfig,
166
+ TFXLMRobertaForMaskedLM,
167
+ XLMRobertaForMaskedLM,
168
+ ),
169
+ "transfo-xl": (
170
+ TransfoXLConfig,
171
+ TFTransfoXLLMHeadModel,
172
+ TransfoXLLMHeadModel,
173
+ ),
174
+ "openai-community/openai-gpt": (
175
+ OpenAIGPTConfig,
176
+ TFOpenAIGPTLMHeadModel,
177
+ OpenAIGPTLMHeadModel,
178
+ ),
179
+ "roberta": (
180
+ RobertaConfig,
181
+ TFRobertaForCausalLM,
182
+ TFRobertaForMaskedLM,
183
+ RobertaForMaskedLM,
184
+ ),
185
+ "layoutlm": (
186
+ LayoutLMConfig,
187
+ TFLayoutLMForMaskedLM,
188
+ LayoutLMForMaskedLM,
189
+ ),
190
+ "FacebookAI/roberta-large-mnli": (
191
+ RobertaConfig,
192
+ TFRobertaForSequenceClassification,
193
+ RobertaForSequenceClassification,
194
+ ),
195
+ "camembert": (
196
+ CamembertConfig,
197
+ TFCamembertForMaskedLM,
198
+ CamembertForMaskedLM,
199
+ ),
200
+ "flaubert": (
201
+ FlaubertConfig,
202
+ TFFlaubertWithLMHeadModel,
203
+ FlaubertWithLMHeadModel,
204
+ ),
205
+ "distilbert": (
206
+ DistilBertConfig,
207
+ TFDistilBertForMaskedLM,
208
+ DistilBertForMaskedLM,
209
+ ),
210
+ "distilbert-base-distilled-squad": (
211
+ DistilBertConfig,
212
+ TFDistilBertForQuestionAnswering,
213
+ DistilBertForQuestionAnswering,
214
+ ),
215
+ "lxmert": (
216
+ LxmertConfig,
217
+ TFLxmertForPreTraining,
218
+ LxmertForPreTraining,
219
+ ),
220
+ "lxmert-visual-feature-encoder": (
221
+ LxmertConfig,
222
+ TFLxmertVisualFeatureEncoder,
223
+ LxmertVisualFeatureEncoder,
224
+ ),
225
+ "Salesforce/ctrl": (
226
+ CTRLConfig,
227
+ TFCTRLLMHeadModel,
228
+ CTRLLMHeadModel,
229
+ ),
230
+ "albert": (
231
+ AlbertConfig,
232
+ TFAlbertForPreTraining,
233
+ AlbertForPreTraining,
234
+ ),
235
+ "t5": (
236
+ T5Config,
237
+ TFT5ForConditionalGeneration,
238
+ T5ForConditionalGeneration,
239
+ ),
240
+ "electra": (
241
+ ElectraConfig,
242
+ TFElectraForPreTraining,
243
+ ElectraForPreTraining,
244
+ ),
245
+ "wav2vec2": (
246
+ Wav2Vec2Config,
247
+ TFWav2Vec2Model,
248
+ Wav2Vec2Model,
249
+ ),
250
+ }
251
+
252
+
253
+ def convert_pt_checkpoint_to_tf(
254
+ model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
255
+ ):
256
+ if model_type not in MODEL_CLASSES:
257
+ raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")
258
+
259
+ config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
260
+
261
+ # Initialise TF model
262
+ if config_file in aws_config_map:
263
+ config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
264
+ config = config_class.from_json_file(config_file)
265
+ config.output_hidden_states = True
266
+ config.output_attentions = True
267
+ print(f"Building TensorFlow model from configuration: {config}")
268
+ tf_model = model_class(config)
269
+
270
+ # Load weights from tf checkpoint
271
+ if pytorch_checkpoint_path in aws_config_map.keys():
272
+ pytorch_checkpoint_path = cached_file(
273
+ pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
274
+ )
275
+ # Load PyTorch checkpoint in tf2 model:
276
+ tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
277
+
278
+ if compare_with_pt_model:
279
+ tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
280
+
281
+ weights_only_kwarg = {"weights_only": True}
282
+ state_dict = torch.load(
283
+ pytorch_checkpoint_path,
284
+ map_location="cpu",
285
+ **weights_only_kwarg,
286
+ )
287
+ pt_model = pt_model_class.from_pretrained(
288
+ pretrained_model_name_or_path=None, config=config, state_dict=state_dict
289
+ )
290
+
291
+ with torch.no_grad():
292
+ pto = pt_model(**pt_model.dummy_inputs)
293
+
294
+ np_pt = pto[0].numpy()
295
+ np_tf = tfo[0].numpy()
296
+ diff = np.amax(np.abs(np_pt - np_tf))
297
+ print(f"Max absolute difference between models outputs {diff}")
298
+ assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"
299
+
300
+ # Save pytorch-model
301
+ print(f"Save TensorFlow model to {tf_dump_path}")
302
+ tf_model.save_weights(tf_dump_path, save_format="h5")
303
+
304
+
305
+ def convert_all_pt_checkpoints_to_tf(
306
+ args_model_type,
307
+ tf_dump_path,
308
+ model_shortcut_names_or_path=None,
309
+ config_shortcut_names_or_path=None,
310
+ compare_with_pt_model=False,
311
+ use_cached_models=False,
312
+ remove_cached_files=False,
313
+ only_convert_finetuned_models=False,
314
+ ):
315
+ if args_model_type is None:
316
+ model_types = list(MODEL_CLASSES.keys())
317
+ else:
318
+ model_types = [args_model_type]
319
+
320
+ for j, model_type in enumerate(model_types, start=1):
321
+ print("=" * 100)
322
+ print(f" Converting model type {j}/{len(model_types)}: {model_type}")
323
+ print("=" * 100)
324
+ if model_type not in MODEL_CLASSES:
325
+ raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")
326
+
327
+ config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
328
+
329
+ if model_shortcut_names_or_path is None:
330
+ model_shortcut_names_or_path = list(aws_model_maps.keys())
331
+ if config_shortcut_names_or_path is None:
332
+ config_shortcut_names_or_path = model_shortcut_names_or_path
333
+
334
+ for i, (model_shortcut_name, config_shortcut_name) in enumerate(
335
+ zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
336
+ ):
337
+ print("-" * 100)
338
+ if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
339
+ if not only_convert_finetuned_models:
340
+ print(f" Skipping finetuned checkpoint {model_shortcut_name}")
341
+ continue
342
+ model_type = model_shortcut_name
343
+ elif only_convert_finetuned_models:
344
+ print(f" Skipping not finetuned checkpoint {model_shortcut_name}")
345
+ continue
346
+ print(
347
+ f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}"
348
+ )
349
+ print("-" * 100)
350
+
351
+ if config_shortcut_name in aws_config_map:
352
+ config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
353
+ else:
354
+ config_file = config_shortcut_name
355
+
356
+ if model_shortcut_name in aws_model_maps:
357
+ model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
358
+ else:
359
+ model_file = model_shortcut_name
360
+
361
+ if os.path.isfile(model_shortcut_name):
362
+ model_shortcut_name = "converted_model"
363
+
364
+ convert_pt_checkpoint_to_tf(
365
+ model_type=model_type,
366
+ pytorch_checkpoint_path=model_file,
367
+ config_file=config_file,
368
+ tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
369
+ compare_with_pt_model=compare_with_pt_model,
370
+ )
371
+ if remove_cached_files:
372
+ os.remove(config_file)
373
+ os.remove(model_file)
374
+
375
+
376
+ if __name__ == "__main__":
377
+ parser = argparse.ArgumentParser()
378
+ # Required parameters
379
+ parser.add_argument(
380
+ "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
381
+ )
382
+ parser.add_argument(
383
+ "--model_type",
384
+ default=None,
385
+ type=str,
386
+ help=(
387
+ f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
388
+ "convert all the models from AWS."
389
+ ),
390
+ )
391
+ parser.add_argument(
392
+ "--pytorch_checkpoint_path",
393
+ default=None,
394
+ type=str,
395
+ help=(
396
+ "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
397
+ "If not given, will download and convert all the checkpoints from AWS."
398
+ ),
399
+ )
400
+ parser.add_argument(
401
+ "--config_file",
402
+ default=None,
403
+ type=str,
404
+ help=(
405
+ "The config json file corresponding to the pre-trained model. \n"
406
+ "This specifies the model architecture. If not given and "
407
+ "--pytorch_checkpoint_path is not given or is a shortcut name "
408
+ "use the configuration associated to the shortcut name on the AWS"
409
+ ),
410
+ )
411
+ parser.add_argument(
412
+ "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
413
+ )
414
+ parser.add_argument(
415
+ "--use_cached_models",
416
+ action="store_true",
417
+ help="Use cached models if possible instead of updating to latest checkpoint versions.",
418
+ )
419
+ parser.add_argument(
420
+ "--remove_cached_files",
421
+ action="store_true",
422
+ help="Remove pytorch models after conversion (save memory when converting in batches).",
423
+ )
424
+ parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
425
+ args = parser.parse_args()
426
+
427
+ # if args.pytorch_checkpoint_path is not None:
428
+ # convert_pt_checkpoint_to_tf(args.model_type.lower(),
429
+ # args.pytorch_checkpoint_path,
430
+ # args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
431
+ # args.tf_dump_path,
432
+ # compare_with_pt_model=args.compare_with_pt_model,
433
+ # use_cached_models=args.use_cached_models)
434
+ # else:
435
+ convert_all_pt_checkpoints_to_tf(
436
+ args.model_type.lower() if args.model_type is not None else None,
437
+ args.tf_dump_path,
438
+ model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
439
+ if args.pytorch_checkpoint_path is not None
440
+ else None,
441
+ config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
442
+ compare_with_pt_model=args.compare_with_pt_model,
443
+ use_cached_models=args.use_cached_models,
444
+ remove_cached_files=args.remove_cached_files,
445
+ only_convert_finetuned_models=args.only_convert_finetuned_models,
446
+ )
.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py ADDED
@@ -0,0 +1,1642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Utilities to convert slow tokenizers in their fast tokenizers counterparts.
17
+
18
+ All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
19
+ allow to make our dependency on SentencePiece optional.
20
+ """
21
+
22
+ import warnings
23
+ from typing import Dict, List, Tuple
24
+
25
+ from packaging import version
26
+ from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
27
+ from tokenizers.models import BPE, Unigram, WordPiece
28
+
29
+ from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
30
+ from .utils.import_utils import PROTOBUF_IMPORT_ERROR
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ def import_protobuf(error_message=""):
37
+ if is_sentencepiece_available():
38
+ from sentencepiece import sentencepiece_model_pb2
39
+
40
+ return sentencepiece_model_pb2
41
+ if is_protobuf_available():
42
+ import google.protobuf
43
+
44
+ if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
45
+ from transformers.utils import sentencepiece_model_pb2
46
+ else:
47
+ from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
48
+ return sentencepiece_model_pb2
49
+ else:
50
+ raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
51
+
52
+
53
+ def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
54
+ if add_prefix_space:
55
+ prepend_scheme = "always"
56
+ if not getattr(original_tokenizer, "legacy", True):
57
+ prepend_scheme = "first"
58
+ else:
59
+ prepend_scheme = "never"
60
+ return prepend_scheme
61
+
62
+
63
+ def generate_merges(vocab, vocab_scores):
64
+ reverse = vocab_scores is not None
65
+ vocab_scores = dict(vocab_scores) if reverse else vocab
66
+
67
+ merges = []
68
+ for merge, piece_score in vocab_scores.items():
69
+ local = []
70
+ for index in range(1, len(merge)):
71
+ piece_l, piece_r = merge[:index], merge[index:]
72
+ if piece_l in vocab and piece_r in vocab:
73
+ local.append((piece_l, piece_r, piece_score))
74
+ local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
75
+ merges.extend(local)
76
+
77
+ merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
78
+ merges = [(val[0], val[1]) for val in merges]
79
+ return merges
80
+
81
+
82
+ class SentencePieceExtractor:
83
+ """
84
+ Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
85
+ """
86
+
87
+ def __init__(self, model: str):
88
+ requires_backends(self, "sentencepiece")
89
+ from sentencepiece import SentencePieceProcessor
90
+
91
+ self.sp = SentencePieceProcessor()
92
+ self.sp.Load(model)
93
+
94
+ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
95
+ """
96
+ By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
97
+ order the merges with respect to the piece scores instead.
98
+ """
99
+ sp = self.sp
100
+ vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
101
+
102
+ merges = generate_merges(vocab, vocab_scores)
103
+
104
+ return vocab, merges
105
+
106
+
107
+ class GemmaSentencePieceExtractor(SentencePieceExtractor):
108
+ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
109
+ """
110
+ By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
111
+ order the merges with respect to the piece scores instead.
112
+ """
113
+ sp = self.sp
114
+ vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
115
+
116
+ # there is a missing token in the vocab. We have to do this to support merges
117
+ # "<0x09>" is the bytefallback for `\t`
118
+ vocab["\t"] = vocab.get("<0x09>")
119
+
120
+ merges = generate_merges(vocab, vocab_scores)
121
+ return vocab, merges
122
+
123
+
124
+ def check_number_comma(piece: str) -> bool:
125
+ return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
126
+
127
+
128
+ class Converter:
129
+ def __init__(self, original_tokenizer):
130
+ self.original_tokenizer = original_tokenizer
131
+
132
+ def converted(self) -> Tokenizer:
133
+ raise NotImplementedError()
134
+
135
+
136
+ class BertConverter(Converter):
137
+ def converted(self) -> Tokenizer:
138
+ vocab = self.original_tokenizer.vocab
139
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
140
+
141
+ tokenize_chinese_chars = False
142
+ strip_accents = False
143
+ do_lower_case = False
144
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
145
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
146
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
147
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
148
+
149
+ tokenizer.normalizer = normalizers.BertNormalizer(
150
+ clean_text=True,
151
+ handle_chinese_chars=tokenize_chinese_chars,
152
+ strip_accents=strip_accents,
153
+ lowercase=do_lower_case,
154
+ )
155
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
156
+
157
+ cls = str(self.original_tokenizer.cls_token)
158
+ sep = str(self.original_tokenizer.sep_token)
159
+ cls_token_id = self.original_tokenizer.cls_token_id
160
+ sep_token_id = self.original_tokenizer.sep_token_id
161
+
162
+ tokenizer.post_processor = processors.TemplateProcessing(
163
+ single=f"{cls}:0 $A:0 {sep}:0",
164
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
165
+ special_tokens=[
166
+ (cls, cls_token_id),
167
+ (sep, sep_token_id),
168
+ ],
169
+ )
170
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
171
+
172
+ return tokenizer
173
+
174
+
175
+ class SplinterConverter(Converter):
176
+ def converted(self) -> Tokenizer:
177
+ vocab = self.original_tokenizer.vocab
178
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
179
+
180
+ tokenize_chinese_chars = False
181
+ strip_accents = False
182
+ do_lower_case = False
183
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
184
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
185
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
186
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
187
+
188
+ tokenizer.normalizer = normalizers.BertNormalizer(
189
+ clean_text=True,
190
+ handle_chinese_chars=tokenize_chinese_chars,
191
+ strip_accents=strip_accents,
192
+ lowercase=do_lower_case,
193
+ )
194
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
195
+
196
+ cls = str(self.original_tokenizer.cls_token)
197
+ sep = str(self.original_tokenizer.sep_token)
198
+ question = str(self.original_tokenizer.question_token)
199
+ dot = "."
200
+ cls_token_id = self.original_tokenizer.cls_token_id
201
+ sep_token_id = self.original_tokenizer.sep_token_id
202
+ question_token_id = self.original_tokenizer.question_token_id
203
+ dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
204
+
205
+ if self.original_tokenizer.padding_side == "right":
206
+ pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
207
+ else:
208
+ pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
209
+
210
+ tokenizer.post_processor = processors.TemplateProcessing(
211
+ single=f"{cls}:0 $A:0 {sep}:0",
212
+ pair=pair,
213
+ special_tokens=[
214
+ (cls, cls_token_id),
215
+ (sep, sep_token_id),
216
+ (question, question_token_id),
217
+ (dot, dot_token_id),
218
+ ],
219
+ )
220
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
221
+
222
+ return tokenizer
223
+
224
+
225
+ class FunnelConverter(Converter):
226
+ def converted(self) -> Tokenizer:
227
+ vocab = self.original_tokenizer.vocab
228
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
229
+
230
+ tokenize_chinese_chars = False
231
+ strip_accents = False
232
+ do_lower_case = False
233
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
234
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
235
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
236
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
237
+
238
+ tokenizer.normalizer = normalizers.BertNormalizer(
239
+ clean_text=True,
240
+ handle_chinese_chars=tokenize_chinese_chars,
241
+ strip_accents=strip_accents,
242
+ lowercase=do_lower_case,
243
+ )
244
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
245
+
246
+ cls = str(self.original_tokenizer.cls_token)
247
+ sep = str(self.original_tokenizer.sep_token)
248
+ cls_token_id = self.original_tokenizer.cls_token_id
249
+ sep_token_id = self.original_tokenizer.sep_token_id
250
+
251
+ tokenizer.post_processor = processors.TemplateProcessing(
252
+ single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
253
+ pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
254
+ special_tokens=[
255
+ (cls, cls_token_id),
256
+ (sep, sep_token_id),
257
+ ],
258
+ )
259
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
260
+
261
+ return tokenizer
262
+
263
+
264
+ class MPNetConverter(Converter):
265
+ def converted(self) -> Tokenizer:
266
+ vocab = self.original_tokenizer.vocab
267
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
268
+
269
+ tokenize_chinese_chars = False
270
+ strip_accents = False
271
+ do_lower_case = False
272
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
273
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
274
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
275
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
276
+
277
+ tokenizer.normalizer = normalizers.BertNormalizer(
278
+ clean_text=True,
279
+ handle_chinese_chars=tokenize_chinese_chars,
280
+ strip_accents=strip_accents,
281
+ lowercase=do_lower_case,
282
+ )
283
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
284
+
285
+ cls = str(self.original_tokenizer.cls_token)
286
+ sep = str(self.original_tokenizer.sep_token)
287
+ cls_token_id = self.original_tokenizer.cls_token_id
288
+ sep_token_id = self.original_tokenizer.sep_token_id
289
+
290
+ tokenizer.post_processor = processors.TemplateProcessing(
291
+ single=f"{cls}:0 $A:0 {sep}:0",
292
+ pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
293
+ special_tokens=[
294
+ (cls, cls_token_id),
295
+ (sep, sep_token_id),
296
+ ],
297
+ )
298
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
299
+
300
+ return tokenizer
301
+
302
+
303
+ class OpenAIGPTConverter(Converter):
304
+ def converted(self) -> Tokenizer:
305
+ vocab = self.original_tokenizer.encoder
306
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
307
+ unk_token = self.original_tokenizer.unk_token
308
+
309
+ tokenizer = Tokenizer(
310
+ BPE(
311
+ vocab=vocab,
312
+ merges=merges,
313
+ dropout=None,
314
+ unk_token=str(unk_token),
315
+ end_of_word_suffix="</w>",
316
+ fuse_unk=False,
317
+ )
318
+ )
319
+
320
+ if tokenizer.token_to_id(str(unk_token)) is not None:
321
+ tokenizer.add_special_tokens([str(unk_token)])
322
+
323
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
324
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
325
+ tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
326
+
327
+ return tokenizer
328
+
329
+
330
+ class GPT2Converter(Converter):
331
+ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
332
+ if not vocab:
333
+ vocab = self.original_tokenizer.encoder
334
+ if not merges:
335
+ merges = list(self.original_tokenizer.bpe_ranks)
336
+
337
+ tokenizer = Tokenizer(
338
+ BPE(
339
+ vocab=vocab,
340
+ merges=merges,
341
+ dropout=None,
342
+ continuing_subword_prefix="",
343
+ end_of_word_suffix="",
344
+ fuse_unk=False,
345
+ )
346
+ )
347
+
348
+ add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
349
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
350
+ tokenizer.decoder = decoders.ByteLevel()
351
+ if getattr(self.original_tokenizer, "add_bos_token", False):
352
+ bos = self.original_tokenizer.bos_token
353
+ bos_token_id = self.original_tokenizer.bos_token_id
354
+ tokenizer.post_processor = processors.TemplateProcessing(
355
+ single=f"{bos}:0 $A:0",
356
+ pair=f"{bos}:0 $A:0 $B:1",
357
+ special_tokens=[
358
+ (bos, bos_token_id),
359
+ ],
360
+ )
361
+ else:
362
+ # XXX trim_offsets=False actually means this post_processor doesn't
363
+ # really do anything.
364
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
365
+ return tokenizer
366
+
367
+
368
+ class HerbertConverter(Converter):
369
+ def converted(self) -> Tokenizer:
370
+ tokenizer_info_str = "#version:"
371
+ token_suffix = "</w>"
372
+
373
+ vocab = self.original_tokenizer.encoder
374
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
375
+ if tokenizer_info_str in merges[0][0]:
376
+ merges = merges[1:]
377
+
378
+ tokenizer = Tokenizer(
379
+ BPE(
380
+ vocab,
381
+ merges,
382
+ dropout=None,
383
+ unk_token=self.original_tokenizer.unk_token,
384
+ end_of_word_suffix=token_suffix,
385
+ )
386
+ )
387
+
388
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
389
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
390
+ tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
391
+ tokenizer.post_processor = processors.BertProcessing(
392
+ sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
393
+ cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
394
+ )
395
+
396
+ return tokenizer
397
+
398
+
399
+ class Qwen2Converter(Converter):
400
+ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
401
+ if not vocab:
402
+ vocab = self.original_tokenizer.encoder
403
+ if not merges:
404
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
405
+
406
+ tokenizer = Tokenizer(
407
+ BPE(
408
+ vocab=vocab,
409
+ merges=merges,
410
+ dropout=None,
411
+ unk_token=None,
412
+ continuing_subword_prefix="",
413
+ end_of_word_suffix="",
414
+ fuse_unk=False,
415
+ byte_fallback=False,
416
+ )
417
+ )
418
+
419
+ tokenizer.normalizer = normalizers.NFC()
420
+
421
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
422
+ [
423
+ pre_tokenizers.Split(
424
+ Regex(
425
+ r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
426
+ ),
427
+ behavior="isolated",
428
+ invert=False,
429
+ ),
430
+ pre_tokenizers.ByteLevel(
431
+ add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
432
+ use_regex=False,
433
+ ),
434
+ ]
435
+ )
436
+
437
+ tokenizer.decoder = decoders.ByteLevel()
438
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
439
+
440
+ return tokenizer
441
+
442
+
443
+ class RobertaConverter(Converter):
444
+ def converted(self) -> Tokenizer:
445
+ ot = self.original_tokenizer
446
+ vocab = ot.encoder
447
+ merges = list(ot.bpe_ranks.keys())
448
+
449
+ tokenizer = Tokenizer(
450
+ BPE(
451
+ vocab=vocab,
452
+ merges=merges,
453
+ dropout=None,
454
+ continuing_subword_prefix="",
455
+ end_of_word_suffix="",
456
+ fuse_unk=False,
457
+ )
458
+ )
459
+
460
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
461
+ tokenizer.decoder = decoders.ByteLevel()
462
+ tokenizer.post_processor = processors.RobertaProcessing(
463
+ sep=(ot.sep_token, ot.sep_token_id),
464
+ cls=(ot.cls_token, ot.cls_token_id),
465
+ add_prefix_space=ot.add_prefix_space,
466
+ trim_offsets=True, # True by default on Roberta (historical)
467
+ )
468
+
469
+ return tokenizer
470
+
471
+
472
+ class RoFormerConverter(Converter):
473
+ def converted(self) -> Tokenizer:
474
+ from .models.roformer.tokenization_utils import JiebaPreTokenizer
475
+
476
+ vocab = self.original_tokenizer.vocab
477
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
478
+
479
+ strip_accents = False
480
+ do_lower_case = False
481
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
482
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
483
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
484
+
485
+ tokenizer.normalizer = normalizers.BertNormalizer(
486
+ clean_text=True,
487
+ handle_chinese_chars=False,
488
+ strip_accents=strip_accents,
489
+ lowercase=do_lower_case,
490
+ )
491
+ tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
492
+
493
+ cls = str(self.original_tokenizer.cls_token)
494
+ sep = str(self.original_tokenizer.sep_token)
495
+ cls_token_id = self.original_tokenizer.cls_token_id
496
+ sep_token_id = self.original_tokenizer.sep_token_id
497
+
498
+ tokenizer.post_processor = processors.TemplateProcessing(
499
+ single=f"{cls}:0 $A:0 {sep}:0",
500
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
501
+ special_tokens=[
502
+ (cls, cls_token_id),
503
+ (sep, sep_token_id),
504
+ ],
505
+ )
506
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
507
+
508
+ return tokenizer
509
+
510
+
511
+ class DebertaConverter(Converter):
512
+ def converted(self) -> Tokenizer:
513
+ ot = self.original_tokenizer
514
+ vocab = ot.encoder
515
+ merges = list(ot.bpe_ranks.keys())
516
+
517
+ tokenizer = Tokenizer(
518
+ BPE(
519
+ vocab=vocab,
520
+ merges=merges,
521
+ dropout=None,
522
+ continuing_subword_prefix="",
523
+ end_of_word_suffix="",
524
+ fuse_unk=False,
525
+ )
526
+ )
527
+
528
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
529
+ tokenizer.decoder = decoders.ByteLevel()
530
+ tokenizer.post_processor = processors.TemplateProcessing(
531
+ single="[CLS]:0 $A:0 [SEP]:0",
532
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
533
+ special_tokens=[
534
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
535
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
536
+ ],
537
+ )
538
+
539
+ return tokenizer
540
+
541
+
542
+ class SpmConverter(Converter):
543
+ handle_byte_fallback = False
544
+ SpmExtractor = SentencePieceExtractor
545
+ special_tokens = {}
546
+
547
+ def __init__(self, *args):
548
+ requires_backends(self, "protobuf")
549
+
550
+ super().__init__(*args)
551
+
552
+ # from .utils import sentencepiece_model_pb2 as model_pb2
553
+ model_pb2 = import_protobuf()
554
+
555
+ m = model_pb2.ModelProto()
556
+ with open(self.original_tokenizer.vocab_file, "rb") as f:
557
+ m.ParseFromString(f.read())
558
+ self.proto = m
559
+
560
+ if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
561
+ warnings.warn(
562
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
563
+ " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
564
+ " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
565
+ "unknown tokens into a sequence of byte tokens matching the original piece of text."
566
+ )
567
+
568
+ def vocab(self, proto):
569
+ return [(piece.piece, piece.score) for piece in proto.pieces]
570
+
571
+ def unk_id(self, proto):
572
+ return proto.trainer_spec.unk_id
573
+
574
+ def tokenizer(self, proto):
575
+ model_type = proto.trainer_spec.model_type
576
+ vocab_scores = self.vocab(proto)
577
+
578
+ if model_type == 1:
579
+ tokenizer = Tokenizer(
580
+ Unigram(
581
+ vocab_scores,
582
+ unk_id=self.unk_id(proto),
583
+ byte_fallback=self.handle_byte_fallback,
584
+ )
585
+ )
586
+
587
+ elif model_type == 2:
588
+ _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
589
+ bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
590
+ tokenizer = Tokenizer(
591
+ BPE(
592
+ bpe_vocab,
593
+ merges,
594
+ unk_token=proto.trainer_spec.unk_piece,
595
+ fuse_unk=True,
596
+ byte_fallback=self.handle_byte_fallback,
597
+ dropout=None,
598
+ )
599
+ )
600
+
601
+ else:
602
+ raise Exception(
603
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
604
+ )
605
+
606
+ # control tokens are special
607
+ # user defined symbols are not
608
+ # both user and control tokens are AddedTokens
609
+ # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
610
+ spm_added_tokens = [
611
+ (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
612
+ for id, p in enumerate(proto.pieces)
613
+ if p.type in [3, 4]
614
+ ]
615
+ tokenizer.add_tokens(
616
+ [
617
+ AddedToken(token, normalized=False, special=special)
618
+ for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
619
+ ]
620
+ )
621
+
622
+ return tokenizer
623
+
624
+ def normalizer(self, proto):
625
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
626
+ _normalizers = [
627
+ normalizers.Strip(left=False, right=True), # stripping is important
628
+ normalizers.Replace(Regex(" {2,}"), "▁"),
629
+ ]
630
+ if not precompiled_charsmap:
631
+ return normalizers.Sequence(_normalizers)
632
+ else:
633
+ return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
634
+
635
+ def pre_tokenizer(self, replacement, add_prefix_space):
636
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
637
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
638
+
639
+ def post_processor(self):
640
+ return None
641
+
642
+ def decoder(self, replacement, add_prefix_space):
643
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
644
+ return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
645
+
646
+ def converted(self) -> Tokenizer:
647
+ tokenizer = self.tokenizer(self.proto)
648
+
649
+ # Tokenizer assemble
650
+ normalizer = self.normalizer(self.proto)
651
+ if normalizer is not None:
652
+ tokenizer.normalizer = normalizer
653
+
654
+ replacement = "▁"
655
+ add_prefix_space = True
656
+ if hasattr(self.original_tokenizer, "add_prefix_space"):
657
+ add_prefix_space = self.original_tokenizer.add_prefix_space
658
+
659
+ pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
660
+ if pre_tokenizer is not None:
661
+ tokenizer.pre_tokenizer = pre_tokenizer
662
+
663
+ tokenizer.decoder = self.decoder(replacement, add_prefix_space)
664
+ post_processor = self.post_processor()
665
+ if post_processor:
666
+ tokenizer.post_processor = post_processor
667
+
668
+ return tokenizer
669
+
670
+
671
+ class AlbertConverter(SpmConverter):
672
+ def vocab(self, proto):
673
+ return [
674
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
675
+ for piece in proto.pieces
676
+ ]
677
+
678
+ def normalizer(self, proto):
679
+ list_normalizers = [
680
+ normalizers.Replace("``", '"'),
681
+ normalizers.Replace("''", '"'),
682
+ ]
683
+ if not self.original_tokenizer.keep_accents:
684
+ list_normalizers.append(normalizers.NFKD())
685
+ list_normalizers.append(normalizers.StripAccents())
686
+ if self.original_tokenizer.do_lower_case:
687
+ list_normalizers.append(normalizers.Lowercase())
688
+
689
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
690
+
691
+ if precompiled_charsmap:
692
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
693
+
694
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
695
+ return normalizers.Sequence(list_normalizers)
696
+
697
+ def post_processor(self):
698
+ return processors.TemplateProcessing(
699
+ single="[CLS]:0 $A:0 [SEP]:0",
700
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
701
+ special_tokens=[
702
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
703
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
704
+ ],
705
+ )
706
+
707
+
708
+ class BarthezConverter(SpmConverter):
709
+ def unk_id(self, proto):
710
+ unk_id = 3
711
+ return unk_id
712
+
713
+ def post_processor(self):
714
+ return processors.TemplateProcessing(
715
+ single="<s> $A </s>",
716
+ pair="<s> $A </s> </s> $B </s>",
717
+ special_tokens=[
718
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
719
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
720
+ ],
721
+ )
722
+
723
+
724
+ class CamembertConverter(SpmConverter):
725
+ def vocab(self, proto):
726
+ vocab = [
727
+ ("<s>NOTUSED", 0.0),
728
+ ("<pad>", 0.0),
729
+ ("</s>NOTUSED", 0.0),
730
+ ("<unk>", 0.0),
731
+ ("<unk>NOTUSED", -100),
732
+ ]
733
+ # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
734
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
735
+ vocab += [("<mask>", 0.0)]
736
+ return vocab
737
+
738
+ def unk_id(self, proto):
739
+ # See vocab unk position
740
+ return 3
741
+
742
+ def post_processor(self):
743
+ return processors.TemplateProcessing(
744
+ single="<s> $A </s>",
745
+ pair="<s> $A </s> </s> $B </s>",
746
+ special_tokens=[
747
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
748
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
749
+ ],
750
+ )
751
+
752
+
753
+ class DebertaV2Converter(SpmConverter):
754
+ def pre_tokenizer(self, replacement, add_prefix_space):
755
+ list_pretokenizers = []
756
+ if self.original_tokenizer.split_by_punct:
757
+ list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
758
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
759
+ list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
760
+ return pre_tokenizers.Sequence(list_pretokenizers)
761
+
762
+ def normalizer(self, proto):
763
+ list_normalizers = []
764
+ if self.original_tokenizer.do_lower_case:
765
+ list_normalizers.append(normalizers.Lowercase())
766
+ list_normalizers.append(normalizers.Strip())
767
+
768
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
769
+ if precompiled_charsmap:
770
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
771
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
772
+
773
+ return normalizers.Sequence(list_normalizers)
774
+
775
+ def post_processor(self):
776
+ return processors.TemplateProcessing(
777
+ single="[CLS]:0 $A:0 [SEP]:0",
778
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
779
+ special_tokens=[
780
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
781
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
782
+ ],
783
+ )
784
+
785
+
786
+ class MBartConverter(SpmConverter):
787
+ def vocab(self, proto):
788
+ vocab = [
789
+ ("<s>", 0.0),
790
+ ("<pad>", 0.0),
791
+ ("</s>", 0.0),
792
+ ("<unk>", 0.0),
793
+ ]
794
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
795
+ vocab += [
796
+ ("ar_AR", 0.0),
797
+ ("cs_CZ", 0.0),
798
+ ("de_DE", 0.0),
799
+ ("en_XX", 0.0),
800
+ ("es_XX", 0.0),
801
+ ("et_EE", 0.0),
802
+ ("fi_FI", 0.0),
803
+ ("fr_XX", 0.0),
804
+ ("gu_IN", 0.0),
805
+ ("hi_IN", 0.0),
806
+ ("it_IT", 0.0),
807
+ ("ja_XX", 0.0),
808
+ ("kk_KZ", 0.0),
809
+ ("ko_KR", 0.0),
810
+ ("lt_LT", 0.0),
811
+ ("lv_LV", 0.0),
812
+ ("my_MM", 0.0),
813
+ ("ne_NP", 0.0),
814
+ ("nl_XX", 0.0),
815
+ ("ro_RO", 0.0),
816
+ ("ru_RU", 0.0),
817
+ ("si_LK", 0.0),
818
+ ("tr_TR", 0.0),
819
+ ("vi_VN", 0.0),
820
+ ("zh_CN", 0.0),
821
+ ]
822
+ vocab += [("<mask>", 0.0)]
823
+ return vocab
824
+
825
+ def unk_id(self, proto):
826
+ return 3
827
+
828
+ def post_processor(self):
829
+ return processors.TemplateProcessing(
830
+ single="$A </s> en_XX",
831
+ pair="$A $B </s> en_XX",
832
+ special_tokens=[
833
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
834
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
835
+ ],
836
+ )
837
+
838
+
839
+ class MBart50Converter(SpmConverter):
840
+ def vocab(self, proto):
841
+ vocab = [
842
+ ("<s>", 0.0),
843
+ ("<pad>", 0.0),
844
+ ("</s>", 0.0),
845
+ ("<unk>", 0.0),
846
+ ]
847
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
848
+ vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip
849
+ vocab += [("<mask>", 0.0)]
850
+ return vocab
851
+
852
+ def unk_id(self, proto):
853
+ return 3
854
+
855
+ def post_processor(self):
856
+ return processors.TemplateProcessing(
857
+ single="en_XX $A </s>",
858
+ pair="en_XX $A $B </s>",
859
+ special_tokens=[
860
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
861
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
862
+ ],
863
+ )
864
+
865
+
866
+ class NllbConverter(SpmConverter):
867
+ def vocab(self, proto):
868
+ vocab = [
869
+ ("<s>", 0.0),
870
+ ("<pad>", 0.0),
871
+ ("</s>", 0.0),
872
+ ("<unk>", 0.0),
873
+ ]
874
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
875
+ return vocab
876
+
877
+ def unk_id(self, proto):
878
+ return 3
879
+
880
+ def post_processor(self):
881
+ return processors.TemplateProcessing(
882
+ single="eng_Latn $A </s>",
883
+ pair="eng_Latn $A $B </s>",
884
+ special_tokens=[
885
+ ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
886
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
887
+ ],
888
+ )
889
+
890
+
891
+ class SeamlessM4TConverter(SpmConverter):
892
+ def vocab(self, proto):
893
+ vocab = [
894
+ ("<pad>", 0.0),
895
+ ("<unk>", 0.0),
896
+ ("<s>", 0.0),
897
+ ("</s>", 0.0),
898
+ ]
899
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
900
+ return vocab
901
+
902
+ def unk_id(self, proto):
903
+ return self.original_tokenizer.unk_token_id
904
+
905
+ def post_processor(self):
906
+ return processors.TemplateProcessing(
907
+ single="__eng__ $A </s>",
908
+ pair="__eng__ $A $B </s>",
909
+ special_tokens=[
910
+ ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
911
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
912
+ ],
913
+ )
914
+
915
+
916
+ class XLMRobertaConverter(SpmConverter):
917
+ def vocab(self, proto):
918
+ vocab = [
919
+ ("<s>", 0.0),
920
+ ("<pad>", 0.0),
921
+ ("</s>", 0.0),
922
+ ("<unk>", 0.0),
923
+ ]
924
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
925
+ vocab += [("<mask>", 0.0)]
926
+ return vocab
927
+
928
+ def unk_id(self, proto):
929
+ unk_id = 3
930
+ return unk_id
931
+
932
+ def post_processor(self):
933
+ return processors.TemplateProcessing(
934
+ single="<s> $A </s>",
935
+ pair="<s> $A </s> </s> $B </s>",
936
+ special_tokens=[
937
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
938
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
939
+ ],
940
+ )
941
+
942
+
943
+ class XLNetConverter(SpmConverter):
944
+ def vocab(self, proto):
945
+ return [
946
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
947
+ for piece in proto.pieces
948
+ ]
949
+
950
+ def normalizer(self, proto):
951
+ list_normalizers = [
952
+ normalizers.Replace("``", '"'),
953
+ normalizers.Replace("''", '"'),
954
+ ]
955
+ if not self.original_tokenizer.keep_accents:
956
+ list_normalizers.append(normalizers.NFKD())
957
+ list_normalizers.append(normalizers.StripAccents())
958
+ if self.original_tokenizer.do_lower_case:
959
+ list_normalizers.append(normalizers.Lowercase())
960
+
961
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
962
+
963
+ if precompiled_charsmap:
964
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
965
+
966
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
967
+ return normalizers.Sequence(list_normalizers)
968
+
969
+ def post_processor(self):
970
+ return processors.TemplateProcessing(
971
+ single="$A:0 <sep>:0 <cls>:2",
972
+ pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
973
+ special_tokens=[
974
+ ("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
975
+ ("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
976
+ ],
977
+ )
978
+
979
+
980
+ class ReformerConverter(SpmConverter):
981
+ pass
982
+
983
+
984
+ class RemBertConverter(SpmConverter):
985
+ # Inspired from AlbertConverter
986
+ def normalizer(self, proto):
987
+ list_normalizers = [
988
+ normalizers.Replace("``", '"'),
989
+ normalizers.Replace("''", '"'),
990
+ normalizers.Replace(Regex(" {2,}"), " "),
991
+ ]
992
+ if not self.original_tokenizer.keep_accents:
993
+ list_normalizers.append(normalizers.NFKD())
994
+ list_normalizers.append(normalizers.StripAccents())
995
+ if self.original_tokenizer.do_lower_case:
996
+ list_normalizers.append(normalizers.Lowercase())
997
+
998
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
999
+
1000
+ if precompiled_charsmap:
1001
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
1002
+
1003
+ return normalizers.Sequence(list_normalizers)
1004
+
1005
+ def post_processor(self):
1006
+ return processors.TemplateProcessing(
1007
+ single="[CLS]:0 $A:0 [SEP]:0",
1008
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
1009
+ special_tokens=[
1010
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
1011
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
1012
+ ],
1013
+ )
1014
+
1015
+
1016
+ class BertGenerationConverter(SpmConverter):
1017
+ pass
1018
+
1019
+
1020
+ class PegasusConverter(SpmConverter):
1021
+ def vocab(self, proto):
1022
+ vocab = [
1023
+ (self.original_tokenizer.pad_token, 0.0),
1024
+ (self.original_tokenizer.eos_token, 0.0),
1025
+ ]
1026
+
1027
+ if self.original_tokenizer.mask_token_sent is not None:
1028
+ vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
1029
+
1030
+ if (
1031
+ self.original_tokenizer.mask_token is not None
1032
+ and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
1033
+ ):
1034
+ vocab += [(self.original_tokenizer.mask_token, 0.0)]
1035
+
1036
+ vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
1037
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
1038
+ return vocab
1039
+
1040
+ def unk_id(self, proto):
1041
+ return proto.trainer_spec.unk_id + self.original_tokenizer.offset
1042
+
1043
+ def pre_tokenizer(self, replacement, add_prefix_space):
1044
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
1045
+ return pre_tokenizers.Sequence(
1046
+ [
1047
+ pre_tokenizers.WhitespaceSplit(),
1048
+ pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
1049
+ ]
1050
+ )
1051
+
1052
+ def post_processor(self):
1053
+ eos = self.original_tokenizer.eos_token
1054
+ special_tokens = [
1055
+ (eos, self.original_tokenizer.eos_token_id),
1056
+ ]
1057
+ return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
1058
+
1059
+
1060
+ class T5Converter(SpmConverter):
1061
+ def vocab(self, proto):
1062
+ num_extra_ids = self.original_tokenizer._extra_ids
1063
+ vocab = [(piece.piece, piece.score) for piece in proto.pieces]
1064
+ vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
1065
+ return vocab
1066
+
1067
+ def post_processor(self):
1068
+ return processors.TemplateProcessing(
1069
+ single=["$A", "</s>"],
1070
+ pair=["$A", "</s>", "$B", "</s>"],
1071
+ special_tokens=[
1072
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
1073
+ ],
1074
+ )
1075
+
1076
+
1077
+ class UdopConverter(SpmConverter):
1078
+ def post_processor(self):
1079
+ return processors.TemplateProcessing(
1080
+ single=["$A", "</s>"],
1081
+ pair=["$A", "</s>", "$B", "</s>"],
1082
+ special_tokens=[
1083
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
1084
+ ],
1085
+ )
1086
+
1087
+
1088
+ class WhisperConverter(Converter):
1089
+ def converted(self) -> Tokenizer:
1090
+ vocab = self.original_tokenizer.encoder
1091
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
1092
+
1093
+ tokenizer = Tokenizer(
1094
+ BPE(
1095
+ vocab=vocab,
1096
+ merges=merges,
1097
+ dropout=None,
1098
+ continuing_subword_prefix="",
1099
+ end_of_word_suffix="",
1100
+ fuse_unk=False,
1101
+ )
1102
+ )
1103
+
1104
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
1105
+ tokenizer.decoder = decoders.ByteLevel()
1106
+
1107
+ prefix_token_ids = self.original_tokenizer.prefix_tokens
1108
+ prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
1109
+ eos = self.original_tokenizer.eos_token
1110
+ eos_token_id = self.original_tokenizer.eos_token_id
1111
+ prefix_template = " ".join([f"{token}:0" for token in prefixes])
1112
+ tokenizer.post_processor = processors.TemplateProcessing(
1113
+ single=f"{prefix_template} $A:0 {eos}:0",
1114
+ pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
1115
+ special_tokens=[
1116
+ (eos, eos_token_id),
1117
+ *zip(prefixes, prefix_token_ids),
1118
+ ],
1119
+ )
1120
+
1121
+ return tokenizer
1122
+
1123
+
1124
+ class BigBirdConverter(SpmConverter):
1125
+ def post_processor(self):
1126
+ return processors.TemplateProcessing(
1127
+ single="[CLS]:0 $A:0 [SEP]:0",
1128
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
1129
+ special_tokens=[
1130
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
1131
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
1132
+ ],
1133
+ )
1134
+
1135
+
1136
+ class CLIPConverter(Converter):
1137
+ def converted(self) -> Tokenizer:
1138
+ vocab = self.original_tokenizer.encoder
1139
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
1140
+ unk_token = self.original_tokenizer.unk_token
1141
+
1142
+ tokenizer = Tokenizer(
1143
+ BPE(
1144
+ vocab=vocab,
1145
+ merges=merges,
1146
+ dropout=None,
1147
+ continuing_subword_prefix="",
1148
+ end_of_word_suffix="</w>",
1149
+ fuse_unk=False,
1150
+ unk_token=str(unk_token),
1151
+ )
1152
+ )
1153
+
1154
+ tokenizer.normalizer = normalizers.Sequence(
1155
+ [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
1156
+ )
1157
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
1158
+ [
1159
+ pre_tokenizers.Split(
1160
+ Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
1161
+ behavior="removed",
1162
+ invert=True,
1163
+ ),
1164
+ pre_tokenizers.ByteLevel(add_prefix_space=False),
1165
+ ]
1166
+ )
1167
+ tokenizer.decoder = decoders.ByteLevel()
1168
+
1169
+ # Hack to have a ByteLevel and TemplaceProcessor
1170
+ tokenizer.post_processor = processors.RobertaProcessing(
1171
+ sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
1172
+ cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
1173
+ add_prefix_space=False,
1174
+ trim_offsets=False,
1175
+ )
1176
+ return tokenizer
1177
+
1178
+
1179
+ class LayoutLMv2Converter(Converter):
1180
+ def converted(self) -> Tokenizer:
1181
+ vocab = self.original_tokenizer.vocab
1182
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
1183
+
1184
+ tokenize_chinese_chars = False
1185
+ strip_accents = False
1186
+ do_lower_case = True
1187
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
1188
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
1189
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
1190
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
1191
+
1192
+ tokenizer.normalizer = normalizers.BertNormalizer(
1193
+ clean_text=True,
1194
+ handle_chinese_chars=tokenize_chinese_chars,
1195
+ strip_accents=strip_accents,
1196
+ lowercase=do_lower_case,
1197
+ )
1198
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
1199
+
1200
+ cls = str(self.original_tokenizer.cls_token)
1201
+ sep = str(self.original_tokenizer.sep_token)
1202
+ cls_token_id = self.original_tokenizer.cls_token_id
1203
+ sep_token_id = self.original_tokenizer.sep_token_id
1204
+
1205
+ tokenizer.post_processor = processors.TemplateProcessing(
1206
+ single=f"{cls}:0 $A:0 {sep}:0",
1207
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
1208
+ special_tokens=[
1209
+ (cls, cls_token_id),
1210
+ (sep, sep_token_id),
1211
+ ],
1212
+ )
1213
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
1214
+
1215
+ return tokenizer
1216
+
1217
+
1218
+ class BlenderbotConverter(Converter):
1219
+ def converted(self) -> Tokenizer:
1220
+ ot = self.original_tokenizer
1221
+ vocab = ot.encoder
1222
+ merges = list(ot.bpe_ranks.keys())
1223
+
1224
+ tokenizer = Tokenizer(
1225
+ BPE(
1226
+ vocab=vocab,
1227
+ merges=merges,
1228
+ dropout=None,
1229
+ continuing_subword_prefix="",
1230
+ end_of_word_suffix="",
1231
+ fuse_unk=False,
1232
+ )
1233
+ )
1234
+
1235
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
1236
+ tokenizer.decoder = decoders.ByteLevel()
1237
+ tokenizer.post_processor = processors.TemplateProcessing(
1238
+ single=f"$A:0 {ot.eos_token}:0",
1239
+ special_tokens=[
1240
+ (ot.eos_token, ot.eos_token_id),
1241
+ ],
1242
+ )
1243
+
1244
+ return tokenizer
1245
+
1246
+
1247
+ class XGLMConverter(SpmConverter):
1248
+ def vocab(self, proto):
1249
+ vocab = [
1250
+ ("<s>", 0.0),
1251
+ ("<pad>", 0.0),
1252
+ ("</s>", 0.0),
1253
+ ("<unk>", 0.0),
1254
+ ]
1255
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1256
+ vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)] # fmt: skip
1257
+ return vocab
1258
+
1259
+ def unk_id(self, proto):
1260
+ unk_id = 3
1261
+ return unk_id
1262
+
1263
+ def post_processor(self):
1264
+ return processors.TemplateProcessing(
1265
+ single="</s> $A",
1266
+ pair="</s> $A </s> </s> $B",
1267
+ special_tokens=[
1268
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
1269
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
1270
+ ],
1271
+ )
1272
+
1273
+
1274
+ class GemmaConverter(SpmConverter):
1275
+ handle_byte_fallback = True
1276
+ SpmExtractor = GemmaSentencePieceExtractor
1277
+ # start and end of turn tokens must be marked as special
1278
+ special_tokens = {"<start_of_turn>", "<end_of_turn>"}
1279
+
1280
+ """"
1281
+ split_by_unicode_script: true
1282
+ split_by_number: true
1283
+ split_by_whitespace: true
1284
+ treat_whitespace_as_suffix: false
1285
+ allow_whitespace_only_pieces: true
1286
+ split_digits: true
1287
+ byte_fallback: true
1288
+ """
1289
+
1290
+ def normalizer(self, proto):
1291
+ return normalizers.Replace(" ", "▁")
1292
+
1293
+ def vocab(self, proto):
1294
+ vocab = [
1295
+ (self.original_tokenizer.pad_token, 0.0),
1296
+ (self.original_tokenizer.eos_token, 0.0),
1297
+ (self.original_tokenizer.bos_token, 0.0),
1298
+ ]
1299
+ for piece in proto.pieces[3:]:
1300
+ if piece.piece == "<0x09>":
1301
+ vocab += [("\t", piece.score)]
1302
+ else:
1303
+ vocab += [(piece.piece, piece.score)]
1304
+ # vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1305
+ return vocab
1306
+
1307
+ def pre_tokenizer(self, replacement, add_prefix_space):
1308
+ return pre_tokenizers.Split(" ", "merged_with_previous")
1309
+
1310
+ def unk_id(self, proto):
1311
+ unk_id = 3
1312
+ return unk_id
1313
+
1314
+ def decoder(self, replacement, add_prefix_space):
1315
+ return decoders.Sequence(
1316
+ [
1317
+ decoders.Replace("▁", " "),
1318
+ decoders.ByteFallback(),
1319
+ decoders.Fuse(),
1320
+ ]
1321
+ )
1322
+
1323
+
1324
+ class LlamaConverter(SpmConverter):
1325
+ handle_byte_fallback = True
1326
+
1327
+ def vocab(self, proto):
1328
+ vocab = [
1329
+ (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
1330
+ (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
1331
+ (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
1332
+ ]
1333
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1334
+ return vocab
1335
+
1336
+ def unk_id(self, proto):
1337
+ unk_id = 0
1338
+ return unk_id
1339
+
1340
+ def decoder(self, replacement, add_prefix_space):
1341
+ sequence = [
1342
+ decoders.Replace("▁", " "),
1343
+ decoders.ByteFallback(),
1344
+ decoders.Fuse(),
1345
+ ]
1346
+ if add_prefix_space:
1347
+ sequence += [decoders.Strip(content=" ", left=1)]
1348
+ return decoders.Sequence(sequence)
1349
+
1350
+ def normalizer(self, proto):
1351
+ if getattr(self.original_tokenizer, "legacy", True):
1352
+ sequence = []
1353
+ if getattr(self.original_tokenizer, "add_prefix_space", True):
1354
+ sequence += [normalizers.Prepend(prepend="▁")]
1355
+ sequence += [normalizers.Replace(pattern=" ", content="▁")]
1356
+ return normalizers.Sequence(sequence)
1357
+ return None # non-legacy, no normalizer
1358
+
1359
+ def pre_tokenizer(self, replacement, add_prefix_space):
1360
+ if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
1361
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
1362
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
1363
+ return None
1364
+
1365
+ def post_processor(self):
1366
+ # the processor is defined in the LlamaTokenizerFast class.
1367
+ return None
1368
+
1369
+
1370
+ class MarkupLMConverter(Converter):
1371
+ def converted(self) -> Tokenizer:
1372
+ ot = self.original_tokenizer
1373
+ vocab = ot.encoder
1374
+ merges = list(ot.bpe_ranks.keys())
1375
+
1376
+ tokenizer = Tokenizer(
1377
+ BPE(
1378
+ vocab=vocab,
1379
+ merges=merges,
1380
+ dropout=None,
1381
+ continuing_subword_prefix="",
1382
+ end_of_word_suffix="",
1383
+ fuse_unk=False,
1384
+ unk_token=self.original_tokenizer.unk_token,
1385
+ )
1386
+ )
1387
+
1388
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
1389
+ tokenizer.decoder = decoders.ByteLevel()
1390
+
1391
+ cls = str(self.original_tokenizer.cls_token)
1392
+ sep = str(self.original_tokenizer.sep_token)
1393
+ cls_token_id = self.original_tokenizer.cls_token_id
1394
+ sep_token_id = self.original_tokenizer.sep_token_id
1395
+
1396
+ tokenizer.post_processor = processors.TemplateProcessing(
1397
+ single=f"{cls} $A {sep}",
1398
+ pair=f"{cls} $A {sep} $B {sep}",
1399
+ special_tokens=[
1400
+ (cls, cls_token_id),
1401
+ (sep, sep_token_id),
1402
+ ],
1403
+ )
1404
+
1405
+ return tokenizer
1406
+
1407
+
1408
+ class MoshiConverter(SpmConverter):
1409
+ handle_byte_fallback = True
1410
+
1411
+ def __init__(self, vocab_file, model_max_length=None, **kwargs):
1412
+ requires_backends(self, "protobuf")
1413
+
1414
+ Converter.__init__(self, vocab_file)
1415
+
1416
+ # from .utils import sentencepiece_model_pb2 as model_pb2
1417
+ model_pb2 = import_protobuf()
1418
+
1419
+ m = model_pb2.ModelProto()
1420
+ with open(vocab_file, "rb") as f:
1421
+ m.ParseFromString(f.read())
1422
+ self.proto = m
1423
+
1424
+ def normalizer(self, proto):
1425
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
1426
+ _normalizers = [
1427
+ normalizers.Replace(" ", "▁"),
1428
+ ]
1429
+ if not precompiled_charsmap:
1430
+ return normalizers.Sequence(_normalizers)
1431
+ else:
1432
+ return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
1433
+
1434
+ def decoder(self, replacement, add_prefix_space):
1435
+ sequence = [
1436
+ decoders.Replace("▁", " "),
1437
+ decoders.ByteFallback(),
1438
+ decoders.Fuse(),
1439
+ ]
1440
+ if add_prefix_space:
1441
+ sequence += [decoders.Strip(content=" ", left=1)]
1442
+ return decoders.Sequence(sequence)
1443
+
1444
+ def pre_tokenizer(self, replacement, add_prefix_space):
1445
+ prepend_scheme = "first"
1446
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
1447
+
1448
+
1449
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
1450
+ def bytes_to_unicode():
1451
+ """
1452
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
1453
+ characters the bpe code barfs on.
1454
+
1455
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
1456
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
1457
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
1458
+ tables between utf-8 bytes and unicode strings.
1459
+ """
1460
+ bs = (
1461
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
1462
+ )
1463
+ cs = bs[:]
1464
+ n = 0
1465
+ for b in range(2**8):
1466
+ if b not in bs:
1467
+ bs.append(b)
1468
+ cs.append(2**8 + n)
1469
+ n += 1
1470
+ cs = [chr(n) for n in cs]
1471
+ return dict(zip(bs, cs))
1472
+
1473
+
1474
+ class TikTokenConverter:
1475
+ """
1476
+ A general tiktoken converter.
1477
+ """
1478
+
1479
+ def __init__(
1480
+ self,
1481
+ vocab_file=None,
1482
+ pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
1483
+ add_prefix_space=False,
1484
+ additional_special_tokens=None,
1485
+ *args,
1486
+ **kwargs,
1487
+ ):
1488
+ super().__init__(*args)
1489
+ self.vocab_file = vocab_file
1490
+ self.pattern = pattern
1491
+ self.add_prefix_space = add_prefix_space
1492
+ self.additional_special_tokens = additional_special_tokens
1493
+
1494
+ def extract_vocab_merges_from_model(self, tiktoken_url: str):
1495
+ try:
1496
+ from tiktoken.load import load_tiktoken_bpe
1497
+ except Exception:
1498
+ raise ValueError(
1499
+ "`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`."
1500
+ )
1501
+
1502
+ bpe_ranks = load_tiktoken_bpe(tiktoken_url)
1503
+ byte_encoder = bytes_to_unicode()
1504
+
1505
+ def token_bytes_to_string(b):
1506
+ return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
1507
+
1508
+ merges = []
1509
+ vocab = {}
1510
+ for token, rank in bpe_ranks.items():
1511
+ vocab[token_bytes_to_string(token)] = rank
1512
+ if len(token) == 1:
1513
+ continue
1514
+ local = []
1515
+ for index in range(1, len(token)):
1516
+ piece_l, piece_r = token[:index], token[index:]
1517
+ if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
1518
+ local.append((piece_l, piece_r, rank))
1519
+ local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
1520
+ merges.extend(local)
1521
+ merges = sorted(merges, key=lambda val: val[2], reverse=False)
1522
+ merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
1523
+ return vocab, merges
1524
+
1525
+ def tokenizer(self):
1526
+ vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
1527
+ tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
1528
+ if hasattr(tokenizer.model, "ignore_merges"):
1529
+ tokenizer.model.ignore_merges = True
1530
+ return tokenizer
1531
+
1532
+ def converted(self) -> Tokenizer:
1533
+ tokenizer = self.tokenizer()
1534
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
1535
+ [
1536
+ pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
1537
+ pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
1538
+ ]
1539
+ )
1540
+ tokenizer.decoder = decoders.ByteLevel()
1541
+ tokenizer.add_special_tokens(self.additional_special_tokens)
1542
+
1543
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
1544
+
1545
+ return tokenizer
1546
+
1547
+
1548
+ SLOW_TO_FAST_CONVERTERS = {
1549
+ "AlbertTokenizer": AlbertConverter,
1550
+ "BartTokenizer": RobertaConverter,
1551
+ "BarthezTokenizer": BarthezConverter,
1552
+ "BertTokenizer": BertConverter,
1553
+ "BigBirdTokenizer": BigBirdConverter,
1554
+ "BlenderbotTokenizer": BlenderbotConverter,
1555
+ "CamembertTokenizer": CamembertConverter,
1556
+ "CLIPTokenizer": CLIPConverter,
1557
+ "CodeGenTokenizer": GPT2Converter,
1558
+ "ConvBertTokenizer": BertConverter,
1559
+ "DebertaTokenizer": DebertaConverter,
1560
+ "DebertaV2Tokenizer": DebertaV2Converter,
1561
+ "DistilBertTokenizer": BertConverter,
1562
+ "DPRReaderTokenizer": BertConverter,
1563
+ "DPRQuestionEncoderTokenizer": BertConverter,
1564
+ "DPRContextEncoderTokenizer": BertConverter,
1565
+ "ElectraTokenizer": BertConverter,
1566
+ "FNetTokenizer": AlbertConverter,
1567
+ "FunnelTokenizer": FunnelConverter,
1568
+ "GPT2Tokenizer": GPT2Converter,
1569
+ "HerbertTokenizer": HerbertConverter,
1570
+ "LayoutLMTokenizer": BertConverter,
1571
+ "LayoutLMv2Tokenizer": BertConverter,
1572
+ "LayoutLMv3Tokenizer": RobertaConverter,
1573
+ "LayoutXLMTokenizer": XLMRobertaConverter,
1574
+ "LongformerTokenizer": RobertaConverter,
1575
+ "LEDTokenizer": RobertaConverter,
1576
+ "LxmertTokenizer": BertConverter,
1577
+ "MarkupLMTokenizer": MarkupLMConverter,
1578
+ "MBartTokenizer": MBartConverter,
1579
+ "MBart50Tokenizer": MBart50Converter,
1580
+ "MPNetTokenizer": MPNetConverter,
1581
+ "MobileBertTokenizer": BertConverter,
1582
+ "MvpTokenizer": RobertaConverter,
1583
+ "NllbTokenizer": NllbConverter,
1584
+ "OpenAIGPTTokenizer": OpenAIGPTConverter,
1585
+ "PegasusTokenizer": PegasusConverter,
1586
+ "Qwen2Tokenizer": Qwen2Converter,
1587
+ "RealmTokenizer": BertConverter,
1588
+ "ReformerTokenizer": ReformerConverter,
1589
+ "RemBertTokenizer": RemBertConverter,
1590
+ "RetriBertTokenizer": BertConverter,
1591
+ "RobertaTokenizer": RobertaConverter,
1592
+ "RoFormerTokenizer": RoFormerConverter,
1593
+ "SeamlessM4TTokenizer": SeamlessM4TConverter,
1594
+ "SqueezeBertTokenizer": BertConverter,
1595
+ "T5Tokenizer": T5Converter,
1596
+ "UdopTokenizer": UdopConverter,
1597
+ "WhisperTokenizer": WhisperConverter,
1598
+ "XLMRobertaTokenizer": XLMRobertaConverter,
1599
+ "XLNetTokenizer": XLNetConverter,
1600
+ "SplinterTokenizer": SplinterConverter,
1601
+ "XGLMTokenizer": XGLMConverter,
1602
+ "LlamaTokenizer": LlamaConverter,
1603
+ "CodeLlamaTokenizer": LlamaConverter,
1604
+ "GemmaTokenizer": GemmaConverter,
1605
+ "Phi3Tokenizer": LlamaConverter,
1606
+ }
1607
+
1608
+
1609
+ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
1610
+ """
1611
+ Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
1612
+
1613
+ Args:
1614
+ transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
1615
+ Instance of a slow tokenizer to convert in the backend tokenizer for
1616
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`].
1617
+ from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece.
1618
+ Defaults to False.
1619
+
1620
+ Return:
1621
+ A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
1622
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`]
1623
+ """
1624
+
1625
+ tokenizer_class_name = transformer_tokenizer.__class__.__name__
1626
+ if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
1627
+ converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
1628
+ return converter_class(transformer_tokenizer).converted()
1629
+
1630
+ else:
1631
+ try:
1632
+ logger.info("Converting from Tiktoken")
1633
+ return TikTokenConverter(
1634
+ vocab_file=transformer_tokenizer.vocab_file,
1635
+ additional_special_tokens=transformer_tokenizer.additional_special_tokens,
1636
+ ).converted()
1637
+ except Exception:
1638
+ raise ValueError(
1639
+ f"Converting from Tiktoken failed, if a converter for SentencePiece is available, provide a model path "
1640
+ f"with a SentencePiece tokenizer.model file."
1641
+ f"Currently available slow->fast convertors: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
1642
+ )
.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizers_checkpoints_to_fast.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert slow tokenizers checkpoints in fast (serialization format of the `tokenizers` library)"""
16
+
17
+ import argparse
18
+ import os
19
+
20
+ import transformers
21
+
22
+ from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
23
+ from .utils import logging
24
+
25
+
26
+ logging.set_verbosity_info()
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ TOKENIZER_CLASSES = {
32
+ # Phi3 uses Llama tokenizer
33
+ name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast")
34
+ for name in SLOW_TO_FAST_CONVERTERS
35
+ }
36
+
37
+
38
+ def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):
39
+ if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES:
40
+ raise ValueError(f"Unrecognized tokenizer name, should be one of {list(TOKENIZER_CLASSES.keys())}.")
41
+
42
+ if tokenizer_name is None:
43
+ tokenizer_names = TOKENIZER_CLASSES
44
+ else:
45
+ tokenizer_names = {tokenizer_name: getattr(transformers, tokenizer_name + "Fast")}
46
+
47
+ logger.info(f"Loading tokenizer classes: {tokenizer_names}")
48
+
49
+ for tokenizer_name in tokenizer_names:
50
+ tokenizer_class = TOKENIZER_CLASSES[tokenizer_name]
51
+
52
+ add_prefix = True
53
+ if checkpoint_name is None:
54
+ checkpoint_names = list(tokenizer_class.max_model_input_sizes.keys())
55
+ else:
56
+ checkpoint_names = [checkpoint_name]
57
+
58
+ logger.info(f"For tokenizer {tokenizer_class.__class__.__name__} loading checkpoints: {checkpoint_names}")
59
+
60
+ for checkpoint in checkpoint_names:
61
+ logger.info(f"Loading {tokenizer_class.__class__.__name__} {checkpoint}")
62
+
63
+ # Load tokenizer
64
+ tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download)
65
+
66
+ # Save fast tokenizer
67
+ logger.info(f"Save fast tokenizer to {dump_path} with prefix {checkpoint} add_prefix {add_prefix}")
68
+
69
+ # For organization names we create sub-directories
70
+ if "/" in checkpoint:
71
+ checkpoint_directory, checkpoint_prefix_name = checkpoint.split("/")
72
+ dump_path_full = os.path.join(dump_path, checkpoint_directory)
73
+ elif add_prefix:
74
+ checkpoint_prefix_name = checkpoint
75
+ dump_path_full = dump_path
76
+ else:
77
+ checkpoint_prefix_name = None
78
+ dump_path_full = dump_path
79
+
80
+ logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
81
+
82
+ if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]:
83
+ file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint]
84
+ next_char = file_path.split(checkpoint)[-1][0]
85
+ if next_char == "/":
86
+ dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name)
87
+ checkpoint_prefix_name = None
88
+
89
+ logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
90
+
91
+ file_names = tokenizer.save_pretrained(
92
+ dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name
93
+ )
94
+ logger.info(f"=> File names {file_names}")
95
+
96
+ for file_name in file_names:
97
+ if not file_name.endswith("tokenizer.json"):
98
+ os.remove(file_name)
99
+ logger.info(f"=> removing {file_name}")
100
+
101
+
102
+ if __name__ == "__main__":
103
+ parser = argparse.ArgumentParser()
104
+ # Required parameters
105
+ parser.add_argument(
106
+ "--dump_path", default=None, type=str, required=True, help="Path to output generated fast tokenizer files."
107
+ )
108
+ parser.add_argument(
109
+ "--tokenizer_name",
110
+ default=None,
111
+ type=str,
112
+ help=(
113
+ f"Optional tokenizer type selected in the list of {list(TOKENIZER_CLASSES.keys())}. If not given, will "
114
+ "download and convert all the checkpoints from AWS."
115
+ ),
116
+ )
117
+ parser.add_argument(
118
+ "--checkpoint_name",
119
+ default=None,
120
+ type=str,
121
+ help="Optional checkpoint name. If not given, will download and convert the canonical checkpoints from AWS.",
122
+ )
123
+ parser.add_argument(
124
+ "--force_download",
125
+ action="store_true",
126
+ help="Re-download checkpoints.",
127
+ )
128
+ args = parser.parse_args()
129
+
130
+ convert_slow_checkpoint_to_fast(args.tokenizer_name, args.checkpoint_name, args.dump_path, args.force_download)
.venv/lib/python3.11/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Seq2Seq TF Hub checkpoint."""
16
+
17
+ import argparse
18
+
19
+ from . import (
20
+ BertConfig,
21
+ BertGenerationConfig,
22
+ BertGenerationDecoder,
23
+ BertGenerationEncoder,
24
+ load_tf_weights_in_bert_generation,
25
+ logging,
26
+ )
27
+
28
+
29
+ logging.set_verbosity_info()
30
+
31
+
32
+ def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
33
+ # Initialise PyTorch model
34
+ bert_config = BertConfig.from_pretrained(
35
+ "google-bert/bert-large-cased",
36
+ vocab_size=vocab_size,
37
+ max_position_embeddings=512,
38
+ is_decoder=True,
39
+ add_cross_attention=True,
40
+ )
41
+ bert_config_dict = bert_config.to_dict()
42
+ del bert_config_dict["type_vocab_size"]
43
+ config = BertGenerationConfig(**bert_config_dict)
44
+ if is_encoder:
45
+ model = BertGenerationEncoder(config)
46
+ else:
47
+ model = BertGenerationDecoder(config)
48
+ print(f"Building PyTorch model from configuration: {config}")
49
+
50
+ # Load weights from tf checkpoint
51
+ load_tf_weights_in_bert_generation(
52
+ model,
53
+ tf_hub_path,
54
+ model_class="bert",
55
+ is_encoder_named_decoder=is_encoder_named_decoder,
56
+ is_encoder=is_encoder,
57
+ )
58
+
59
+ # Save pytorch-model
60
+ print(f"Save PyTorch model and config to {pytorch_dump_path}")
61
+ model.save_pretrained(pytorch_dump_path)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser()
66
+ # Required parameters
67
+ parser.add_argument(
68
+ "--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
69
+ )
70
+ parser.add_argument(
71
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
72
+ )
73
+ parser.add_argument(
74
+ "--is_encoder_named_decoder",
75
+ action="store_true",
76
+ help="If decoder has to be renamed to encoder in PyTorch model.",
77
+ )
78
+ parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.")
79
+ parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model")
80
+ args = parser.parse_args()
81
+ convert_tf_checkpoint_to_pytorch(
82
+ args.tf_hub_path,
83
+ args.pytorch_dump_path,
84
+ args.is_encoder_named_decoder,
85
+ args.vocab_size,
86
+ is_encoder=args.is_encoder,
87
+ )
.venv/lib/python3.11/site-packages/transformers/debug_utils.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+
17
+ from .utils import ExplicitEnum, is_torch_available, logging
18
+
19
+
20
+ if is_torch_available():
21
+ import torch
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class DebugUnderflowOverflow:
28
+ """
29
+ This debug class helps detect and understand where the model starts getting very large or very small, and more
30
+ importantly `nan` or `inf` weight and activation elements.
31
+
32
+ There are 2 working modes:
33
+
34
+ 1. Underflow/overflow detection (default)
35
+ 2. Specific batch absolute min/max tracing without detection
36
+
37
+ Mode 1: Underflow/overflow detection
38
+
39
+ To activate the underflow/overflow detection, initialize the object with the model :
40
+
41
+ ```python
42
+ debug_overflow = DebugUnderflowOverflow(model)
43
+ ```
44
+
45
+ then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output
46
+ elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event,
47
+ each frame reporting
48
+
49
+ 1. the fully qualified module name plus the class name whose `forward` was run
50
+ 2. the absolute min and max value of all elements for each module weights, and the inputs and output
51
+
52
+ For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16
53
+ mixed precision :
54
+
55
+ ```
56
+ Detected inf/nan during batch_number=0
57
+ Last 21 forward frames:
58
+ abs min abs max metadata
59
+ [...]
60
+ encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
61
+ 2.17e-07 4.50e+00 weight
62
+ 1.79e-06 4.65e+00 input[0]
63
+ 2.68e-06 3.70e+01 output
64
+ encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
65
+ 8.08e-07 2.66e+01 weight
66
+ 1.79e-06 4.65e+00 input[0]
67
+ 1.27e-04 2.37e+02 output
68
+ encoder.block.2.layer.1.DenseReluDense.wo Linear
69
+ 1.01e-06 6.44e+00 weight
70
+ 0.00e+00 9.74e+03 input[0]
71
+ 3.18e-04 6.27e+04 output
72
+ encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
73
+ 1.79e-06 4.65e+00 input[0]
74
+ 3.18e-04 6.27e+04 output
75
+ encoder.block.2.layer.1.dropout Dropout
76
+ 3.18e-04 6.27e+04 input[0]
77
+ 0.00e+00 inf output
78
+ ```
79
+
80
+ You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was
81
+ around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
82
+ renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
83
+ 64K, and we get an overlow.
84
+
85
+ As you can see it's the previous frames that we need to look into when the numbers start going into very large for
86
+ fp16 numbers.
87
+
88
+ The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
89
+
90
+ By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
91
+
92
+ ```python
93
+ debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
94
+ ```
95
+
96
+ To validate that you have set up this debugging feature correctly, and you intend to use it in a training that
97
+ may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in
98
+ the next section.
99
+
100
+
101
+ Mode 2. Specific batch absolute min/max tracing without detection
102
+
103
+ The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
104
+
105
+ Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
106
+ given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
107
+
108
+ ```python
109
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
110
+ ```
111
+
112
+ And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
113
+
114
+ This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
115
+ fast-forward right to that area.
116
+
117
+
118
+ Early stopping:
119
+
120
+ You can also specify the batch number after which to stop the training, with :
121
+
122
+ ```python
123
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
124
+ ```
125
+
126
+ This feature is mainly useful in the tracing mode, but you can use it for any mode.
127
+
128
+
129
+ **Performance**:
130
+
131
+ As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training
132
+ down. Therefore remember to turn it off once the debugging needs have been met.
133
+
134
+ Args:
135
+ model (`nn.Module`):
136
+ The model to debug.
137
+ max_frames_to_save (`int`, *optional*, defaults to 21):
138
+ How many frames back to record
139
+ trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
140
+ Which batch numbers to trace (turns detection off)
141
+ abort_after_batch_num (`int``, *optional*):
142
+ Whether to abort after a certain batch number has finished
143
+ """
144
+
145
+ def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
146
+ self.model = model
147
+ self.trace_batch_nums = trace_batch_nums
148
+ self.abort_after_batch_num = abort_after_batch_num
149
+
150
+ # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
151
+ self.frames = collections.deque([], max_frames_to_save)
152
+ self.frame = []
153
+ self.batch_number = 0
154
+ self.total_calls = 0
155
+ self.detected_overflow = False
156
+ self.prefix = " "
157
+
158
+ self.analyse_model()
159
+
160
+ self.register_forward_hook()
161
+
162
+ def save_frame(self, frame=None):
163
+ if frame is not None:
164
+ self.expand_frame(frame)
165
+ self.frames.append("\n".join(self.frame))
166
+ self.frame = [] # start a new frame
167
+
168
+ def expand_frame(self, line):
169
+ self.frame.append(line)
170
+
171
+ def trace_frames(self):
172
+ print("\n".join(self.frames))
173
+ self.frames = []
174
+
175
+ def reset_saved_frames(self):
176
+ self.frames = []
177
+
178
+ def dump_saved_frames(self):
179
+ print(f"\nDetected inf/nan during batch_number={self.batch_number}")
180
+ print(f"Last {len(self.frames)} forward frames:")
181
+ print(f"{'abs min':8} {'abs max':8} metadata")
182
+ print("\n".join(self.frames))
183
+ print("\n\n")
184
+ self.frames = []
185
+
186
+ def analyse_model(self):
187
+ # extract the fully qualified module names, to be able to report at run time. e.g.:
188
+ # encoder.block.2.layer.0.SelfAttention.o
189
+ #
190
+ # for shared weights only the first shared module name will be registered
191
+ self.module_names = {m: name for name, m in self.model.named_modules()}
192
+ # self.longest_module_name = max(len(v) for v in self.module_names.values())
193
+
194
+ def analyse_variable(self, var, ctx):
195
+ if torch.is_tensor(var):
196
+ self.expand_frame(get_abs_min_max(var, ctx))
197
+ if detect_overflow(var, ctx):
198
+ self.detected_overflow = True
199
+ elif var is None:
200
+ self.expand_frame(f"{'None':>17} {ctx}")
201
+ else:
202
+ self.expand_frame(f"{'not a tensor':>17} {ctx}")
203
+
204
+ def batch_start_frame(self):
205
+ self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
206
+ self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
207
+
208
+ def batch_end_frame(self):
209
+ self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")
210
+
211
+ def create_frame(self, module, input, output):
212
+ self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
213
+
214
+ # params
215
+ for name, p in module.named_parameters(recurse=False):
216
+ self.analyse_variable(p, name)
217
+
218
+ # inputs
219
+ if isinstance(input, tuple):
220
+ for i, x in enumerate(input):
221
+ self.analyse_variable(x, f"input[{i}]")
222
+ else:
223
+ self.analyse_variable(input, "input")
224
+
225
+ # outputs
226
+ if isinstance(output, tuple):
227
+ for i, x in enumerate(output):
228
+ # possibly a tuple of tuples
229
+ if isinstance(x, tuple):
230
+ for j, y in enumerate(x):
231
+ self.analyse_variable(y, f"output[{i}][{j}]")
232
+ else:
233
+ self.analyse_variable(x, f"output[{i}]")
234
+ else:
235
+ self.analyse_variable(output, "output")
236
+
237
+ self.save_frame()
238
+
239
+ def register_forward_hook(self):
240
+ self.model.apply(self._register_forward_hook)
241
+
242
+ def _register_forward_hook(self, module):
243
+ module.register_forward_hook(self.forward_hook)
244
+
245
+ def forward_hook(self, module, input, output):
246
+ # - input is a tuple of packed inputs (could be non-Tensors)
247
+ # - output could be a Tensor or a tuple of Tensors and non-Tensors
248
+
249
+ last_frame_of_batch = False
250
+
251
+ trace_mode = True if self.batch_number in self.trace_batch_nums else False
252
+ if trace_mode:
253
+ self.reset_saved_frames()
254
+
255
+ if self.total_calls == 0:
256
+ self.batch_start_frame()
257
+ self.total_calls += 1
258
+
259
+ # count batch numbers - the very first forward hook of the batch will be called when the
260
+ # batch completes - i.e. it gets called very last - we know this batch has finished
261
+ if module == self.model:
262
+ self.batch_number += 1
263
+ last_frame_of_batch = True
264
+
265
+ self.create_frame(module, input, output)
266
+
267
+ # if last_frame_of_batch:
268
+ # self.batch_end_frame()
269
+
270
+ if trace_mode:
271
+ self.trace_frames()
272
+
273
+ if last_frame_of_batch:
274
+ self.batch_start_frame()
275
+
276
+ if self.detected_overflow and not trace_mode:
277
+ self.dump_saved_frames()
278
+
279
+ # now we can abort, as it's pointless to continue running
280
+ raise ValueError(
281
+ "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
282
+ "Please scroll up above this traceback to see the activation values prior to this event."
283
+ )
284
+
285
+ # abort after certain batch if requested to do so
286
+ if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
287
+ raise ValueError(
288
+ f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
289
+ f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
290
+ )
291
+
292
+
293
+ def get_abs_min_max(var, ctx):
294
+ abs_var = var.abs()
295
+ return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
296
+
297
+
298
+ def detect_overflow(var, ctx):
299
+ """
300
+ Report whether the tensor contains any `nan` or `inf` entries.
301
+
302
+ This is useful for detecting overflows/underflows and best to call right after the function that did some math that
303
+ modified the tensor in question.
304
+
305
+ This function contains a few other helper features that you can enable and tweak directly if you want to track
306
+ various other things.
307
+
308
+ Args:
309
+ var: the tensor variable to check
310
+ ctx: the message to print as a context
311
+
312
+ Return:
313
+ `True` if `inf` or `nan` was detected, `False` otherwise
314
+ """
315
+ detected = False
316
+ if torch.isnan(var).any().item():
317
+ detected = True
318
+ print(f"{ctx} has nans")
319
+ if torch.isinf(var).any().item():
320
+ detected = True
321
+ print(f"{ctx} has infs")
322
+
323
+ # if needed to monitor large elements can enable the following
324
+ if 0: # and detected:
325
+ n100 = var[torch.ge(var.abs(), 100)]
326
+ if n100.numel() > 0:
327
+ print(f"{ctx}: n100={n100.numel()}")
328
+ n1000 = var[torch.ge(var.abs(), 1000)]
329
+ if n1000.numel() > 0:
330
+ print(f"{ctx}: n1000={n1000.numel()}")
331
+ n10000 = var[torch.ge(var.abs(), 10000)]
332
+ if n10000.numel() > 0:
333
+ print(f"{ctx}: n10000={n10000.numel()}")
334
+
335
+ if 0:
336
+ print(f"min={var.min():9.2e} max={var.max():9.2e}")
337
+
338
+ if 0:
339
+ print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
340
+
341
+ return detected
342
+
343
+
344
+ class DebugOption(ExplicitEnum):
345
+ UNDERFLOW_OVERFLOW = "underflow_overflow"
346
+ TPU_METRICS_DEBUG = "tpu_metrics_debug"
.venv/lib/python3.11/site-packages/transformers/dependency_versions_check.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .dependency_versions_table import deps
16
+ from .utils.versions import require_version, require_version_core
17
+
18
+
19
+ # define which module versions we always want to check at run time
20
+ # (usually the ones defined in `install_requires` in setup.py)
21
+ #
22
+ # order specific notes:
23
+ # - tqdm must be checked before tokenizers
24
+
25
+ pkgs_to_check_at_runtime = [
26
+ "python",
27
+ "tqdm",
28
+ "regex",
29
+ "requests",
30
+ "packaging",
31
+ "filelock",
32
+ "numpy",
33
+ "tokenizers",
34
+ "huggingface-hub",
35
+ "safetensors",
36
+ "accelerate",
37
+ "pyyaml",
38
+ ]
39
+
40
+ for pkg in pkgs_to_check_at_runtime:
41
+ if pkg in deps:
42
+ if pkg == "tokenizers":
43
+ # must be loaded here, or else tqdm check may fail
44
+ from .utils import is_tokenizers_available
45
+
46
+ if not is_tokenizers_available():
47
+ continue # not required, check version only if installed
48
+ elif pkg == "accelerate":
49
+ # must be loaded here, or else tqdm check may fail
50
+ from .utils import is_accelerate_available
51
+
52
+ # Maybe switch to is_torch_available in the future here so that Accelerate is hard dep of
53
+ # Transformers with PyTorch
54
+ if not is_accelerate_available():
55
+ continue # not required, check version only if installed
56
+
57
+ require_version_core(deps[pkg])
58
+ else:
59
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
60
+
61
+
62
+ def dep_version_check(pkg, hint=None):
63
+ require_version(deps[pkg], hint)
.venv/lib/python3.11/site-packages/transformers/dependency_versions_table.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow>=10.0.1,<=15.0",
6
+ "accelerate": "accelerate>=0.26.0",
7
+ "av": "av==9.2.0",
8
+ "beautifulsoup4": "beautifulsoup4",
9
+ "blobfile": "blobfile",
10
+ "codecarbon": "codecarbon>=2.8.1",
11
+ "cookiecutter": "cookiecutter==1.7.3",
12
+ "dataclasses": "dataclasses",
13
+ "datasets": "datasets!=2.5.0",
14
+ "deepspeed": "deepspeed>=0.9.3",
15
+ "diffusers": "diffusers",
16
+ "dill": "dill<0.3.5",
17
+ "evaluate": "evaluate>=0.2.0",
18
+ "faiss-cpu": "faiss-cpu",
19
+ "fastapi": "fastapi",
20
+ "filelock": "filelock",
21
+ "flax": "flax>=0.4.1,<=0.7.0",
22
+ "fsspec": "fsspec<2023.10.0",
23
+ "ftfy": "ftfy",
24
+ "fugashi": "fugashi>=1.0",
25
+ "GitPython": "GitPython<3.1.19",
26
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
27
+ "huggingface-hub": "huggingface-hub>=0.24.0,<1.0",
28
+ "importlib_metadata": "importlib_metadata",
29
+ "ipadic": "ipadic>=1.0.0,<2.0",
30
+ "isort": "isort>=5.5.4",
31
+ "jax": "jax>=0.4.1,<=0.4.13",
32
+ "jaxlib": "jaxlib>=0.4.1,<=0.4.13",
33
+ "jieba": "jieba",
34
+ "jinja2": "jinja2>=3.1.0",
35
+ "kenlm": "kenlm",
36
+ "keras": "keras>2.9,<2.16",
37
+ "keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
38
+ "librosa": "librosa",
39
+ "nltk": "nltk<=3.8.1",
40
+ "natten": "natten>=0.14.6,<0.15.0",
41
+ "numpy": "numpy>=1.17",
42
+ "onnxconverter-common": "onnxconverter-common",
43
+ "onnxruntime-tools": "onnxruntime-tools>=1.4.2",
44
+ "onnxruntime": "onnxruntime>=1.4.0",
45
+ "opencv-python": "opencv-python",
46
+ "optimum-benchmark": "optimum-benchmark>=0.3.0",
47
+ "optuna": "optuna",
48
+ "optax": "optax>=0.0.8,<=0.1.4",
49
+ "packaging": "packaging>=20.0",
50
+ "parameterized": "parameterized",
51
+ "phonemizer": "phonemizer",
52
+ "protobuf": "protobuf",
53
+ "psutil": "psutil",
54
+ "pyyaml": "pyyaml>=5.1",
55
+ "pydantic": "pydantic",
56
+ "pytest": "pytest>=7.2.0,<8.0.0",
57
+ "pytest-asyncio": "pytest-asyncio",
58
+ "pytest-timeout": "pytest-timeout",
59
+ "pytest-xdist": "pytest-xdist",
60
+ "python": "python>=3.9.0",
61
+ "ray[tune]": "ray[tune]>=2.7.0",
62
+ "regex": "regex!=2019.12.17",
63
+ "requests": "requests",
64
+ "rhoknp": "rhoknp>=1.1.0,<1.3.1",
65
+ "rjieba": "rjieba",
66
+ "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
67
+ "ruff": "ruff==0.5.1",
68
+ "sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
69
+ "sacremoses": "sacremoses",
70
+ "safetensors": "safetensors>=0.4.1",
71
+ "sagemaker": "sagemaker>=2.31.0",
72
+ "schedulefree": "schedulefree>=1.2.6",
73
+ "scikit-learn": "scikit-learn",
74
+ "scipy": "scipy<1.13.0",
75
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
76
+ "sigopt": "sigopt",
77
+ "starlette": "starlette",
78
+ "sudachipy": "sudachipy>=0.6.6",
79
+ "sudachidict_core": "sudachidict_core>=20220729",
80
+ "tensorboard": "tensorboard",
81
+ "tensorflow-cpu": "tensorflow-cpu>2.9,<2.16",
82
+ "tensorflow": "tensorflow>2.9,<2.16",
83
+ "tensorflow-text": "tensorflow-text<2.16",
84
+ "tensorflow-probability": "tensorflow-probability<0.24",
85
+ "tf2onnx": "tf2onnx",
86
+ "timeout-decorator": "timeout-decorator",
87
+ "tiktoken": "tiktoken",
88
+ "timm": "timm<=1.0.11",
89
+ "tokenizers": "tokenizers>=0.21,<0.22",
90
+ "torch": "torch>=2.0",
91
+ "torchaudio": "torchaudio",
92
+ "torchvision": "torchvision",
93
+ "pyctcdecode": "pyctcdecode>=0.4.0",
94
+ "tqdm": "tqdm>=4.27",
95
+ "unidic": "unidic>=1.0.2",
96
+ "unidic_lite": "unidic_lite>=1.0.7",
97
+ "urllib3": "urllib3<2.0.0",
98
+ "uvicorn": "uvicorn",
99
+ "pytest-rich": "pytest-rich",
100
+ "libcst": "libcst",
101
+ "rich": "rich",
102
+ }
.venv/lib/python3.11/site-packages/transformers/dynamic_module_utils.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities to dynamically load objects from the Hub."""
16
+
17
+ import filecmp
18
+ import hashlib
19
+ import importlib
20
+ import importlib.util
21
+ import os
22
+ import re
23
+ import shutil
24
+ import signal
25
+ import sys
26
+ import threading
27
+ import typing
28
+ import warnings
29
+ from pathlib import Path
30
+ from types import ModuleType
31
+ from typing import Any, Dict, List, Optional, Union
32
+
33
+ from huggingface_hub import try_to_load_from_cache
34
+
35
+ from .utils import (
36
+ HF_MODULES_CACHE,
37
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
38
+ cached_file,
39
+ extract_commit_hash,
40
+ is_offline_mode,
41
+ logging,
42
+ )
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+ _HF_REMOTE_CODE_LOCK = threading.Lock()
47
+
48
+
49
+ def init_hf_modules():
50
+ """
51
+ Creates the cache directory for modules with an init, and adds it to the Python path.
52
+ """
53
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
54
+ if HF_MODULES_CACHE in sys.path:
55
+ return
56
+
57
+ sys.path.append(HF_MODULES_CACHE)
58
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
59
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
60
+ if not init_path.exists():
61
+ init_path.touch()
62
+ importlib.invalidate_caches()
63
+
64
+
65
+ def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
66
+ """
67
+ Creates a dynamic module in the cache directory for modules.
68
+
69
+ Args:
70
+ name (`str` or `os.PathLike`):
71
+ The name of the dynamic module to create.
72
+ """
73
+ init_hf_modules()
74
+ dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
75
+ # If the parent module does not exist yet, recursively create it.
76
+ if not dynamic_module_path.parent.exists():
77
+ create_dynamic_module(dynamic_module_path.parent)
78
+ os.makedirs(dynamic_module_path, exist_ok=True)
79
+ init_path = dynamic_module_path / "__init__.py"
80
+ if not init_path.exists():
81
+ init_path.touch()
82
+ # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
83
+ # with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
84
+ importlib.invalidate_caches()
85
+
86
+
87
+ def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
88
+ """
89
+ Get the list of modules that are relatively imported in a module file.
90
+
91
+ Args:
92
+ module_file (`str` or `os.PathLike`): The module file to inspect.
93
+
94
+ Returns:
95
+ `List[str]`: The list of relative imports in the module.
96
+ """
97
+ with open(module_file, "r", encoding="utf-8") as f:
98
+ content = f.read()
99
+
100
+ # Imports of the form `import .xxx`
101
+ relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
102
+ # Imports of the form `from .xxx import yyy`
103
+ relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
104
+ # Unique-ify
105
+ return list(set(relative_imports))
106
+
107
+
108
+ def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
109
+ """
110
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
111
+ imports (if a imports b and b imports c, it will return module files for b and c).
112
+
113
+ Args:
114
+ module_file (`str` or `os.PathLike`): The module file to inspect.
115
+
116
+ Returns:
117
+ `List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
118
+ of module files a given module needs.
119
+ """
120
+ no_change = False
121
+ files_to_check = [module_file]
122
+ all_relative_imports = []
123
+
124
+ # Let's recurse through all relative imports
125
+ while not no_change:
126
+ new_imports = []
127
+ for f in files_to_check:
128
+ new_imports.extend(get_relative_imports(f))
129
+
130
+ module_path = Path(module_file).parent
131
+ new_import_files = [str(module_path / m) for m in new_imports]
132
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
133
+ files_to_check = [f"{f}.py" for f in new_import_files]
134
+
135
+ no_change = len(new_import_files) == 0
136
+ all_relative_imports.extend(files_to_check)
137
+
138
+ return all_relative_imports
139
+
140
+
141
+ def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
142
+ """
143
+ Extracts all the libraries (not relative imports this time) that are imported in a file.
144
+
145
+ Args:
146
+ filename (`str` or `os.PathLike`): The module file to inspect.
147
+
148
+ Returns:
149
+ `List[str]`: The list of all packages required to use the input module.
150
+ """
151
+ with open(filename, "r", encoding="utf-8") as f:
152
+ content = f.read()
153
+
154
+ # filter out try/except block so in custom code we can have try/except imports
155
+ content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
156
+
157
+ # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
158
+ content = re.sub(
159
+ r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", "", content, flags=re.MULTILINE
160
+ )
161
+
162
+ # Imports of the form `import xxx`
163
+ imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
164
+ # Imports of the form `from xxx import yyy`
165
+ imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
166
+ # Only keep the top-level module
167
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
168
+ return list(set(imports))
169
+
170
+
171
+ def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
172
+ """
173
+ Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
174
+ library is missing.
175
+
176
+ Args:
177
+ filename (`str` or `os.PathLike`): The module file to check.
178
+
179
+ Returns:
180
+ `List[str]`: The list of relative imports in the file.
181
+ """
182
+ imports = get_imports(filename)
183
+ missing_packages = []
184
+ for imp in imports:
185
+ try:
186
+ importlib.import_module(imp)
187
+ except ImportError as exception:
188
+ logger.warning(f"Encountered exception while importing {imp}: {exception}")
189
+ # Some packages can fail with an ImportError because of a dependency issue.
190
+ # This check avoids hiding such errors.
191
+ # See https://github.com/huggingface/transformers/issues/33604
192
+ if "No module named" in str(exception):
193
+ missing_packages.append(imp)
194
+ else:
195
+ raise
196
+
197
+ if len(missing_packages) > 0:
198
+ raise ImportError(
199
+ "This modeling file requires the following packages that were not found in your environment: "
200
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
201
+ )
202
+
203
+ return get_relative_imports(filename)
204
+
205
+
206
+ def get_class_in_module(
207
+ class_name: str,
208
+ module_path: Union[str, os.PathLike],
209
+ *,
210
+ force_reload: bool = False,
211
+ ) -> typing.Type:
212
+ """
213
+ Import a module on the cache directory for modules and extract a class from it.
214
+
215
+ Args:
216
+ class_name (`str`): The name of the class to import.
217
+ module_path (`str` or `os.PathLike`): The path to the module to import.
218
+ force_reload (`bool`, *optional*, defaults to `False`):
219
+ Whether to reload the dynamic module from file if it already exists in `sys.modules`.
220
+ Otherwise, the module is only reloaded if the file has changed.
221
+
222
+ Returns:
223
+ `typing.Type`: The class looked for.
224
+ """
225
+ name = os.path.normpath(module_path)
226
+ if name.endswith(".py"):
227
+ name = name[:-3]
228
+ name = name.replace(os.path.sep, ".")
229
+ module_file: Path = Path(HF_MODULES_CACHE) / module_path
230
+ with _HF_REMOTE_CODE_LOCK:
231
+ if force_reload:
232
+ sys.modules.pop(name, None)
233
+ importlib.invalidate_caches()
234
+ cached_module: Optional[ModuleType] = sys.modules.get(name)
235
+ module_spec = importlib.util.spec_from_file_location(name, location=module_file)
236
+
237
+ # Hash the module file and all its relative imports to check if we need to reload it
238
+ module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
239
+ module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
240
+
241
+ module: ModuleType
242
+ if cached_module is None:
243
+ module = importlib.util.module_from_spec(module_spec)
244
+ # insert it into sys.modules before any loading begins
245
+ sys.modules[name] = module
246
+ else:
247
+ module = cached_module
248
+ # reload in both cases, unless the module is already imported and the hash hits
249
+ if getattr(module, "__transformers_module_hash__", "") != module_hash:
250
+ module_spec.loader.exec_module(module)
251
+ module.__transformers_module_hash__ = module_hash
252
+ return getattr(module, class_name)
253
+
254
+
255
+ def get_cached_module_file(
256
+ pretrained_model_name_or_path: Union[str, os.PathLike],
257
+ module_file: str,
258
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
259
+ force_download: bool = False,
260
+ resume_download: Optional[bool] = None,
261
+ proxies: Optional[Dict[str, str]] = None,
262
+ token: Optional[Union[bool, str]] = None,
263
+ revision: Optional[str] = None,
264
+ local_files_only: bool = False,
265
+ repo_type: Optional[str] = None,
266
+ _commit_hash: Optional[str] = None,
267
+ **deprecated_kwargs,
268
+ ) -> str:
269
+ """
270
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
271
+ Transformers module.
272
+
273
+ Args:
274
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
275
+ This can be either:
276
+
277
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
278
+ huggingface.co.
279
+ - a path to a *directory* containing a configuration file saved using the
280
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
281
+
282
+ module_file (`str`):
283
+ The name of the module file containing the class to look for.
284
+ cache_dir (`str` or `os.PathLike`, *optional*):
285
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
286
+ cache should not be used.
287
+ force_download (`bool`, *optional*, defaults to `False`):
288
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
289
+ exist.
290
+ resume_download:
291
+ Deprecated and ignored. All downloads are now resumed by default when possible.
292
+ Will be removed in v5 of Transformers.
293
+ proxies (`Dict[str, str]`, *optional*):
294
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
295
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
296
+ token (`str` or *bool*, *optional*):
297
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
298
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
299
+ revision (`str`, *optional*, defaults to `"main"`):
300
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
301
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
302
+ identifier allowed by git.
303
+ local_files_only (`bool`, *optional*, defaults to `False`):
304
+ If `True`, will only try to load the tokenizer configuration from local files.
305
+ repo_type (`str`, *optional*):
306
+ Specify the repo type (useful when downloading from a space for instance).
307
+
308
+ <Tip>
309
+
310
+ Passing `token=True` is required when you want to use a private model.
311
+
312
+ </Tip>
313
+
314
+ Returns:
315
+ `str`: The path to the module inside the cache.
316
+ """
317
+ use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
318
+ if use_auth_token is not None:
319
+ warnings.warn(
320
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
321
+ FutureWarning,
322
+ )
323
+ if token is not None:
324
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
325
+ token = use_auth_token
326
+
327
+ if is_offline_mode() and not local_files_only:
328
+ logger.info("Offline mode: forcing local_files_only=True")
329
+ local_files_only = True
330
+
331
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
332
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
333
+ is_local = os.path.isdir(pretrained_model_name_or_path)
334
+ if is_local:
335
+ submodule = os.path.basename(pretrained_model_name_or_path)
336
+ else:
337
+ submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
338
+ cached_module = try_to_load_from_cache(
339
+ pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
340
+ )
341
+
342
+ new_files = []
343
+ try:
344
+ # Load from URL or cache if already cached
345
+ resolved_module_file = cached_file(
346
+ pretrained_model_name_or_path,
347
+ module_file,
348
+ cache_dir=cache_dir,
349
+ force_download=force_download,
350
+ proxies=proxies,
351
+ resume_download=resume_download,
352
+ local_files_only=local_files_only,
353
+ token=token,
354
+ revision=revision,
355
+ repo_type=repo_type,
356
+ _commit_hash=_commit_hash,
357
+ )
358
+ if not is_local and cached_module != resolved_module_file:
359
+ new_files.append(module_file)
360
+
361
+ except EnvironmentError:
362
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
363
+ raise
364
+
365
+ # Check we have all the requirements in our environment
366
+ modules_needed = check_imports(resolved_module_file)
367
+
368
+ # Now we move the module inside our cached dynamic modules.
369
+ full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
370
+ create_dynamic_module(full_submodule)
371
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
372
+ if submodule == os.path.basename(pretrained_model_name_or_path):
373
+ # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
374
+ # has changed since last copy.
375
+ if not (submodule_path / module_file).exists() or not filecmp.cmp(
376
+ resolved_module_file, str(submodule_path / module_file)
377
+ ):
378
+ shutil.copy(resolved_module_file, submodule_path / module_file)
379
+ importlib.invalidate_caches()
380
+ for module_needed in modules_needed:
381
+ module_needed = f"{module_needed}.py"
382
+ module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
383
+ if not (submodule_path / module_needed).exists() or not filecmp.cmp(
384
+ module_needed_file, str(submodule_path / module_needed)
385
+ ):
386
+ shutil.copy(module_needed_file, submodule_path / module_needed)
387
+ importlib.invalidate_caches()
388
+ else:
389
+ # Get the commit hash
390
+ commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
391
+
392
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
393
+ # benefit of versioning.
394
+ submodule_path = submodule_path / commit_hash
395
+ full_submodule = full_submodule + os.path.sep + commit_hash
396
+ create_dynamic_module(full_submodule)
397
+
398
+ if not (submodule_path / module_file).exists():
399
+ shutil.copy(resolved_module_file, submodule_path / module_file)
400
+ importlib.invalidate_caches()
401
+ # Make sure we also have every file with relative
402
+ for module_needed in modules_needed:
403
+ if not (submodule_path / f"{module_needed}.py").exists():
404
+ get_cached_module_file(
405
+ pretrained_model_name_or_path,
406
+ f"{module_needed}.py",
407
+ cache_dir=cache_dir,
408
+ force_download=force_download,
409
+ resume_download=resume_download,
410
+ proxies=proxies,
411
+ token=token,
412
+ revision=revision,
413
+ local_files_only=local_files_only,
414
+ _commit_hash=commit_hash,
415
+ )
416
+ new_files.append(f"{module_needed}.py")
417
+
418
+ if len(new_files) > 0 and revision is None:
419
+ new_files = "\n".join([f"- {f}" for f in new_files])
420
+ repo_type_str = "" if repo_type is None else f"{repo_type}s/"
421
+ url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
422
+ logger.warning(
423
+ f"A new version of the following files was downloaded from {url}:\n{new_files}"
424
+ "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
425
+ "versions of the code file, you can pin a revision."
426
+ )
427
+
428
+ return os.path.join(full_submodule, module_file)
429
+
430
+
431
+ def get_class_from_dynamic_module(
432
+ class_reference: str,
433
+ pretrained_model_name_or_path: Union[str, os.PathLike],
434
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
435
+ force_download: bool = False,
436
+ resume_download: Optional[bool] = None,
437
+ proxies: Optional[Dict[str, str]] = None,
438
+ token: Optional[Union[bool, str]] = None,
439
+ revision: Optional[str] = None,
440
+ local_files_only: bool = False,
441
+ repo_type: Optional[str] = None,
442
+ code_revision: Optional[str] = None,
443
+ **kwargs,
444
+ ) -> typing.Type:
445
+ """
446
+ Extracts a class from a module file, present in the local folder or repository of a model.
447
+
448
+ <Tip warning={true}>
449
+
450
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
451
+ therefore only be called on trusted repos.
452
+
453
+ </Tip>
454
+
455
+
456
+
457
+ Args:
458
+ class_reference (`str`):
459
+ The full name of the class to load, including its module and optionally its repo.
460
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
461
+ This can be either:
462
+
463
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
464
+ huggingface.co.
465
+ - a path to a *directory* containing a configuration file saved using the
466
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
467
+
468
+ This is used when `class_reference` does not specify another repo.
469
+ module_file (`str`):
470
+ The name of the module file containing the class to look for.
471
+ class_name (`str`):
472
+ The name of the class to import in the module.
473
+ cache_dir (`str` or `os.PathLike`, *optional*):
474
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
475
+ cache should not be used.
476
+ force_download (`bool`, *optional*, defaults to `False`):
477
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
478
+ exist.
479
+ resume_download:
480
+ Deprecated and ignored. All downloads are now resumed by default when possible.
481
+ Will be removed in v5 of Transformers.
482
+ proxies (`Dict[str, str]`, *optional*):
483
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
484
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
485
+ token (`str` or `bool`, *optional*):
486
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
487
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
488
+ revision (`str`, *optional*, defaults to `"main"`):
489
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
490
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
491
+ identifier allowed by git.
492
+ local_files_only (`bool`, *optional*, defaults to `False`):
493
+ If `True`, will only try to load the tokenizer configuration from local files.
494
+ repo_type (`str`, *optional*):
495
+ Specify the repo type (useful when downloading from a space for instance).
496
+ code_revision (`str`, *optional*, defaults to `"main"`):
497
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
498
+ rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
499
+ storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
500
+
501
+ <Tip>
502
+
503
+ Passing `token=True` is required when you want to use a private model.
504
+
505
+ </Tip>
506
+
507
+ Returns:
508
+ `typing.Type`: The class, dynamically imported from the module.
509
+
510
+ Examples:
511
+
512
+ ```python
513
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
514
+ # module.
515
+ cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
516
+
517
+ # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
518
+ # module.
519
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
520
+ ```"""
521
+ use_auth_token = kwargs.pop("use_auth_token", None)
522
+ if use_auth_token is not None:
523
+ warnings.warn(
524
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
525
+ FutureWarning,
526
+ )
527
+ if token is not None:
528
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
529
+ token = use_auth_token
530
+
531
+ # Catch the name of the repo if it's specified in `class_reference`
532
+ if "--" in class_reference:
533
+ repo_id, class_reference = class_reference.split("--")
534
+ else:
535
+ repo_id = pretrained_model_name_or_path
536
+ module_file, class_name = class_reference.split(".")
537
+
538
+ if code_revision is None and pretrained_model_name_or_path == repo_id:
539
+ code_revision = revision
540
+ # And lastly we get the class inside our newly created module
541
+ final_module = get_cached_module_file(
542
+ repo_id,
543
+ module_file + ".py",
544
+ cache_dir=cache_dir,
545
+ force_download=force_download,
546
+ resume_download=resume_download,
547
+ proxies=proxies,
548
+ token=token,
549
+ revision=code_revision,
550
+ local_files_only=local_files_only,
551
+ repo_type=repo_type,
552
+ )
553
+ return get_class_in_module(class_name, final_module, force_reload=force_download)
554
+
555
+
556
+ def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
557
+ """
558
+ Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
559
+ adds the proper fields in a config.
560
+
561
+ Args:
562
+ obj (`Any`): The object for which to save the module files.
563
+ folder (`str` or `os.PathLike`): The folder where to save.
564
+ config (`PretrainedConfig` or dictionary, `optional`):
565
+ A config in which to register the auto_map corresponding to this custom object.
566
+
567
+ Returns:
568
+ `List[str]`: The list of files saved.
569
+ """
570
+ if obj.__module__ == "__main__":
571
+ logger.warning(
572
+ f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
573
+ "this code in a separate module so we can include it in the saved folder and make it easier to share via "
574
+ "the Hub."
575
+ )
576
+ return
577
+
578
+ def _set_auto_map_in_config(_config):
579
+ module_name = obj.__class__.__module__
580
+ last_module = module_name.split(".")[-1]
581
+ full_name = f"{last_module}.{obj.__class__.__name__}"
582
+ # Special handling for tokenizers
583
+ if "Tokenizer" in full_name:
584
+ slow_tokenizer_class = None
585
+ fast_tokenizer_class = None
586
+ if obj.__class__.__name__.endswith("Fast"):
587
+ # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
588
+ fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
589
+ if getattr(obj, "slow_tokenizer_class", None) is not None:
590
+ slow_tokenizer = getattr(obj, "slow_tokenizer_class")
591
+ slow_tok_module_name = slow_tokenizer.__module__
592
+ last_slow_tok_module = slow_tok_module_name.split(".")[-1]
593
+ slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
594
+ else:
595
+ # Slow tokenizer: no way to have the fast class
596
+ slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
597
+
598
+ full_name = (slow_tokenizer_class, fast_tokenizer_class)
599
+
600
+ if isinstance(_config, dict):
601
+ auto_map = _config.get("auto_map", {})
602
+ auto_map[obj._auto_class] = full_name
603
+ _config["auto_map"] = auto_map
604
+ elif getattr(_config, "auto_map", None) is not None:
605
+ _config.auto_map[obj._auto_class] = full_name
606
+ else:
607
+ _config.auto_map = {obj._auto_class: full_name}
608
+
609
+ # Add object class to the config auto_map
610
+ if isinstance(config, (list, tuple)):
611
+ for cfg in config:
612
+ _set_auto_map_in_config(cfg)
613
+ elif config is not None:
614
+ _set_auto_map_in_config(config)
615
+
616
+ result = []
617
+ # Copy module file to the output folder.
618
+ object_file = sys.modules[obj.__module__].__file__
619
+ dest_file = Path(folder) / (Path(object_file).name)
620
+ shutil.copy(object_file, dest_file)
621
+ result.append(dest_file)
622
+
623
+ # Gather all relative imports recursively and make sure they are copied as well.
624
+ for needed_file in get_relative_import_files(object_file):
625
+ dest_file = Path(folder) / (Path(needed_file).name)
626
+ shutil.copy(needed_file, dest_file)
627
+ result.append(dest_file)
628
+
629
+ return result
630
+
631
+
632
+ def _raise_timeout_error(signum, frame):
633
+ raise ValueError(
634
+ "Loading this model requires you to execute custom code contained in the model repository on your local "
635
+ "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
636
+ )
637
+
638
+
639
+ TIME_OUT_REMOTE_CODE = 15
640
+
641
+
642
+ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
643
+ if trust_remote_code is None:
644
+ if has_local_code:
645
+ trust_remote_code = False
646
+ elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
647
+ prev_sig_handler = None
648
+ try:
649
+ prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
650
+ signal.alarm(TIME_OUT_REMOTE_CODE)
651
+ while trust_remote_code is None:
652
+ answer = input(
653
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
654
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
655
+ f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
656
+ f"Do you wish to run the custom code? [y/N] "
657
+ )
658
+ if answer.lower() in ["yes", "y", "1"]:
659
+ trust_remote_code = True
660
+ elif answer.lower() in ["no", "n", "0", ""]:
661
+ trust_remote_code = False
662
+ signal.alarm(0)
663
+ except Exception:
664
+ # OS which does not support signal.SIGALRM
665
+ raise ValueError(
666
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
667
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
668
+ f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
669
+ )
670
+ finally:
671
+ if prev_sig_handler is not None:
672
+ signal.signal(signal.SIGALRM, prev_sig_handler)
673
+ signal.alarm(0)
674
+ elif has_remote_code:
675
+ # For the CI which puts the timeout at 0
676
+ _raise_timeout_error(None, None)
677
+
678
+ if has_remote_code and not has_local_code and not trust_remote_code:
679
+ raise ValueError(
680
+ f"Loading {model_name} requires you to execute the configuration file in that"
681
+ " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
682
+ " set the option `trust_remote_code=True` to remove this error."
683
+ )
684
+
685
+ return trust_remote_code
.venv/lib/python3.11/site-packages/transformers/feature_extraction_sequence_utils.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Sequence feature extraction class for common feature extractors to preprocess sequences.
17
+ """
18
+
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import numpy as np
22
+
23
+ from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
24
+ from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class SequenceFeatureExtractor(FeatureExtractionMixin):
31
+ """
32
+ This is a general feature extraction class for speech recognition.
33
+
34
+ Args:
35
+ feature_size (`int`):
36
+ The feature dimension of the extracted features.
37
+ sampling_rate (`int`):
38
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
39
+ padding_value (`float`):
40
+ The value that is used to fill the padding values / vectors.
41
+ """
42
+
43
+ def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
44
+ self.feature_size = feature_size
45
+ self.sampling_rate = sampling_rate
46
+ self.padding_value = padding_value
47
+
48
+ self.padding_side = kwargs.pop("padding_side", "right")
49
+ self.return_attention_mask = kwargs.pop("return_attention_mask", True)
50
+
51
+ super().__init__(**kwargs)
52
+
53
+ def pad(
54
+ self,
55
+ processed_features: Union[
56
+ BatchFeature,
57
+ List[BatchFeature],
58
+ Dict[str, BatchFeature],
59
+ Dict[str, List[BatchFeature]],
60
+ List[Dict[str, BatchFeature]],
61
+ ],
62
+ padding: Union[bool, str, PaddingStrategy] = True,
63
+ max_length: Optional[int] = None,
64
+ truncation: bool = False,
65
+ pad_to_multiple_of: Optional[int] = None,
66
+ return_attention_mask: Optional[bool] = None,
67
+ return_tensors: Optional[Union[str, TensorType]] = None,
68
+ ) -> BatchFeature:
69
+ """
70
+ Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
71
+ max sequence length in the batch.
72
+
73
+ Padding side (left/right) padding values are defined at the feature extractor level (with `self.padding_side`,
74
+ `self.padding_value`)
75
+
76
+ <Tip>
77
+
78
+ If the `processed_features` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
79
+ result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
80
+ PyTorch tensors, you will lose the specific device of your tensors however.
81
+
82
+ </Tip>
83
+
84
+ Args:
85
+ processed_features ([`BatchFeature`], list of [`BatchFeature`], `Dict[str, List[float]]`, `Dict[str, List[List[float]]` or `List[Dict[str, List[float]]]`):
86
+ Processed inputs. Can represent one input ([`BatchFeature`] or `Dict[str, List[float]]`) or a batch of
87
+ input values / vectors (list of [`BatchFeature`], *Dict[str, List[List[float]]]* or *List[Dict[str,
88
+ List[float]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
89
+ collate function.
90
+
91
+ Instead of `List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
92
+ see the note above for the return type.
93
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
94
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
95
+ index) among:
96
+
97
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
98
+ sequence if provided).
99
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
100
+ acceptable input length for the model if that argument is not provided.
101
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
102
+ lengths).
103
+ max_length (`int`, *optional*):
104
+ Maximum length of the returned list and optionally padding length (see above).
105
+ truncation (`bool`):
106
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
107
+ pad_to_multiple_of (`int`, *optional*):
108
+ If set will pad the sequence to a multiple of the provided value.
109
+
110
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
111
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
112
+ return_attention_mask (`bool`, *optional*):
113
+ Whether to return the attention mask. If left to the default, will return the attention mask according
114
+ to the specific feature_extractor's default.
115
+
116
+ [What are attention masks?](../glossary#attention-mask)
117
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
118
+ If set, will return tensors instead of list of python integers. Acceptable values are:
119
+
120
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
121
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
122
+ - `'np'`: Return Numpy `np.ndarray` objects.
123
+ """
124
+ # If we have a list of dicts, let's convert it in a dict of lists
125
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
126
+ if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
127
+ processed_features = {
128
+ key: [example[key] for example in processed_features] for key in processed_features[0].keys()
129
+ }
130
+
131
+ # The model's main input name, usually `input_values`, has be passed for padding
132
+ if self.model_input_names[0] not in processed_features:
133
+ raise ValueError(
134
+ "You should supply an instance of `transformers.BatchFeature` or list of `transformers.BatchFeature`"
135
+ f" to this method that includes {self.model_input_names[0]}, but you provided"
136
+ f" {list(processed_features.keys())}"
137
+ )
138
+
139
+ required_input = processed_features[self.model_input_names[0]]
140
+ return_attention_mask = (
141
+ return_attention_mask if return_attention_mask is not None else self.return_attention_mask
142
+ )
143
+
144
+ if len(required_input) == 0:
145
+ if return_attention_mask:
146
+ processed_features["attention_mask"] = []
147
+ return processed_features
148
+
149
+ # If we have PyTorch/TF tensors or lists as inputs, we cast them as Numpy arrays
150
+ # and rebuild them afterwards if no return_tensors is specified
151
+ # Note that we lose the specific device the tensor may be on for PyTorch
152
+
153
+ first_element = required_input[0]
154
+ if isinstance(first_element, (list, tuple)):
155
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
156
+ index = 0
157
+ while len(required_input[index]) == 0:
158
+ index += 1
159
+ if index < len(required_input):
160
+ first_element = required_input[index][0]
161
+
162
+ if return_tensors is None:
163
+ if is_tf_tensor(first_element):
164
+ return_tensors = "tf"
165
+ elif is_torch_tensor(first_element):
166
+ return_tensors = "pt"
167
+ elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
168
+ return_tensors = "np"
169
+ else:
170
+ raise ValueError(
171
+ f"type of {first_element} unknown: {type(first_element)}. "
172
+ "Should be one of a python, numpy, pytorch or tensorflow object."
173
+ )
174
+
175
+ for key, value in processed_features.items():
176
+ if isinstance(value[0], (int, float)):
177
+ processed_features[key] = to_numpy(value)
178
+ else:
179
+ processed_features[key] = [to_numpy(v) for v in value]
180
+
181
+ # Convert padding_strategy in PaddingStrategy
182
+ padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
183
+
184
+ required_input = processed_features[self.model_input_names[0]]
185
+
186
+ batch_size = len(required_input)
187
+ if not all(len(v) == batch_size for v in processed_features.values()):
188
+ raise ValueError("Some items in the output dictionary have a different batch size than others.")
189
+
190
+ truncated_inputs = []
191
+ for i in range(batch_size):
192
+ inputs = {k: v[i] for k, v in processed_features.items()}
193
+ # truncation
194
+ inputs_slice = self._truncate(
195
+ inputs,
196
+ max_length=max_length,
197
+ pad_to_multiple_of=pad_to_multiple_of,
198
+ truncation=truncation,
199
+ )
200
+ truncated_inputs.append(inputs_slice)
201
+
202
+ if padding_strategy == PaddingStrategy.LONGEST:
203
+ # make sure that `max_length` cannot be longer than the longest truncated length
204
+ max_length = max(len(input_slice[self.model_input_names[0]]) for input_slice in truncated_inputs)
205
+ padding_strategy = PaddingStrategy.MAX_LENGTH
206
+
207
+ batch_outputs = {}
208
+ for i in range(batch_size):
209
+ # padding
210
+ outputs = self._pad(
211
+ truncated_inputs[i],
212
+ max_length=max_length,
213
+ padding_strategy=padding_strategy,
214
+ pad_to_multiple_of=pad_to_multiple_of,
215
+ return_attention_mask=return_attention_mask,
216
+ )
217
+
218
+ for key, value in outputs.items():
219
+ if key not in batch_outputs:
220
+ batch_outputs[key] = []
221
+ if value.dtype is np.dtype(np.float64):
222
+ value = value.astype(np.float32)
223
+ batch_outputs[key].append(value)
224
+
225
+ return BatchFeature(batch_outputs, tensor_type=return_tensors)
226
+
227
+ def _pad(
228
+ self,
229
+ processed_features: Union[Dict[str, np.ndarray], BatchFeature],
230
+ max_length: Optional[int] = None,
231
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
232
+ pad_to_multiple_of: Optional[int] = None,
233
+ return_attention_mask: Optional[bool] = None,
234
+ ) -> dict:
235
+ """
236
+ Pad inputs (on left/right and up to predefined length or max length in the batch)
237
+
238
+ Args:
239
+ processed_features (`Union[Dict[str, np.ndarray], BatchFeature]`):
240
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
241
+ of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
242
+ max_length (`int`, *optional*):
243
+ Maximum length of the returned list and optionally padding length (see below)
244
+ padding_strategy (`PaddingStrategy`, *optional*, default to `PaddingStrategy.DO_NOT_PAD`):
245
+ PaddingStrategy to use for padding.
246
+
247
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
248
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
249
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
250
+ The feature_extractor padding sides are defined in self.padding_side:
251
+
252
+ - 'left': pads on the left of the sequences
253
+ - 'right': pads on the right of the sequences
254
+ pad_to_multiple_of (`int`, *optional*):
255
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
256
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
257
+ which benefit from having sequence lengths be a multiple of 128.
258
+ return_attention_mask (`bool`, *optional*):
259
+ Set to False to avoid returning attention mask (default: set to model specifics)
260
+ """
261
+ required_input = processed_features[self.model_input_names[0]]
262
+
263
+ if padding_strategy == PaddingStrategy.LONGEST:
264
+ max_length = len(required_input)
265
+
266
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
267
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
268
+
269
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) < max_length
270
+
271
+ if return_attention_mask and "attention_mask" not in processed_features:
272
+ processed_features["attention_mask"] = np.ones(len(required_input), dtype=np.int32)
273
+
274
+ if needs_to_be_padded:
275
+ difference = max_length - len(required_input)
276
+ if self.padding_side == "right":
277
+ if return_attention_mask:
278
+ processed_features["attention_mask"] = np.pad(
279
+ processed_features["attention_mask"], (0, difference)
280
+ )
281
+ padding_shape = ((0, difference), (0, 0)) if self.feature_size > 1 else (0, difference)
282
+ processed_features[self.model_input_names[0]] = np.pad(
283
+ required_input, padding_shape, "constant", constant_values=self.padding_value
284
+ )
285
+ elif self.padding_side == "left":
286
+ if return_attention_mask:
287
+ processed_features["attention_mask"] = np.pad(
288
+ processed_features["attention_mask"], (difference, 0)
289
+ )
290
+ padding_shape = ((difference, 0), (0, 0)) if self.feature_size > 1 else (difference, 0)
291
+ processed_features[self.model_input_names[0]] = np.pad(
292
+ required_input, padding_shape, "constant", constant_values=self.padding_value
293
+ )
294
+ else:
295
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
296
+
297
+ return processed_features
298
+
299
+ def _truncate(
300
+ self,
301
+ processed_features: Union[Dict[str, np.ndarray], BatchFeature],
302
+ max_length: Optional[int] = None,
303
+ pad_to_multiple_of: Optional[int] = None,
304
+ truncation: Optional[bool] = None,
305
+ ):
306
+ """
307
+ Truncate inputs to predefined length or max length in the batch
308
+
309
+ Args:
310
+ processed_features(`Union[Dict[str, np.ndarray], BatchFeature]`):
311
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
312
+ of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
313
+ max_length (`int`, *optional*):
314
+ maximum length of the returned list and optionally padding length (see below)
315
+ pad_to_multiple_of (`int`, *optional*) :
316
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
317
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
318
+ which benefit from having sequence lengths be a multiple of 128.
319
+ truncation (`bool`, *optional*):
320
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
321
+ """
322
+ if not truncation:
323
+ return processed_features
324
+ elif truncation and max_length is None:
325
+ raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.")
326
+
327
+ required_input = processed_features[self.model_input_names[0]]
328
+
329
+ # find `max_length` that fits `pad_to_multiple_of`
330
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
331
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
332
+
333
+ needs_to_be_truncated = len(required_input) > max_length
334
+
335
+ if needs_to_be_truncated:
336
+ processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length]
337
+ if "attention_mask" in processed_features:
338
+ processed_features["attention_mask"] = processed_features["attention_mask"][:max_length]
339
+
340
+ return processed_features
341
+
342
+ def _get_padding_strategies(self, padding=False, max_length=None):
343
+ """
344
+ Find the correct padding strategy
345
+ """
346
+
347
+ # Get padding strategy
348
+ if padding is not False:
349
+ if padding is True:
350
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
351
+ elif not isinstance(padding, PaddingStrategy):
352
+ padding_strategy = PaddingStrategy(padding)
353
+ elif isinstance(padding, PaddingStrategy):
354
+ padding_strategy = padding
355
+ else:
356
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
357
+
358
+ # Set max length if needed
359
+ if max_length is None:
360
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
361
+ raise ValueError(
362
+ f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined"
363
+ )
364
+
365
+ # Test if we have a padding value
366
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
367
+ raise ValueError(
368
+ "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use"
369
+ " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
370
+ )
371
+
372
+ return padding_strategy
.venv/lib/python3.11/site-packages/transformers/hf_argparser.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import json
17
+ import os
18
+ import sys
19
+ import types
20
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
21
+ from copy import copy
22
+ from enum import Enum
23
+ from inspect import isclass
24
+ from pathlib import Path
25
+ from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints
26
+
27
+ import yaml
28
+
29
+
30
+ DataClass = NewType("DataClass", Any)
31
+ DataClassType = NewType("DataClassType", Any)
32
+
33
+
34
+ # From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
35
+ def string_to_bool(v):
36
+ if isinstance(v, bool):
37
+ return v
38
+ if v.lower() in ("yes", "true", "t", "y", "1"):
39
+ return True
40
+ elif v.lower() in ("no", "false", "f", "n", "0"):
41
+ return False
42
+ else:
43
+ raise ArgumentTypeError(
44
+ f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
45
+ )
46
+
47
+
48
+ def make_choice_type_function(choices: list) -> Callable[[str], Any]:
49
+ """
50
+ Creates a mapping function from each choices string representation to the actual value. Used to support multiple
51
+ value types for a single argument.
52
+
53
+ Args:
54
+ choices (list): List of choices.
55
+
56
+ Returns:
57
+ Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
58
+ """
59
+ str_to_choice = {str(choice): choice for choice in choices}
60
+ return lambda arg: str_to_choice.get(arg, arg)
61
+
62
+
63
+ def HfArg(
64
+ *,
65
+ aliases: Union[str, List[str]] = None,
66
+ help: str = None,
67
+ default: Any = dataclasses.MISSING,
68
+ default_factory: Callable[[], Any] = dataclasses.MISSING,
69
+ metadata: dict = None,
70
+ **kwargs,
71
+ ) -> dataclasses.Field:
72
+ """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.
73
+
74
+ Example comparing the use of `HfArg` and `dataclasses.field`:
75
+ ```
76
+ @dataclass
77
+ class Args:
78
+ regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
79
+ hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
80
+ ```
81
+
82
+ Args:
83
+ aliases (Union[str, List[str]], optional):
84
+ Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
85
+ Defaults to None.
86
+ help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
87
+ default (Any, optional):
88
+ Default value for the argument. If not default or default_factory is specified, the argument is required.
89
+ Defaults to dataclasses.MISSING.
90
+ default_factory (Callable[[], Any], optional):
91
+ The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
92
+ default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
93
+ Defaults to dataclasses.MISSING.
94
+ metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
95
+
96
+ Returns:
97
+ Field: A `dataclasses.Field` with the desired properties.
98
+ """
99
+ if metadata is None:
100
+ # Important, don't use as default param in function signature because dict is mutable and shared across function calls
101
+ metadata = {}
102
+ if aliases is not None:
103
+ metadata["aliases"] = aliases
104
+ if help is not None:
105
+ metadata["help"] = help
106
+
107
+ return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
108
+
109
+
110
+ class HfArgumentParser(ArgumentParser):
111
+ """
112
+ This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
113
+
114
+ The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
115
+ arguments to the parser after initialization and you'll get the output back after parsing as an additional
116
+ namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
117
+ """
118
+
119
+ dataclass_types: Iterable[DataClassType]
120
+
121
+ def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
122
+ """
123
+ Args:
124
+ dataclass_types:
125
+ Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
126
+ kwargs (`Dict[str, Any]`, *optional*):
127
+ Passed to `argparse.ArgumentParser()` in the regular way.
128
+ """
129
+ # To make the default appear when using --help
130
+ if "formatter_class" not in kwargs:
131
+ kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
132
+ super().__init__(**kwargs)
133
+ if dataclasses.is_dataclass(dataclass_types):
134
+ dataclass_types = [dataclass_types]
135
+ self.dataclass_types = list(dataclass_types)
136
+ for dtype in self.dataclass_types:
137
+ self._add_dataclass_arguments(dtype)
138
+
139
+ @staticmethod
140
+ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
141
+ # Long-option strings are conventionlly separated by hyphens rather
142
+ # than underscores, e.g., "--long-format" rather than "--long_format".
143
+ # Argparse converts hyphens to underscores so that the destination
144
+ # string is a valid attribute name. Hf_argparser should do the same.
145
+ long_options = [f"--{field.name}"]
146
+ if "_" in field.name:
147
+ long_options.append(f"--{field.name.replace('_', '-')}")
148
+
149
+ kwargs = field.metadata.copy()
150
+ # field.metadata is not used at all by Data Classes,
151
+ # it is provided as a third-party extension mechanism.
152
+ if isinstance(field.type, str):
153
+ raise RuntimeError(
154
+ "Unresolved type detected, which should have been done with the help of "
155
+ "`typing.get_type_hints` method by default"
156
+ )
157
+
158
+ aliases = kwargs.pop("aliases", [])
159
+ if isinstance(aliases, str):
160
+ aliases = [aliases]
161
+
162
+ origin_type = getattr(field.type, "__origin__", field.type)
163
+ if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)):
164
+ if str not in field.type.__args__ and (
165
+ len(field.type.__args__) != 2 or type(None) not in field.type.__args__
166
+ ):
167
+ raise ValueError(
168
+ "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because"
169
+ " the argument parser only supports one type per argument."
170
+ f" Problem encountered in field '{field.name}'."
171
+ )
172
+ if type(None) not in field.type.__args__:
173
+ # filter `str` in Union
174
+ field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1]
175
+ origin_type = getattr(field.type, "__origin__", field.type)
176
+ elif bool not in field.type.__args__:
177
+ # filter `NoneType` in Union (except for `Union[bool, NoneType]`)
178
+ field.type = (
179
+ field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
180
+ )
181
+ origin_type = getattr(field.type, "__origin__", field.type)
182
+
183
+ # A variable to store kwargs for a boolean field, if needed
184
+ # so that we can init a `no_*` complement argument (see below)
185
+ bool_kwargs = {}
186
+ if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)):
187
+ if origin_type is Literal:
188
+ kwargs["choices"] = field.type.__args__
189
+ else:
190
+ kwargs["choices"] = [x.value for x in field.type]
191
+
192
+ kwargs["type"] = make_choice_type_function(kwargs["choices"])
193
+
194
+ if field.default is not dataclasses.MISSING:
195
+ kwargs["default"] = field.default
196
+ else:
197
+ kwargs["required"] = True
198
+ elif field.type is bool or field.type == Optional[bool]:
199
+ # Copy the currect kwargs to use to instantiate a `no_*` complement argument below.
200
+ # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument
201
+ bool_kwargs = copy(kwargs)
202
+
203
+ # Hack because type=bool in argparse does not behave as we want.
204
+ kwargs["type"] = string_to_bool
205
+ if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
206
+ # Default value is False if we have no default when of type bool.
207
+ default = False if field.default is dataclasses.MISSING else field.default
208
+ # This is the value that will get picked if we don't include --{field.name} in any way
209
+ kwargs["default"] = default
210
+ # This tells argparse we accept 0 or 1 value after --{field.name}
211
+ kwargs["nargs"] = "?"
212
+ # This is the value that will get picked if we do --{field.name} (without value)
213
+ kwargs["const"] = True
214
+ elif isclass(origin_type) and issubclass(origin_type, list):
215
+ kwargs["type"] = field.type.__args__[0]
216
+ kwargs["nargs"] = "+"
217
+ if field.default_factory is not dataclasses.MISSING:
218
+ kwargs["default"] = field.default_factory()
219
+ elif field.default is dataclasses.MISSING:
220
+ kwargs["required"] = True
221
+ else:
222
+ kwargs["type"] = field.type
223
+ if field.default is not dataclasses.MISSING:
224
+ kwargs["default"] = field.default
225
+ elif field.default_factory is not dataclasses.MISSING:
226
+ kwargs["default"] = field.default_factory()
227
+ else:
228
+ kwargs["required"] = True
229
+ parser.add_argument(*long_options, *aliases, **kwargs)
230
+
231
+ # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
232
+ # Order is important for arguments with the same destination!
233
+ # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down
234
+ # here and we do not need those changes/additional keys.
235
+ if field.default is True and (field.type is bool or field.type == Optional[bool]):
236
+ bool_kwargs["default"] = False
237
+ parser.add_argument(
238
+ f"--no_{field.name}",
239
+ f"--no-{field.name.replace('_', '-')}",
240
+ action="store_false",
241
+ dest=field.name,
242
+ **bool_kwargs,
243
+ )
244
+
245
+ def _add_dataclass_arguments(self, dtype: DataClassType):
246
+ if hasattr(dtype, "_argument_group_name"):
247
+ parser = self.add_argument_group(dtype._argument_group_name)
248
+ else:
249
+ parser = self
250
+
251
+ try:
252
+ type_hints: Dict[str, type] = get_type_hints(dtype)
253
+ except NameError:
254
+ raise RuntimeError(
255
+ f"Type resolution failed for {dtype}. Try declaring the class in global scope or "
256
+ "removing line of `from __future__ import annotations` which opts in Postponed "
257
+ "Evaluation of Annotations (PEP 563)"
258
+ )
259
+ except TypeError as ex:
260
+ # Remove this block when we drop Python 3.9 support
261
+ if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex):
262
+ python_version = ".".join(map(str, sys.version_info[:3]))
263
+ raise RuntimeError(
264
+ f"Type resolution failed for {dtype} on Python {python_version}. Try removing "
265
+ "line of `from __future__ import annotations` which opts in union types as "
266
+ "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To "
267
+ "support Python versions that lower than 3.10, you need to use "
268
+ "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of "
269
+ "`X | None`."
270
+ ) from ex
271
+ raise
272
+
273
+ for field in dataclasses.fields(dtype):
274
+ if not field.init:
275
+ continue
276
+ field.type = type_hints[field.name]
277
+ self._parse_dataclass_field(parser, field)
278
+
279
+ def parse_args_into_dataclasses(
280
+ self,
281
+ args=None,
282
+ return_remaining_strings=False,
283
+ look_for_args_file=True,
284
+ args_filename=None,
285
+ args_file_flag=None,
286
+ ) -> Tuple[DataClass, ...]:
287
+ """
288
+ Parse command-line args into instances of the specified dataclass types.
289
+
290
+ This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
291
+ docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
292
+
293
+ Args:
294
+ args:
295
+ List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
296
+ return_remaining_strings:
297
+ If true, also return a list of remaining argument strings.
298
+ look_for_args_file:
299
+ If true, will look for a ".args" file with the same base name as the entry point script for this
300
+ process, and will append its potential content to the command line args.
301
+ args_filename:
302
+ If not None, will uses this file instead of the ".args" file specified in the previous argument.
303
+ args_file_flag:
304
+ If not None, will look for a file in the command-line args specified with this flag. The flag can be
305
+ specified multiple times and precedence is determined by the order (last one wins).
306
+
307
+ Returns:
308
+ Tuple consisting of:
309
+
310
+ - the dataclass instances in the same order as they were passed to the initializer.abspath
311
+ - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
312
+ after initialization.
313
+ - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
314
+ """
315
+
316
+ if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)):
317
+ args_files = []
318
+
319
+ if args_filename:
320
+ args_files.append(Path(args_filename))
321
+ elif look_for_args_file and len(sys.argv):
322
+ args_files.append(Path(sys.argv[0]).with_suffix(".args"))
323
+
324
+ # args files specified via command line flag should overwrite default args files so we add them last
325
+ if args_file_flag:
326
+ # Create special parser just to extract the args_file_flag values
327
+ args_file_parser = ArgumentParser()
328
+ args_file_parser.add_argument(args_file_flag, type=str, action="append")
329
+
330
+ # Use only remaining args for further parsing (remove the args_file_flag)
331
+ cfg, args = args_file_parser.parse_known_args(args=args)
332
+ cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None)
333
+
334
+ if cmd_args_file_paths:
335
+ args_files.extend([Path(p) for p in cmd_args_file_paths])
336
+
337
+ file_args = []
338
+ for args_file in args_files:
339
+ if args_file.exists():
340
+ file_args += args_file.read_text().split()
341
+
342
+ # in case of duplicate arguments the last one has precedence
343
+ # args specified via the command line should overwrite args from files, so we add them last
344
+ args = file_args + args if args is not None else file_args + sys.argv[1:]
345
+ namespace, remaining_args = self.parse_known_args(args=args)
346
+ outputs = []
347
+ for dtype in self.dataclass_types:
348
+ keys = {f.name for f in dataclasses.fields(dtype) if f.init}
349
+ inputs = {k: v for k, v in vars(namespace).items() if k in keys}
350
+ for k in keys:
351
+ delattr(namespace, k)
352
+ obj = dtype(**inputs)
353
+ outputs.append(obj)
354
+ if len(namespace.__dict__) > 0:
355
+ # additional namespace.
356
+ outputs.append(namespace)
357
+ if return_remaining_strings:
358
+ return (*outputs, remaining_args)
359
+ else:
360
+ if remaining_args:
361
+ raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
362
+
363
+ return (*outputs,)
364
+
365
+ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
366
+ """
367
+ Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
368
+ types.
369
+
370
+ Args:
371
+ args (`dict`):
372
+ dict containing config values
373
+ allow_extra_keys (`bool`, *optional*, defaults to `False`):
374
+ Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
375
+
376
+ Returns:
377
+ Tuple consisting of:
378
+
379
+ - the dataclass instances in the same order as they were passed to the initializer.
380
+ """
381
+ unused_keys = set(args.keys())
382
+ outputs = []
383
+ for dtype in self.dataclass_types:
384
+ keys = {f.name for f in dataclasses.fields(dtype) if f.init}
385
+ inputs = {k: v for k, v in args.items() if k in keys}
386
+ unused_keys.difference_update(inputs.keys())
387
+ obj = dtype(**inputs)
388
+ outputs.append(obj)
389
+ if not allow_extra_keys and unused_keys:
390
+ raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
391
+ return tuple(outputs)
392
+
393
+ def parse_json_file(
394
+ self, json_file: Union[str, os.PathLike], allow_extra_keys: bool = False
395
+ ) -> Tuple[DataClass, ...]:
396
+ """
397
+ Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
398
+ dataclass types.
399
+
400
+ Args:
401
+ json_file (`str` or `os.PathLike`):
402
+ File name of the json file to parse
403
+ allow_extra_keys (`bool`, *optional*, defaults to `False`):
404
+ Defaults to False. If False, will raise an exception if the json file contains keys that are not
405
+ parsed.
406
+
407
+ Returns:
408
+ Tuple consisting of:
409
+
410
+ - the dataclass instances in the same order as they were passed to the initializer.
411
+ """
412
+ with open(Path(json_file), encoding="utf-8") as open_json_file:
413
+ data = json.loads(open_json_file.read())
414
+ outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
415
+ return tuple(outputs)
416
+
417
+ def parse_yaml_file(
418
+ self, yaml_file: Union[str, os.PathLike], allow_extra_keys: bool = False
419
+ ) -> Tuple[DataClass, ...]:
420
+ """
421
+ Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
422
+ dataclass types.
423
+
424
+ Args:
425
+ yaml_file (`str` or `os.PathLike`):
426
+ File name of the yaml file to parse
427
+ allow_extra_keys (`bool`, *optional*, defaults to `False`):
428
+ Defaults to False. If False, will raise an exception if the json file contains keys that are not
429
+ parsed.
430
+
431
+ Returns:
432
+ Tuple consisting of:
433
+
434
+ - the dataclass instances in the same order as they were passed to the initializer.
435
+ """
436
+ outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
437
+ return tuple(outputs)
.venv/lib/python3.11/site-packages/transformers/hyperparameter_search.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .integrations import (
17
+ is_optuna_available,
18
+ is_ray_tune_available,
19
+ is_sigopt_available,
20
+ is_wandb_available,
21
+ run_hp_search_optuna,
22
+ run_hp_search_ray,
23
+ run_hp_search_sigopt,
24
+ run_hp_search_wandb,
25
+ )
26
+ from .trainer_utils import (
27
+ HPSearchBackend,
28
+ default_hp_space_optuna,
29
+ default_hp_space_ray,
30
+ default_hp_space_sigopt,
31
+ default_hp_space_wandb,
32
+ )
33
+ from .utils import logging
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ class HyperParamSearchBackendBase:
40
+ name: str
41
+ pip_package: str = None
42
+
43
+ @staticmethod
44
+ def is_available():
45
+ raise NotImplementedError
46
+
47
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
48
+ raise NotImplementedError
49
+
50
+ def default_hp_space(self, trial):
51
+ raise NotImplementedError
52
+
53
+ def ensure_available(self):
54
+ if not self.is_available():
55
+ raise RuntimeError(
56
+ f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
57
+ )
58
+
59
+ @classmethod
60
+ def pip_install(cls):
61
+ return f"`pip install {cls.pip_package or cls.name}`"
62
+
63
+
64
+ class OptunaBackend(HyperParamSearchBackendBase):
65
+ name = "optuna"
66
+
67
+ @staticmethod
68
+ def is_available():
69
+ return is_optuna_available()
70
+
71
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
72
+ return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
73
+
74
+ def default_hp_space(self, trial):
75
+ return default_hp_space_optuna(trial)
76
+
77
+
78
+ class RayTuneBackend(HyperParamSearchBackendBase):
79
+ name = "ray"
80
+ pip_package = "'ray[tune]'"
81
+
82
+ @staticmethod
83
+ def is_available():
84
+ return is_ray_tune_available()
85
+
86
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
87
+ return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
88
+
89
+ def default_hp_space(self, trial):
90
+ return default_hp_space_ray(trial)
91
+
92
+
93
+ class SigOptBackend(HyperParamSearchBackendBase):
94
+ name = "sigopt"
95
+
96
+ @staticmethod
97
+ def is_available():
98
+ return is_sigopt_available()
99
+
100
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
101
+ return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
102
+
103
+ def default_hp_space(self, trial):
104
+ return default_hp_space_sigopt(trial)
105
+
106
+
107
+ class WandbBackend(HyperParamSearchBackendBase):
108
+ name = "wandb"
109
+
110
+ @staticmethod
111
+ def is_available():
112
+ return is_wandb_available()
113
+
114
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
115
+ return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
116
+
117
+ def default_hp_space(self, trial):
118
+ return default_hp_space_wandb(trial)
119
+
120
+
121
+ ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
122
+ HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
123
+ }
124
+
125
+
126
+ def default_hp_search_backend() -> str:
127
+ available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
128
+ if len(available_backends) > 0:
129
+ name = available_backends[0].name
130
+ if len(available_backends) > 1:
131
+ logger.info(
132
+ f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
133
+ )
134
+ return name
135
+ raise RuntimeError(
136
+ "No hyperparameter search backend available.\n"
137
+ + "\n".join(
138
+ f" - To install {backend.name} run {backend.pip_install()}"
139
+ for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
140
+ )
141
+ )
.venv/lib/python3.11/site-packages/transformers/image_processing_base.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import copy
18
+ import json
19
+ import os
20
+ import warnings
21
+ from io import BytesIO
22
+ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
23
+
24
+ import numpy as np
25
+ import requests
26
+
27
+ from .dynamic_module_utils import custom_object_save
28
+ from .feature_extraction_utils import BatchFeature as BaseBatchFeature
29
+ from .utils import (
30
+ IMAGE_PROCESSOR_NAME,
31
+ PushToHubMixin,
32
+ add_model_info_to_auto_map,
33
+ add_model_info_to_custom_pipelines,
34
+ cached_file,
35
+ copy_func,
36
+ download_url,
37
+ is_offline_mode,
38
+ is_remote_url,
39
+ is_vision_available,
40
+ logging,
41
+ )
42
+
43
+
44
+ if is_vision_available():
45
+ from PIL import Image
46
+
47
+
48
+ ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ # TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
55
+ # We override the class string here, but logic is the same.
56
+ class BatchFeature(BaseBatchFeature):
57
+ r"""
58
+ Holds the output of the image processor specific `__call__` methods.
59
+
60
+ This class is derived from a python dictionary and can be used as a dictionary.
61
+
62
+ Args:
63
+ data (`dict`):
64
+ Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
65
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
66
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
67
+ initialization.
68
+ """
69
+
70
+
71
+ # TODO: (Amy) - factor out the common parts of this and the feature extractor
72
+ class ImageProcessingMixin(PushToHubMixin):
73
+ """
74
+ This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
75
+ extractors.
76
+ """
77
+
78
+ _auto_class = None
79
+
80
+ def __init__(self, **kwargs):
81
+ """Set elements of `kwargs` as attributes."""
82
+ # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
83
+ # `XXXImageProcessor`, this attribute and its value are misleading.
84
+ kwargs.pop("feature_extractor_type", None)
85
+ # Pop "processor_class" as it should be saved as private attribute
86
+ self._processor_class = kwargs.pop("processor_class", None)
87
+ # Additional attributes without default values
88
+ for key, value in kwargs.items():
89
+ try:
90
+ setattr(self, key, value)
91
+ except AttributeError as err:
92
+ logger.error(f"Can't set {key} with value {value} for {self}")
93
+ raise err
94
+
95
+ def _set_processor_class(self, processor_class: str):
96
+ """Sets processor class as an attribute."""
97
+ self._processor_class = processor_class
98
+
99
+ @classmethod
100
+ def from_pretrained(
101
+ cls: Type[ImageProcessorType],
102
+ pretrained_model_name_or_path: Union[str, os.PathLike],
103
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
104
+ force_download: bool = False,
105
+ local_files_only: bool = False,
106
+ token: Optional[Union[str, bool]] = None,
107
+ revision: str = "main",
108
+ **kwargs,
109
+ ) -> ImageProcessorType:
110
+ r"""
111
+ Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
112
+
113
+ Args:
114
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
115
+ This can be either:
116
+
117
+ - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
118
+ huggingface.co.
119
+ - a path to a *directory* containing a image processor file saved using the
120
+ [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
121
+ `./my_model_directory/`.
122
+ - a path or url to a saved image processor JSON *file*, e.g.,
123
+ `./my_model_directory/preprocessor_config.json`.
124
+ cache_dir (`str` or `os.PathLike`, *optional*):
125
+ Path to a directory in which a downloaded pretrained model image processor should be cached if the
126
+ standard cache should not be used.
127
+ force_download (`bool`, *optional*, defaults to `False`):
128
+ Whether or not to force to (re-)download the image processor files and override the cached versions if
129
+ they exist.
130
+ resume_download:
131
+ Deprecated and ignored. All downloads are now resumed by default when possible.
132
+ Will be removed in v5 of Transformers.
133
+ proxies (`Dict[str, str]`, *optional*):
134
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
135
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
136
+ token (`str` or `bool`, *optional*):
137
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
138
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
139
+ revision (`str`, *optional*, defaults to `"main"`):
140
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
141
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
142
+ identifier allowed by git.
143
+
144
+
145
+ <Tip>
146
+
147
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
148
+
149
+ </Tip>
150
+
151
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
152
+ If `False`, then this function returns just the final image processor object. If `True`, then this
153
+ functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
154
+ consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
155
+ `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
156
+ subfolder (`str`, *optional*, defaults to `""`):
157
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
158
+ specify the folder name here.
159
+ kwargs (`Dict[str, Any]`, *optional*):
160
+ The values in kwargs of any keys which are image processor attributes will be used to override the
161
+ loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
162
+ controlled by the `return_unused_kwargs` keyword parameter.
163
+
164
+ Returns:
165
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`].
166
+
167
+ Examples:
168
+
169
+ ```python
170
+ # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a
171
+ # derived class: *CLIPImageProcessor*
172
+ image_processor = CLIPImageProcessor.from_pretrained(
173
+ "openai/clip-vit-base-patch32"
174
+ ) # Download image_processing_config from huggingface.co and cache.
175
+ image_processor = CLIPImageProcessor.from_pretrained(
176
+ "./test/saved_model/"
177
+ ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')*
178
+ image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
179
+ image_processor = CLIPImageProcessor.from_pretrained(
180
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False
181
+ )
182
+ assert image_processor.do_normalize is False
183
+ image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained(
184
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True
185
+ )
186
+ assert image_processor.do_normalize is False
187
+ assert unused_kwargs == {"foo": False}
188
+ ```"""
189
+ kwargs["cache_dir"] = cache_dir
190
+ kwargs["force_download"] = force_download
191
+ kwargs["local_files_only"] = local_files_only
192
+ kwargs["revision"] = revision
193
+
194
+ use_auth_token = kwargs.pop("use_auth_token", None)
195
+ if use_auth_token is not None:
196
+ warnings.warn(
197
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
198
+ FutureWarning,
199
+ )
200
+ if token is not None:
201
+ raise ValueError(
202
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
203
+ )
204
+ token = use_auth_token
205
+
206
+ if token is not None:
207
+ kwargs["token"] = token
208
+
209
+ image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
210
+
211
+ return cls.from_dict(image_processor_dict, **kwargs)
212
+
213
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
214
+ """
215
+ Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
216
+ [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
217
+
218
+ Args:
219
+ save_directory (`str` or `os.PathLike`):
220
+ Directory where the image processor JSON file will be saved (will be created if it does not exist).
221
+ push_to_hub (`bool`, *optional*, defaults to `False`):
222
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
223
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
224
+ namespace).
225
+ kwargs (`Dict[str, Any]`, *optional*):
226
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
227
+ """
228
+ use_auth_token = kwargs.pop("use_auth_token", None)
229
+
230
+ if use_auth_token is not None:
231
+ warnings.warn(
232
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
233
+ FutureWarning,
234
+ )
235
+ if kwargs.get("token", None) is not None:
236
+ raise ValueError(
237
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
238
+ )
239
+ kwargs["token"] = use_auth_token
240
+
241
+ if os.path.isfile(save_directory):
242
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
243
+
244
+ os.makedirs(save_directory, exist_ok=True)
245
+
246
+ if push_to_hub:
247
+ commit_message = kwargs.pop("commit_message", None)
248
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
249
+ repo_id = self._create_repo(repo_id, **kwargs)
250
+ files_timestamps = self._get_files_timestamps(save_directory)
251
+
252
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
253
+ # loaded from the Hub.
254
+ if self._auto_class is not None:
255
+ custom_object_save(self, save_directory, config=self)
256
+
257
+ # If we save using the predefined names, we can load using `from_pretrained`
258
+ output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
259
+
260
+ self.to_json_file(output_image_processor_file)
261
+ logger.info(f"Image processor saved in {output_image_processor_file}")
262
+
263
+ if push_to_hub:
264
+ self._upload_modified_files(
265
+ save_directory,
266
+ repo_id,
267
+ files_timestamps,
268
+ commit_message=commit_message,
269
+ token=kwargs.get("token"),
270
+ )
271
+
272
+ return [output_image_processor_file]
273
+
274
+ @classmethod
275
+ def get_image_processor_dict(
276
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
277
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
278
+ """
279
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
280
+ image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`.
281
+
282
+ Parameters:
283
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
284
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
285
+ subfolder (`str`, *optional*, defaults to `""`):
286
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
287
+ specify the folder name here.
288
+ image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
289
+ The name of the file in the model directory to use for the image processor config.
290
+
291
+ Returns:
292
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
293
+ """
294
+ cache_dir = kwargs.pop("cache_dir", None)
295
+ force_download = kwargs.pop("force_download", False)
296
+ resume_download = kwargs.pop("resume_download", None)
297
+ proxies = kwargs.pop("proxies", None)
298
+ token = kwargs.pop("token", None)
299
+ use_auth_token = kwargs.pop("use_auth_token", None)
300
+ local_files_only = kwargs.pop("local_files_only", False)
301
+ revision = kwargs.pop("revision", None)
302
+ subfolder = kwargs.pop("subfolder", "")
303
+ image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
304
+
305
+ from_pipeline = kwargs.pop("_from_pipeline", None)
306
+ from_auto_class = kwargs.pop("_from_auto", False)
307
+
308
+ if use_auth_token is not None:
309
+ warnings.warn(
310
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
311
+ FutureWarning,
312
+ )
313
+ if token is not None:
314
+ raise ValueError(
315
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
316
+ )
317
+ token = use_auth_token
318
+
319
+ user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
320
+ if from_pipeline is not None:
321
+ user_agent["using_pipeline"] = from_pipeline
322
+
323
+ if is_offline_mode() and not local_files_only:
324
+ logger.info("Offline mode: forcing local_files_only=True")
325
+ local_files_only = True
326
+
327
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
328
+ is_local = os.path.isdir(pretrained_model_name_or_path)
329
+ if os.path.isdir(pretrained_model_name_or_path):
330
+ image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
331
+ if os.path.isfile(pretrained_model_name_or_path):
332
+ resolved_image_processor_file = pretrained_model_name_or_path
333
+ is_local = True
334
+ elif is_remote_url(pretrained_model_name_or_path):
335
+ image_processor_file = pretrained_model_name_or_path
336
+ resolved_image_processor_file = download_url(pretrained_model_name_or_path)
337
+ else:
338
+ image_processor_file = image_processor_filename
339
+ try:
340
+ # Load from local folder or from cache or download from model Hub and cache
341
+ resolved_image_processor_file = cached_file(
342
+ pretrained_model_name_or_path,
343
+ image_processor_file,
344
+ cache_dir=cache_dir,
345
+ force_download=force_download,
346
+ proxies=proxies,
347
+ resume_download=resume_download,
348
+ local_files_only=local_files_only,
349
+ token=token,
350
+ user_agent=user_agent,
351
+ revision=revision,
352
+ subfolder=subfolder,
353
+ )
354
+ except EnvironmentError:
355
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
356
+ # the original exception.
357
+ raise
358
+ except Exception:
359
+ # For any other exception, we throw a generic error.
360
+ raise EnvironmentError(
361
+ f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
362
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
363
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
364
+ f" directory containing a {image_processor_filename} file"
365
+ )
366
+
367
+ try:
368
+ # Load image_processor dict
369
+ with open(resolved_image_processor_file, "r", encoding="utf-8") as reader:
370
+ text = reader.read()
371
+ image_processor_dict = json.loads(text)
372
+
373
+ except json.JSONDecodeError:
374
+ raise EnvironmentError(
375
+ f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
376
+ )
377
+
378
+ if is_local:
379
+ logger.info(f"loading configuration file {resolved_image_processor_file}")
380
+ else:
381
+ logger.info(
382
+ f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
383
+ )
384
+ if "auto_map" in image_processor_dict:
385
+ image_processor_dict["auto_map"] = add_model_info_to_auto_map(
386
+ image_processor_dict["auto_map"], pretrained_model_name_or_path
387
+ )
388
+ if "custom_pipelines" in image_processor_dict:
389
+ image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
390
+ image_processor_dict["custom_pipelines"], pretrained_model_name_or_path
391
+ )
392
+
393
+ return image_processor_dict, kwargs
394
+
395
+ @classmethod
396
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
397
+ """
398
+ Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
399
+
400
+ Args:
401
+ image_processor_dict (`Dict[str, Any]`):
402
+ Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
403
+ retrieved from a pretrained checkpoint by leveraging the
404
+ [`~image_processing_utils.ImageProcessingMixin.to_dict`] method.
405
+ kwargs (`Dict[str, Any]`):
406
+ Additional parameters from which to initialize the image processor object.
407
+
408
+ Returns:
409
+ [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
410
+ parameters.
411
+ """
412
+ image_processor_dict = image_processor_dict.copy()
413
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
414
+
415
+ # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
416
+ # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
417
+ # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
418
+ if "size" in kwargs and "size" in image_processor_dict:
419
+ image_processor_dict["size"] = kwargs.pop("size")
420
+ if "crop_size" in kwargs and "crop_size" in image_processor_dict:
421
+ image_processor_dict["crop_size"] = kwargs.pop("crop_size")
422
+
423
+ image_processor = cls(**image_processor_dict)
424
+
425
+ # Update image_processor with kwargs if needed
426
+ to_remove = []
427
+ for key, value in kwargs.items():
428
+ if hasattr(image_processor, key):
429
+ setattr(image_processor, key, value)
430
+ to_remove.append(key)
431
+ for key in to_remove:
432
+ kwargs.pop(key, None)
433
+
434
+ logger.info(f"Image processor {image_processor}")
435
+ if return_unused_kwargs:
436
+ return image_processor, kwargs
437
+ else:
438
+ return image_processor
439
+
440
+ def to_dict(self) -> Dict[str, Any]:
441
+ """
442
+ Serializes this instance to a Python dictionary.
443
+
444
+ Returns:
445
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
446
+ """
447
+ output = copy.deepcopy(self.__dict__)
448
+ output["image_processor_type"] = self.__class__.__name__
449
+
450
+ return output
451
+
452
+ @classmethod
453
+ def from_json_file(cls, json_file: Union[str, os.PathLike]):
454
+ """
455
+ Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON
456
+ file of parameters.
457
+
458
+ Args:
459
+ json_file (`str` or `os.PathLike`):
460
+ Path to the JSON file containing the parameters.
461
+
462
+ Returns:
463
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
464
+ instantiated from that JSON file.
465
+ """
466
+ with open(json_file, "r", encoding="utf-8") as reader:
467
+ text = reader.read()
468
+ image_processor_dict = json.loads(text)
469
+ return cls(**image_processor_dict)
470
+
471
+ def to_json_string(self) -> str:
472
+ """
473
+ Serializes this instance to a JSON string.
474
+
475
+ Returns:
476
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
477
+ """
478
+ dictionary = self.to_dict()
479
+
480
+ for key, value in dictionary.items():
481
+ if isinstance(value, np.ndarray):
482
+ dictionary[key] = value.tolist()
483
+
484
+ # make sure private name "_processor_class" is correctly
485
+ # saved as "processor_class"
486
+ _processor_class = dictionary.pop("_processor_class", None)
487
+ if _processor_class is not None:
488
+ dictionary["processor_class"] = _processor_class
489
+
490
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
491
+
492
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
493
+ """
494
+ Save this instance to a JSON file.
495
+
496
+ Args:
497
+ json_file_path (`str` or `os.PathLike`):
498
+ Path to the JSON file in which this image_processor instance's parameters will be saved.
499
+ """
500
+ with open(json_file_path, "w", encoding="utf-8") as writer:
501
+ writer.write(self.to_json_string())
502
+
503
+ def __repr__(self):
504
+ return f"{self.__class__.__name__} {self.to_json_string()}"
505
+
506
+ @classmethod
507
+ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
508
+ """
509
+ Register this class with a given auto class. This should only be used for custom image processors as the ones
510
+ in the library are already mapped with `AutoImageProcessor `.
511
+
512
+ <Tip warning={true}>
513
+
514
+ This API is experimental and may have some slight breaking changes in the next releases.
515
+
516
+ </Tip>
517
+
518
+ Args:
519
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
520
+ The auto class to register this new image processor with.
521
+ """
522
+ if not isinstance(auto_class, str):
523
+ auto_class = auto_class.__name__
524
+
525
+ import transformers.models.auto as auto_module
526
+
527
+ if not hasattr(auto_module, auto_class):
528
+ raise ValueError(f"{auto_class} is not a valid auto class.")
529
+
530
+ cls._auto_class = auto_class
531
+
532
+ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
533
+ """
534
+ Convert a single or a list of urls into the corresponding `PIL.Image` objects.
535
+
536
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
537
+ returned.
538
+ """
539
+ headers = {
540
+ "User-Agent": (
541
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
542
+ " Safari/537.36"
543
+ )
544
+ }
545
+ if isinstance(image_url_or_urls, list):
546
+ return [self.fetch_images(x) for x in image_url_or_urls]
547
+ elif isinstance(image_url_or_urls, str):
548
+ response = requests.get(image_url_or_urls, stream=True, headers=headers)
549
+ response.raise_for_status()
550
+ return Image.open(BytesIO(response.content))
551
+ else:
552
+ raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
553
+
554
+
555
+ ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
556
+ if ImageProcessingMixin.push_to_hub.__doc__ is not None:
557
+ ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format(
558
+ object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
559
+ )
.venv/lib/python3.11/site-packages/transformers/image_processing_utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Iterable, Optional, Union
17
+
18
+ import numpy as np
19
+
20
+ from .image_processing_base import BatchFeature, ImageProcessingMixin
21
+ from .image_transforms import center_crop, normalize, rescale
22
+ from .image_utils import ChannelDimension
23
+ from .utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ INIT_SERVICE_KWARGS = [
30
+ "processor_class",
31
+ "image_processor_type",
32
+ ]
33
+
34
+
35
+ class BaseImageProcessor(ImageProcessingMixin):
36
+ def __init__(self, **kwargs):
37
+ super().__init__(**kwargs)
38
+
39
+ def __call__(self, images, **kwargs) -> BatchFeature:
40
+ """Preprocess an image or a batch of images."""
41
+ return self.preprocess(images, **kwargs)
42
+
43
+ def preprocess(self, images, **kwargs) -> BatchFeature:
44
+ raise NotImplementedError("Each image processor must implement its own preprocess method")
45
+
46
+ def rescale(
47
+ self,
48
+ image: np.ndarray,
49
+ scale: float,
50
+ data_format: Optional[Union[str, ChannelDimension]] = None,
51
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
52
+ **kwargs,
53
+ ) -> np.ndarray:
54
+ """
55
+ Rescale an image by a scale factor. image = image * scale.
56
+
57
+ Args:
58
+ image (`np.ndarray`):
59
+ Image to rescale.
60
+ scale (`float`):
61
+ The scaling factor to rescale pixel values by.
62
+ data_format (`str` or `ChannelDimension`, *optional*):
63
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
64
+ image is used. Can be one of:
65
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
66
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
67
+ input_data_format (`ChannelDimension` or `str`, *optional*):
68
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
69
+ from the input image. Can be one of:
70
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
71
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
72
+
73
+ Returns:
74
+ `np.ndarray`: The rescaled image.
75
+ """
76
+ return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
77
+
78
+ def normalize(
79
+ self,
80
+ image: np.ndarray,
81
+ mean: Union[float, Iterable[float]],
82
+ std: Union[float, Iterable[float]],
83
+ data_format: Optional[Union[str, ChannelDimension]] = None,
84
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
85
+ **kwargs,
86
+ ) -> np.ndarray:
87
+ """
88
+ Normalize an image. image = (image - image_mean) / image_std.
89
+
90
+ Args:
91
+ image (`np.ndarray`):
92
+ Image to normalize.
93
+ mean (`float` or `Iterable[float]`):
94
+ Image mean to use for normalization.
95
+ std (`float` or `Iterable[float]`):
96
+ Image standard deviation to use for normalization.
97
+ data_format (`str` or `ChannelDimension`, *optional*):
98
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
99
+ image is used. Can be one of:
100
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
101
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
102
+ input_data_format (`ChannelDimension` or `str`, *optional*):
103
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
104
+ from the input image. Can be one of:
105
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
106
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
107
+
108
+ Returns:
109
+ `np.ndarray`: The normalized image.
110
+ """
111
+ return normalize(
112
+ image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
113
+ )
114
+
115
+ def center_crop(
116
+ self,
117
+ image: np.ndarray,
118
+ size: Dict[str, int],
119
+ data_format: Optional[Union[str, ChannelDimension]] = None,
120
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
121
+ **kwargs,
122
+ ) -> np.ndarray:
123
+ """
124
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
125
+ any edge, the image is padded with 0's and then center cropped.
126
+
127
+ Args:
128
+ image (`np.ndarray`):
129
+ Image to center crop.
130
+ size (`Dict[str, int]`):
131
+ Size of the output image.
132
+ data_format (`str` or `ChannelDimension`, *optional*):
133
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
134
+ image is used. Can be one of:
135
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
136
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
137
+ input_data_format (`ChannelDimension` or `str`, *optional*):
138
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
139
+ from the input image. Can be one of:
140
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
141
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
142
+ """
143
+ size = get_size_dict(size)
144
+ if "height" not in size or "width" not in size:
145
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
146
+ return center_crop(
147
+ image,
148
+ size=(size["height"], size["width"]),
149
+ data_format=data_format,
150
+ input_data_format=input_data_format,
151
+ **kwargs,
152
+ )
153
+
154
+ def to_dict(self):
155
+ encoder_dict = super().to_dict()
156
+ encoder_dict.pop("_valid_processor_keys", None)
157
+ return encoder_dict
158
+
159
+
160
+ VALID_SIZE_DICT_KEYS = (
161
+ {"height", "width"},
162
+ {"shortest_edge"},
163
+ {"shortest_edge", "longest_edge"},
164
+ {"longest_edge"},
165
+ {"max_height", "max_width"},
166
+ )
167
+
168
+
169
+ def is_valid_size_dict(size_dict):
170
+ if not isinstance(size_dict, dict):
171
+ return False
172
+
173
+ size_dict_keys = set(size_dict.keys())
174
+ for allowed_keys in VALID_SIZE_DICT_KEYS:
175
+ if size_dict_keys == allowed_keys:
176
+ return True
177
+ return False
178
+
179
+
180
+ def convert_to_size_dict(
181
+ size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
182
+ ):
183
+ # By default, if size is an int we assume it represents a tuple of (size, size).
184
+ if isinstance(size, int) and default_to_square:
185
+ if max_size is not None:
186
+ raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
187
+ return {"height": size, "width": size}
188
+ # In other configs, if size is an int and default_to_square is False, size represents the length of
189
+ # the shortest edge after resizing.
190
+ elif isinstance(size, int) and not default_to_square:
191
+ size_dict = {"shortest_edge": size}
192
+ if max_size is not None:
193
+ size_dict["longest_edge"] = max_size
194
+ return size_dict
195
+ # Otherwise, if size is a tuple it's either (height, width) or (width, height)
196
+ elif isinstance(size, (tuple, list)) and height_width_order:
197
+ return {"height": size[0], "width": size[1]}
198
+ elif isinstance(size, (tuple, list)) and not height_width_order:
199
+ return {"height": size[1], "width": size[0]}
200
+ elif size is None and max_size is not None:
201
+ if default_to_square:
202
+ raise ValueError("Cannot specify both default_to_square=True and max_size")
203
+ return {"longest_edge": max_size}
204
+
205
+ raise ValueError(f"Could not convert size input to size dict: {size}")
206
+
207
+
208
+ def get_size_dict(
209
+ size: Union[int, Iterable[int], Dict[str, int]] = None,
210
+ max_size: Optional[int] = None,
211
+ height_width_order: bool = True,
212
+ default_to_square: bool = True,
213
+ param_name="size",
214
+ ) -> dict:
215
+ """
216
+ Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
217
+ compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
218
+ width) or (width, height) format.
219
+
220
+ - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
221
+ size[0]}` if `height_width_order` is `False`.
222
+ - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
223
+ - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
224
+ is set, it is added to the dict as `{"longest_edge": max_size}`.
225
+
226
+ Args:
227
+ size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
228
+ The `size` parameter to be cast into a size dictionary.
229
+ max_size (`Optional[int]`, *optional*):
230
+ The `max_size` parameter to be cast into a size dictionary.
231
+ height_width_order (`bool`, *optional*, defaults to `True`):
232
+ If `size` is a tuple, whether it's in (height, width) or (width, height) order.
233
+ default_to_square (`bool`, *optional*, defaults to `True`):
234
+ If `size` is an int, whether to default to a square image or not.
235
+ """
236
+ if not isinstance(size, dict):
237
+ size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
238
+ logger.info(
239
+ f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
240
+ f" Converted to {size_dict}.",
241
+ )
242
+ else:
243
+ size_dict = size
244
+
245
+ if not is_valid_size_dict(size_dict):
246
+ raise ValueError(
247
+ f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
248
+ )
249
+ return size_dict
250
+
251
+
252
+ def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
253
+ """
254
+ Selects the best resolution from a list of possible resolutions based on the original size.
255
+
256
+ This is done by calculating the effective and wasted resolution for each possible resolution.
257
+
258
+ The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
259
+
260
+ Args:
261
+ original_size (tuple):
262
+ The original size of the image in the format (height, width).
263
+ possible_resolutions (list):
264
+ A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
265
+
266
+ Returns:
267
+ tuple: The best fit resolution in the format (height, width).
268
+ """
269
+ original_height, original_width = original_size
270
+ best_fit = None
271
+ max_effective_resolution = 0
272
+ min_wasted_resolution = float("inf")
273
+
274
+ for height, width in possible_resolutions:
275
+ scale = min(width / original_width, height / original_height)
276
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
277
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
278
+ wasted_resolution = (width * height) - effective_resolution
279
+
280
+ if effective_resolution > max_effective_resolution or (
281
+ effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
282
+ ):
283
+ max_effective_resolution = effective_resolution
284
+ min_wasted_resolution = wasted_resolution
285
+ best_fit = (height, width)
286
+
287
+ return best_fit
.venv/lib/python3.11/site-packages/transformers/image_processing_utils_fast.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ from dataclasses import dataclass
18
+ from typing import Any, Iterable, List, Optional, Tuple
19
+
20
+ from .image_processing_utils import BaseImageProcessor
21
+ from .utils.import_utils import is_torch_available, is_torchvision_available
22
+
23
+
24
+ if is_torchvision_available():
25
+ from torchvision.transforms import Compose
26
+
27
+ if is_torch_available():
28
+ import torch
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class SizeDict:
33
+ """
34
+ Hashable dictionary to store image size information.
35
+ """
36
+
37
+ height: int = None
38
+ width: int = None
39
+ longest_edge: int = None
40
+ shortest_edge: int = None
41
+ max_height: int = None
42
+ max_width: int = None
43
+
44
+ def __getitem__(self, key):
45
+ if hasattr(self, key):
46
+ return getattr(self, key)
47
+ raise KeyError(f"Key {key} not found in SizeDict.")
48
+
49
+
50
+ class BaseImageProcessorFast(BaseImageProcessor):
51
+ _transform_params = None
52
+
53
+ def _build_transforms(self, **kwargs) -> "Compose":
54
+ """
55
+ Given the input settings e.g. do_resize, build the image transforms.
56
+ """
57
+ raise NotImplementedError
58
+
59
+ def _validate_params(self, **kwargs) -> None:
60
+ for k, v in kwargs.items():
61
+ if k not in self._transform_params:
62
+ raise ValueError(f"Invalid transform parameter {k}={v}.")
63
+
64
+ @functools.lru_cache(maxsize=1)
65
+ def get_transforms(self, **kwargs) -> "Compose":
66
+ self._validate_params(**kwargs)
67
+ return self._build_transforms(**kwargs)
68
+
69
+ def to_dict(self):
70
+ encoder_dict = super().to_dict()
71
+ encoder_dict.pop("_transform_params", None)
72
+ return encoder_dict
73
+
74
+
75
+ def get_image_size_for_max_height_width(
76
+ image_size: Tuple[int, int],
77
+ max_height: int,
78
+ max_width: int,
79
+ ) -> Tuple[int, int]:
80
+ """
81
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
82
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
83
+ to at least one of the edges be equal to max_height or max_width.
84
+
85
+ For example:
86
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
87
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
88
+
89
+ Args:
90
+ image_size (`Tuple[int, int]`):
91
+ The image to resize.
92
+ max_height (`int`):
93
+ The maximum allowed height.
94
+ max_width (`int`):
95
+ The maximum allowed width.
96
+ """
97
+ height, width = image_size
98
+ height_scale = max_height / height
99
+ width_scale = max_width / width
100
+ min_scale = min(height_scale, width_scale)
101
+ new_height = int(height * min_scale)
102
+ new_width = int(width * min_scale)
103
+ return new_height, new_width
104
+
105
+
106
+ def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
107
+ """
108
+ Squeezes a tensor, but only if the axis specified has dim 1.
109
+ """
110
+ if axis is None:
111
+ return tensor.squeeze()
112
+
113
+ try:
114
+ return tensor.squeeze(axis=axis)
115
+ except ValueError:
116
+ return tensor
117
+
118
+
119
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
120
+ """
121
+ Return the maximum value across all indices of an iterable of values.
122
+ """
123
+ return [max(values_i) for values_i in zip(*values)]
124
+
125
+
126
+ def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]:
127
+ """
128
+ Get the maximum height and width across all images in a batch.
129
+ """
130
+
131
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
132
+
133
+ return (max_height, max_width)
.venv/lib/python3.11/site-packages/transformers/image_transforms.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warnings
17
+ from math import ceil
18
+ from typing import Iterable, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+
22
+ from .image_utils import (
23
+ ChannelDimension,
24
+ ImageInput,
25
+ get_channel_dimension_axis,
26
+ get_image_size,
27
+ infer_channel_dimension_format,
28
+ )
29
+ from .utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
30
+ from .utils.import_utils import (
31
+ is_flax_available,
32
+ is_tf_available,
33
+ is_torch_available,
34
+ is_torchvision_available,
35
+ is_torchvision_v2_available,
36
+ is_vision_available,
37
+ requires_backends,
38
+ )
39
+
40
+
41
+ if is_vision_available():
42
+ import PIL
43
+
44
+ from .image_utils import PILImageResampling
45
+
46
+ if is_torch_available():
47
+ import torch
48
+
49
+ if is_tf_available():
50
+ import tensorflow as tf
51
+
52
+ if is_flax_available():
53
+ import jax.numpy as jnp
54
+
55
+ if is_torchvision_v2_available():
56
+ from torchvision.transforms.v2 import functional as F
57
+ elif is_torchvision_available():
58
+ from torchvision.transforms import functional as F
59
+
60
+
61
+ def to_channel_dimension_format(
62
+ image: np.ndarray,
63
+ channel_dim: Union[ChannelDimension, str],
64
+ input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
65
+ ) -> np.ndarray:
66
+ """
67
+ Converts `image` to the channel dimension format specified by `channel_dim`.
68
+
69
+ Args:
70
+ image (`numpy.ndarray`):
71
+ The image to have its channel dimension set.
72
+ channel_dim (`ChannelDimension`):
73
+ The channel dimension format to use.
74
+ input_channel_dim (`ChannelDimension`, *optional*):
75
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
76
+
77
+ Returns:
78
+ `np.ndarray`: The image with the channel dimension set to `channel_dim`.
79
+ """
80
+ if not isinstance(image, np.ndarray):
81
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
82
+
83
+ if input_channel_dim is None:
84
+ input_channel_dim = infer_channel_dimension_format(image)
85
+
86
+ target_channel_dim = ChannelDimension(channel_dim)
87
+ if input_channel_dim == target_channel_dim:
88
+ return image
89
+
90
+ if target_channel_dim == ChannelDimension.FIRST:
91
+ image = image.transpose((2, 0, 1))
92
+ elif target_channel_dim == ChannelDimension.LAST:
93
+ image = image.transpose((1, 2, 0))
94
+ else:
95
+ raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))
96
+
97
+ return image
98
+
99
+
100
+ def rescale(
101
+ image: np.ndarray,
102
+ scale: float,
103
+ data_format: Optional[ChannelDimension] = None,
104
+ dtype: np.dtype = np.float32,
105
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
106
+ ) -> np.ndarray:
107
+ """
108
+ Rescales `image` by `scale`.
109
+
110
+ Args:
111
+ image (`np.ndarray`):
112
+ The image to rescale.
113
+ scale (`float`):
114
+ The scale to use for rescaling the image.
115
+ data_format (`ChannelDimension`, *optional*):
116
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
117
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
118
+ The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
119
+ extractors.
120
+ input_data_format (`ChannelDimension`, *optional*):
121
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
122
+
123
+ Returns:
124
+ `np.ndarray`: The rescaled image.
125
+ """
126
+ if not isinstance(image, np.ndarray):
127
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
128
+
129
+ rescaled_image = image.astype(np.float64) * scale # Numpy type promotion has changed, so always upcast first
130
+ if data_format is not None:
131
+ rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
132
+
133
+ rescaled_image = rescaled_image.astype(dtype) # Finally downcast to the desired dtype at the end
134
+
135
+ return rescaled_image
136
+
137
+
138
+ def _rescale_for_pil_conversion(image):
139
+ """
140
+ Detects whether or not the image needs to be rescaled before being converted to a PIL image.
141
+
142
+ The assumption is that if the image is of type `np.float` and all values are between 0 and 1, it needs to be
143
+ rescaled.
144
+ """
145
+ if image.dtype == np.uint8:
146
+ do_rescale = False
147
+ elif np.allclose(image, image.astype(int)):
148
+ if np.all(0 <= image) and np.all(image <= 255):
149
+ do_rescale = False
150
+ else:
151
+ raise ValueError(
152
+ "The image to be converted to a PIL image contains values outside the range [0, 255], "
153
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
154
+ )
155
+ elif np.all(0 <= image) and np.all(image <= 1):
156
+ do_rescale = True
157
+ else:
158
+ raise ValueError(
159
+ "The image to be converted to a PIL image contains values outside the range [0, 1], "
160
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
161
+ )
162
+ return do_rescale
163
+
164
+
165
+ def to_pil_image(
166
+ image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
167
+ do_rescale: Optional[bool] = None,
168
+ image_mode: Optional[str] = None,
169
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
170
+ ) -> "PIL.Image.Image":
171
+ """
172
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
173
+ needed.
174
+
175
+ Args:
176
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`):
177
+ The image to convert to the `PIL.Image` format.
178
+ do_rescale (`bool`, *optional*):
179
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
180
+ to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
181
+ and `False` otherwise.
182
+ image_mode (`str`, *optional*):
183
+ The mode to use for the PIL image. If unset, will use the default mode for the input image type.
184
+ input_data_format (`ChannelDimension`, *optional*):
185
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
186
+
187
+ Returns:
188
+ `PIL.Image.Image`: The converted image.
189
+ """
190
+ requires_backends(to_pil_image, ["vision"])
191
+
192
+ if isinstance(image, PIL.Image.Image):
193
+ return image
194
+
195
+ # Convert all tensors to numpy arrays before converting to PIL image
196
+ if is_torch_tensor(image) or is_tf_tensor(image):
197
+ image = image.numpy()
198
+ elif is_jax_tensor(image):
199
+ image = np.array(image)
200
+ elif not isinstance(image, np.ndarray):
201
+ raise ValueError("Input image type not supported: {}".format(type(image)))
202
+
203
+ # If the channel has been moved to first dim, we put it back at the end.
204
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
205
+
206
+ # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
207
+ image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
208
+
209
+ # PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed.
210
+ do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale
211
+
212
+ if do_rescale:
213
+ image = rescale(image, 255)
214
+
215
+ image = image.astype(np.uint8)
216
+ return PIL.Image.fromarray(image, mode=image_mode)
217
+
218
+
219
+ # Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
220
+ def get_resize_output_image_size(
221
+ input_image: np.ndarray,
222
+ size: Union[int, Tuple[int, int], List[int], Tuple[int]],
223
+ default_to_square: bool = True,
224
+ max_size: Optional[int] = None,
225
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
226
+ ) -> tuple:
227
+ """
228
+ Find the target (height, width) dimension of the output image after resizing given the input image and the desired
229
+ size.
230
+
231
+ Args:
232
+ input_image (`np.ndarray`):
233
+ The image to resize.
234
+ size (`int` or `Tuple[int, int]` or List[int] or `Tuple[int]`):
235
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to
236
+ this.
237
+
238
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
239
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this
240
+ number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
241
+ default_to_square (`bool`, *optional*, defaults to `True`):
242
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square
243
+ (`size`,`size`). If set to `False`, will replicate
244
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
245
+ with support for resizing only the smallest edge and providing an optional `max_size`.
246
+ max_size (`int`, *optional*):
247
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater
248
+ than `max_size` after being resized according to `size`, then the image is resized again so that the longer
249
+ edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
250
+ than `size`. Only used if `default_to_square` is `False`.
251
+ input_data_format (`ChannelDimension`, *optional*):
252
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
253
+
254
+ Returns:
255
+ `tuple`: The target (height, width) dimension of the output image after resizing.
256
+ """
257
+ if isinstance(size, (tuple, list)):
258
+ if len(size) == 2:
259
+ return tuple(size)
260
+ elif len(size) == 1:
261
+ # Perform same logic as if size was an int
262
+ size = size[0]
263
+ else:
264
+ raise ValueError("size must have 1 or 2 elements if it is a list or tuple")
265
+
266
+ if default_to_square:
267
+ return (size, size)
268
+
269
+ height, width = get_image_size(input_image, input_data_format)
270
+ short, long = (width, height) if width <= height else (height, width)
271
+ requested_new_short = size
272
+
273
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
274
+
275
+ if max_size is not None:
276
+ if max_size <= requested_new_short:
277
+ raise ValueError(
278
+ f"max_size = {max_size} must be strictly greater than the requested "
279
+ f"size for the smaller edge size = {size}"
280
+ )
281
+ if new_long > max_size:
282
+ new_short, new_long = int(max_size * new_short / new_long), max_size
283
+
284
+ return (new_long, new_short) if width <= height else (new_short, new_long)
285
+
286
+
287
+ def resize(
288
+ image: np.ndarray,
289
+ size: Tuple[int, int],
290
+ resample: "PILImageResampling" = None,
291
+ reducing_gap: Optional[int] = None,
292
+ data_format: Optional[ChannelDimension] = None,
293
+ return_numpy: bool = True,
294
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
295
+ ) -> np.ndarray:
296
+ """
297
+ Resizes `image` to `(height, width)` specified by `size` using the PIL library.
298
+
299
+ Args:
300
+ image (`np.ndarray`):
301
+ The image to resize.
302
+ size (`Tuple[int, int]`):
303
+ The size to use for resizing the image.
304
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
305
+ The filter to user for resampling.
306
+ reducing_gap (`int`, *optional*):
307
+ Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to
308
+ the fair resampling. See corresponding Pillow documentation for more details.
309
+ data_format (`ChannelDimension`, *optional*):
310
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
311
+ return_numpy (`bool`, *optional*, defaults to `True`):
312
+ Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
313
+ returned.
314
+ input_data_format (`ChannelDimension`, *optional*):
315
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
316
+
317
+ Returns:
318
+ `np.ndarray`: The resized image.
319
+ """
320
+ requires_backends(resize, ["vision"])
321
+
322
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
323
+
324
+ if not len(size) == 2:
325
+ raise ValueError("size must have 2 elements")
326
+
327
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
328
+ # The resized image from PIL will always have channels last, so find the input format first.
329
+ if input_data_format is None:
330
+ input_data_format = infer_channel_dimension_format(image)
331
+ data_format = input_data_format if data_format is None else data_format
332
+
333
+ # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
334
+ # the pillow library to resize the image and then convert back to numpy
335
+ do_rescale = False
336
+ if not isinstance(image, PIL.Image.Image):
337
+ do_rescale = _rescale_for_pil_conversion(image)
338
+ image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
339
+ height, width = size
340
+ # PIL images are in the format (width, height)
341
+ resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
342
+
343
+ if return_numpy:
344
+ resized_image = np.array(resized_image)
345
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
346
+ # so we need to add it back if necessary.
347
+ resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
348
+ # The image is always in channels last format after converting from a PIL image
349
+ resized_image = to_channel_dimension_format(
350
+ resized_image, data_format, input_channel_dim=ChannelDimension.LAST
351
+ )
352
+ # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
353
+ # rescale it back to the original range.
354
+ resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
355
+ return resized_image
356
+
357
+
358
+ def normalize(
359
+ image: np.ndarray,
360
+ mean: Union[float, Iterable[float]],
361
+ std: Union[float, Iterable[float]],
362
+ data_format: Optional[ChannelDimension] = None,
363
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
364
+ ) -> np.ndarray:
365
+ """
366
+ Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
367
+
368
+ image = (image - mean) / std
369
+
370
+ Args:
371
+ image (`np.ndarray`):
372
+ The image to normalize.
373
+ mean (`float` or `Iterable[float]`):
374
+ The mean to use for normalization.
375
+ std (`float` or `Iterable[float]`):
376
+ The standard deviation to use for normalization.
377
+ data_format (`ChannelDimension`, *optional*):
378
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
379
+ input_data_format (`ChannelDimension`, *optional*):
380
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
381
+ """
382
+ if not isinstance(image, np.ndarray):
383
+ raise ValueError("image must be a numpy array")
384
+
385
+ if input_data_format is None:
386
+ input_data_format = infer_channel_dimension_format(image)
387
+
388
+ channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
389
+ num_channels = image.shape[channel_axis]
390
+
391
+ # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
392
+ # We preserve the original dtype if it is a float type to prevent upcasting float16.
393
+ if not np.issubdtype(image.dtype, np.floating):
394
+ image = image.astype(np.float32)
395
+
396
+ if isinstance(mean, Iterable):
397
+ if len(mean) != num_channels:
398
+ raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
399
+ else:
400
+ mean = [mean] * num_channels
401
+ mean = np.array(mean, dtype=image.dtype)
402
+
403
+ if isinstance(std, Iterable):
404
+ if len(std) != num_channels:
405
+ raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
406
+ else:
407
+ std = [std] * num_channels
408
+ std = np.array(std, dtype=image.dtype)
409
+
410
+ if input_data_format == ChannelDimension.LAST:
411
+ image = (image - mean) / std
412
+ else:
413
+ image = ((image.T - mean) / std).T
414
+
415
+ image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
416
+ return image
417
+
418
+
419
+ def center_crop(
420
+ image: np.ndarray,
421
+ size: Tuple[int, int],
422
+ data_format: Optional[Union[str, ChannelDimension]] = None,
423
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
424
+ return_numpy: Optional[bool] = None,
425
+ ) -> np.ndarray:
426
+ """
427
+ Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
428
+ the size given, it will be padded (so the returned result will always be of size `size`).
429
+
430
+ Args:
431
+ image (`np.ndarray`):
432
+ The image to crop.
433
+ size (`Tuple[int, int]`):
434
+ The target size for the cropped image.
435
+ data_format (`str` or `ChannelDimension`, *optional*):
436
+ The channel dimension format for the output image. Can be one of:
437
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
438
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
439
+ If unset, will use the inferred format of the input image.
440
+ input_data_format (`str` or `ChannelDimension`, *optional*):
441
+ The channel dimension format for the input image. Can be one of:
442
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
443
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
444
+ If unset, will use the inferred format of the input image.
445
+ return_numpy (`bool`, *optional*):
446
+ Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
447
+ previous ImageFeatureExtractionMixin method.
448
+ - Unset: will return the same type as the input image.
449
+ - `True`: will return a numpy array.
450
+ - `False`: will return a `PIL.Image.Image` object.
451
+ Returns:
452
+ `np.ndarray`: The cropped image.
453
+ """
454
+ requires_backends(center_crop, ["vision"])
455
+
456
+ if return_numpy is not None:
457
+ warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
458
+
459
+ return_numpy = True if return_numpy is None else return_numpy
460
+
461
+ if not isinstance(image, np.ndarray):
462
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
463
+
464
+ if not isinstance(size, Iterable) or len(size) != 2:
465
+ raise ValueError("size must have 2 elements representing the height and width of the output image")
466
+
467
+ if input_data_format is None:
468
+ input_data_format = infer_channel_dimension_format(image)
469
+ output_data_format = data_format if data_format is not None else input_data_format
470
+
471
+ # We perform the crop in (C, H, W) format and then convert to the output format
472
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
473
+
474
+ orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
475
+ crop_height, crop_width = size
476
+ crop_height, crop_width = int(crop_height), int(crop_width)
477
+
478
+ # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
479
+ top = (orig_height - crop_height) // 2
480
+ bottom = top + crop_height
481
+ # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
482
+ left = (orig_width - crop_width) // 2
483
+ right = left + crop_width
484
+
485
+ # Check if cropped area is within image boundaries
486
+ if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
487
+ image = image[..., top:bottom, left:right]
488
+ image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
489
+ return image
490
+
491
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
492
+ new_height = max(crop_height, orig_height)
493
+ new_width = max(crop_width, orig_width)
494
+ new_shape = image.shape[:-2] + (new_height, new_width)
495
+ new_image = np.zeros_like(image, shape=new_shape)
496
+
497
+ # If the image is too small, pad it with zeros
498
+ top_pad = ceil((new_height - orig_height) / 2)
499
+ bottom_pad = top_pad + orig_height
500
+ left_pad = ceil((new_width - orig_width) / 2)
501
+ right_pad = left_pad + orig_width
502
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
503
+
504
+ top += top_pad
505
+ bottom += top_pad
506
+ left += left_pad
507
+ right += left_pad
508
+
509
+ new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
510
+ new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
511
+
512
+ if not return_numpy:
513
+ new_image = to_pil_image(new_image)
514
+
515
+ return new_image
516
+
517
+
518
+ def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
519
+ center_x, center_y, width, height = bboxes_center.unbind(-1)
520
+ bbox_corners = torch.stack(
521
+ # top left x, top left y, bottom right x, bottom right y
522
+ [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
523
+ dim=-1,
524
+ )
525
+ return bbox_corners
526
+
527
+
528
+ def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
529
+ center_x, center_y, width, height = bboxes_center.T
530
+ bboxes_corners = np.stack(
531
+ # top left x, top left y, bottom right x, bottom right y
532
+ [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
533
+ axis=-1,
534
+ )
535
+ return bboxes_corners
536
+
537
+
538
+ def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor":
539
+ center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1)
540
+ bboxes_corners = tf.stack(
541
+ # top left x, top left y, bottom right x, bottom right y
542
+ [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
543
+ axis=-1,
544
+ )
545
+ return bboxes_corners
546
+
547
+
548
+ # 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
549
+ def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
550
+ """
551
+ Converts bounding boxes from center format to corners format.
552
+
553
+ center format: contains the coordinate for the center of the box and its width, height dimensions
554
+ (center_x, center_y, width, height)
555
+ corners format: contains the coodinates for the top-left and bottom-right corners of the box
556
+ (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
557
+ """
558
+ # Function is used during model forward pass, so we use the input framework if possible, without
559
+ # converting to numpy
560
+ if is_torch_tensor(bboxes_center):
561
+ return _center_to_corners_format_torch(bboxes_center)
562
+ elif isinstance(bboxes_center, np.ndarray):
563
+ return _center_to_corners_format_numpy(bboxes_center)
564
+ elif is_tf_tensor(bboxes_center):
565
+ return _center_to_corners_format_tf(bboxes_center)
566
+
567
+ raise ValueError(f"Unsupported input type {type(bboxes_center)}")
568
+
569
+
570
+ def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
571
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
572
+ b = [
573
+ (top_left_x + bottom_right_x) / 2, # center x
574
+ (top_left_y + bottom_right_y) / 2, # center y
575
+ (bottom_right_x - top_left_x), # width
576
+ (bottom_right_y - top_left_y), # height
577
+ ]
578
+ return torch.stack(b, dim=-1)
579
+
580
+
581
+ def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
582
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
583
+ bboxes_center = np.stack(
584
+ [
585
+ (top_left_x + bottom_right_x) / 2, # center x
586
+ (top_left_y + bottom_right_y) / 2, # center y
587
+ (bottom_right_x - top_left_x), # width
588
+ (bottom_right_y - top_left_y), # height
589
+ ],
590
+ axis=-1,
591
+ )
592
+ return bboxes_center
593
+
594
+
595
+ def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor":
596
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1)
597
+ bboxes_center = tf.stack(
598
+ [
599
+ (top_left_x + bottom_right_x) / 2, # center x
600
+ (top_left_y + bottom_right_y) / 2, # center y
601
+ (bottom_right_x - top_left_x), # width
602
+ (bottom_right_y - top_left_y), # height
603
+ ],
604
+ axis=-1,
605
+ )
606
+ return bboxes_center
607
+
608
+
609
+ def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
610
+ """
611
+ Converts bounding boxes from corners format to center format.
612
+
613
+ corners format: contains the coordinates for the top-left and bottom-right corners of the box
614
+ (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
615
+ center format: contains the coordinate for the center of the box and its the width, height dimensions
616
+ (center_x, center_y, width, height)
617
+ """
618
+ # Inverse function accepts different input types so implemented here too
619
+ if is_torch_tensor(bboxes_corners):
620
+ return _corners_to_center_format_torch(bboxes_corners)
621
+ elif isinstance(bboxes_corners, np.ndarray):
622
+ return _corners_to_center_format_numpy(bboxes_corners)
623
+ elif is_tf_tensor(bboxes_corners):
624
+ return _corners_to_center_format_tf(bboxes_corners)
625
+
626
+ raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
627
+
628
+
629
+ # 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
630
+ # Copyright (c) 2018, Alexander Kirillov
631
+ # All rights reserved.
632
+ def rgb_to_id(color):
633
+ """
634
+ Converts RGB color to unique ID.
635
+ """
636
+ if isinstance(color, np.ndarray) and len(color.shape) == 3:
637
+ if color.dtype == np.uint8:
638
+ color = color.astype(np.int32)
639
+ return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
640
+ return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
641
+
642
+
643
+ def id_to_rgb(id_map):
644
+ """
645
+ Converts unique ID to RGB color.
646
+ """
647
+ if isinstance(id_map, np.ndarray):
648
+ id_map_copy = id_map.copy()
649
+ rgb_shape = tuple(list(id_map.shape) + [3])
650
+ rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
651
+ for i in range(3):
652
+ rgb_map[..., i] = id_map_copy % 256
653
+ id_map_copy //= 256
654
+ return rgb_map
655
+ color = []
656
+ for _ in range(3):
657
+ color.append(id_map % 256)
658
+ id_map //= 256
659
+ return color
660
+
661
+
662
+ class PaddingMode(ExplicitEnum):
663
+ """
664
+ Enum class for the different padding modes to use when padding images.
665
+ """
666
+
667
+ CONSTANT = "constant"
668
+ REFLECT = "reflect"
669
+ REPLICATE = "replicate"
670
+ SYMMETRIC = "symmetric"
671
+
672
+
673
+ def pad(
674
+ image: np.ndarray,
675
+ padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
676
+ mode: PaddingMode = PaddingMode.CONSTANT,
677
+ constant_values: Union[float, Iterable[float]] = 0.0,
678
+ data_format: Optional[Union[str, ChannelDimension]] = None,
679
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
680
+ ) -> np.ndarray:
681
+ """
682
+ Pads the `image` with the specified (height, width) `padding` and `mode`.
683
+
684
+ Args:
685
+ image (`np.ndarray`):
686
+ The image to pad.
687
+ padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
688
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
689
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
690
+ - `((before, after),)` yields same before and after pad for height and width.
691
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
692
+ mode (`PaddingMode`):
693
+ The padding mode to use. Can be one of:
694
+ - `"constant"`: pads with a constant value.
695
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
696
+ vector along each axis.
697
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
698
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
699
+ constant_values (`float` or `Iterable[float]`, *optional*):
700
+ The value to use for the padding if `mode` is `"constant"`.
701
+ data_format (`str` or `ChannelDimension`, *optional*):
702
+ The channel dimension format for the output image. Can be one of:
703
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
704
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
705
+ If unset, will use same as the input image.
706
+ input_data_format (`str` or `ChannelDimension`, *optional*):
707
+ The channel dimension format for the input image. Can be one of:
708
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
709
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
710
+ If unset, will use the inferred format of the input image.
711
+
712
+ Returns:
713
+ `np.ndarray`: The padded image.
714
+
715
+ """
716
+ if input_data_format is None:
717
+ input_data_format = infer_channel_dimension_format(image)
718
+
719
+ def _expand_for_data_format(values):
720
+ """
721
+ Convert values to be in the format expected by np.pad based on the data format.
722
+ """
723
+ if isinstance(values, (int, float)):
724
+ values = ((values, values), (values, values))
725
+ elif isinstance(values, tuple) and len(values) == 1:
726
+ values = ((values[0], values[0]), (values[0], values[0]))
727
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
728
+ values = (values, values)
729
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
730
+ values = values
731
+ else:
732
+ raise ValueError(f"Unsupported format: {values}")
733
+
734
+ # add 0 for channel dimension
735
+ values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
736
+
737
+ # Add additional padding if there's a batch dimension
738
+ values = (0, *values) if image.ndim == 4 else values
739
+ return values
740
+
741
+ padding = _expand_for_data_format(padding)
742
+
743
+ if mode == PaddingMode.CONSTANT:
744
+ constant_values = _expand_for_data_format(constant_values)
745
+ image = np.pad(image, padding, mode="constant", constant_values=constant_values)
746
+ elif mode == PaddingMode.REFLECT:
747
+ image = np.pad(image, padding, mode="reflect")
748
+ elif mode == PaddingMode.REPLICATE:
749
+ image = np.pad(image, padding, mode="edge")
750
+ elif mode == PaddingMode.SYMMETRIC:
751
+ image = np.pad(image, padding, mode="symmetric")
752
+ else:
753
+ raise ValueError(f"Invalid padding mode: {mode}")
754
+
755
+ image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
756
+ return image
757
+
758
+
759
+ # TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
760
+ def convert_to_rgb(image: ImageInput) -> ImageInput:
761
+ """
762
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
763
+ as is.
764
+ Args:
765
+ image (Image):
766
+ The image to convert.
767
+ """
768
+ requires_backends(convert_to_rgb, ["vision"])
769
+
770
+ if not isinstance(image, PIL.Image.Image):
771
+ return image
772
+
773
+ if image.mode == "RGB":
774
+ return image
775
+
776
+ image = image.convert("RGB")
777
+ return image
778
+
779
+
780
+ def flip_channel_order(
781
+ image: np.ndarray,
782
+ data_format: Optional[ChannelDimension] = None,
783
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
784
+ ) -> np.ndarray:
785
+ """
786
+ Flips the channel order of the image.
787
+
788
+ If the image is in RGB format, it will be converted to BGR and vice versa.
789
+
790
+ Args:
791
+ image (`np.ndarray`):
792
+ The image to flip.
793
+ data_format (`ChannelDimension`, *optional*):
794
+ The channel dimension format for the output image. Can be one of:
795
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
796
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
797
+ If unset, will use same as the input image.
798
+ input_data_format (`ChannelDimension`, *optional*):
799
+ The channel dimension format for the input image. Can be one of:
800
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
801
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
802
+ If unset, will use the inferred format of the input image.
803
+ """
804
+ input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
805
+
806
+ if input_data_format == ChannelDimension.LAST:
807
+ image = image[..., ::-1]
808
+ elif input_data_format == ChannelDimension.FIRST:
809
+ image = image[::-1, ...]
810
+ else:
811
+ raise ValueError(f"Unsupported channel dimension: {input_data_format}")
812
+
813
+ if data_format is not None:
814
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
815
+ return image
816
+
817
+
818
+ def _cast_tensor_to_float(x):
819
+ if x.is_floating_point():
820
+ return x
821
+ return x.float()
822
+
823
+
824
+ class FusedRescaleNormalize:
825
+ """
826
+ Rescale and normalize the input image in one step.
827
+ """
828
+
829
+ def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
830
+ self.mean = torch.tensor(mean) * (1.0 / rescale_factor)
831
+ self.std = torch.tensor(std) * (1.0 / rescale_factor)
832
+ self.inplace = inplace
833
+
834
+ def __call__(self, image: "torch.Tensor"):
835
+ image = _cast_tensor_to_float(image)
836
+ return F.normalize(image, self.mean, self.std, inplace=self.inplace)
837
+
838
+
839
+ class Rescale:
840
+ """
841
+ Rescale the input image by rescale factor: image *= rescale_factor.
842
+ """
843
+
844
+ def __init__(self, rescale_factor: float = 1.0):
845
+ self.rescale_factor = rescale_factor
846
+
847
+ def __call__(self, image: "torch.Tensor"):
848
+ image = image * self.rescale_factor
849
+ return image
850
+
851
+
852
+ class NumpyToTensor:
853
+ """
854
+ Convert a numpy array to a PyTorch tensor.
855
+ """
856
+
857
+ def __call__(self, image: np.ndarray):
858
+ # Same as in PyTorch, we assume incoming numpy images are in HWC format
859
+ # c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
860
+ return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()
.venv/lib/python3.11/site-packages/transformers/image_utils.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import base64
17
+ import os
18
+ from io import BytesIO
19
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import requests
23
+ from packaging import version
24
+
25
+ from .utils import (
26
+ ExplicitEnum,
27
+ TensorType,
28
+ is_jax_tensor,
29
+ is_numpy_array,
30
+ is_tf_tensor,
31
+ is_torch_available,
32
+ is_torch_tensor,
33
+ is_torchvision_available,
34
+ is_vision_available,
35
+ logging,
36
+ requires_backends,
37
+ to_numpy,
38
+ )
39
+ from .utils.constants import ( # noqa: F401
40
+ IMAGENET_DEFAULT_MEAN,
41
+ IMAGENET_DEFAULT_STD,
42
+ IMAGENET_STANDARD_MEAN,
43
+ IMAGENET_STANDARD_STD,
44
+ OPENAI_CLIP_MEAN,
45
+ OPENAI_CLIP_STD,
46
+ )
47
+
48
+
49
+ if is_vision_available():
50
+ import PIL.Image
51
+ import PIL.ImageOps
52
+
53
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
54
+ PILImageResampling = PIL.Image.Resampling
55
+ else:
56
+ PILImageResampling = PIL.Image
57
+
58
+ if is_torchvision_available():
59
+ from torchvision.transforms import InterpolationMode
60
+
61
+ pil_torch_interpolation_mapping = {
62
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
63
+ PILImageResampling.BOX: InterpolationMode.BOX,
64
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
65
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
66
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
67
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
68
+ }
69
+
70
+
71
+ if TYPE_CHECKING:
72
+ if is_torch_available():
73
+ import torch
74
+
75
+
76
+ logger = logging.get_logger(__name__)
77
+
78
+
79
+ ImageInput = Union[
80
+ "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
81
+ ] # noqa
82
+
83
+
84
+ VideoInput = Union[
85
+ List["PIL.Image.Image"],
86
+ "np.ndarray",
87
+ "torch.Tensor",
88
+ List["np.ndarray"],
89
+ List["torch.Tensor"],
90
+ List[List["PIL.Image.Image"]],
91
+ List[List["np.ndarrray"]],
92
+ List[List["torch.Tensor"]],
93
+ ] # noqa
94
+
95
+
96
+ class ChannelDimension(ExplicitEnum):
97
+ FIRST = "channels_first"
98
+ LAST = "channels_last"
99
+
100
+
101
+ class AnnotationFormat(ExplicitEnum):
102
+ COCO_DETECTION = "coco_detection"
103
+ COCO_PANOPTIC = "coco_panoptic"
104
+
105
+
106
+ class AnnotionFormat(ExplicitEnum):
107
+ COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
108
+ COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
109
+
110
+
111
+ AnnotationType = Dict[str, Union[int, str, List[Dict]]]
112
+
113
+
114
+ def is_pil_image(img):
115
+ return is_vision_available() and isinstance(img, PIL.Image.Image)
116
+
117
+
118
+ class ImageType(ExplicitEnum):
119
+ PIL = "pillow"
120
+ TORCH = "torch"
121
+ NUMPY = "numpy"
122
+ TENSORFLOW = "tensorflow"
123
+ JAX = "jax"
124
+
125
+
126
+ def get_image_type(image):
127
+ if is_pil_image(image):
128
+ return ImageType.PIL
129
+ if is_torch_tensor(image):
130
+ return ImageType.TORCH
131
+ if is_numpy_array(image):
132
+ return ImageType.NUMPY
133
+ if is_tf_tensor(image):
134
+ return ImageType.TENSORFLOW
135
+ if is_jax_tensor(image):
136
+ return ImageType.JAX
137
+ raise ValueError(f"Unrecognised image type {type(image)}")
138
+
139
+
140
+ def is_valid_image(img):
141
+ return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
142
+
143
+
144
+ def valid_images(imgs):
145
+ # If we have an list of images, make sure every image is valid
146
+ if isinstance(imgs, (list, tuple)):
147
+ for img in imgs:
148
+ if not valid_images(img):
149
+ return False
150
+ # If not a list of tuple, we have been given a single image or batched tensor of images
151
+ elif not is_valid_image(imgs):
152
+ return False
153
+ return True
154
+
155
+
156
+ def is_batched(img):
157
+ if isinstance(img, (list, tuple)):
158
+ return is_valid_image(img[0])
159
+ return False
160
+
161
+
162
+ def is_scaled_image(image: np.ndarray) -> bool:
163
+ """
164
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
165
+ """
166
+ if image.dtype == np.uint8:
167
+ return False
168
+
169
+ # It's possible the image has pixel values in [0, 255] but is of floating type
170
+ return np.min(image) >= 0 and np.max(image) <= 1
171
+
172
+
173
+ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
174
+ """
175
+ Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
176
+ If the input is a batch of images, it is converted to a list of images.
177
+
178
+ Args:
179
+ images (`ImageInput`):
180
+ Image of images to turn into a list of images.
181
+ expected_ndims (`int`, *optional*, defaults to 3):
182
+ Expected number of dimensions for a single input image. If the input image has a different number of
183
+ dimensions, an error is raised.
184
+ """
185
+ if is_batched(images):
186
+ return images
187
+
188
+ # Either the input is a single image, in which case we create a list of length 1
189
+ if isinstance(images, PIL.Image.Image):
190
+ # PIL images are never batched
191
+ return [images]
192
+
193
+ if is_valid_image(images):
194
+ if images.ndim == expected_ndims + 1:
195
+ # Batch of images
196
+ images = list(images)
197
+ elif images.ndim == expected_ndims:
198
+ # Single image
199
+ images = [images]
200
+ else:
201
+ raise ValueError(
202
+ f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
203
+ f" {images.ndim} dimensions."
204
+ )
205
+ return images
206
+ raise ValueError(
207
+ "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
208
+ f"jax.ndarray, but got {type(images)}."
209
+ )
210
+
211
+
212
+ def to_numpy_array(img) -> np.ndarray:
213
+ if not is_valid_image(img):
214
+ raise ValueError(f"Invalid image type: {type(img)}")
215
+
216
+ if is_vision_available() and isinstance(img, PIL.Image.Image):
217
+ return np.array(img)
218
+ return to_numpy(img)
219
+
220
+
221
+ def infer_channel_dimension_format(
222
+ image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
223
+ ) -> ChannelDimension:
224
+ """
225
+ Infers the channel dimension format of `image`.
226
+
227
+ Args:
228
+ image (`np.ndarray`):
229
+ The image to infer the channel dimension of.
230
+ num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
231
+ The number of channels of the image.
232
+
233
+ Returns:
234
+ The channel dimension of the image.
235
+ """
236
+ num_channels = num_channels if num_channels is not None else (1, 3)
237
+ num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
238
+
239
+ if image.ndim == 3:
240
+ first_dim, last_dim = 0, 2
241
+ elif image.ndim == 4:
242
+ first_dim, last_dim = 1, 3
243
+ else:
244
+ raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
245
+
246
+ if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
247
+ logger.warning(
248
+ f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
249
+ )
250
+ return ChannelDimension.FIRST
251
+ elif image.shape[first_dim] in num_channels:
252
+ return ChannelDimension.FIRST
253
+ elif image.shape[last_dim] in num_channels:
254
+ return ChannelDimension.LAST
255
+ raise ValueError("Unable to infer channel dimension format")
256
+
257
+
258
+ def get_channel_dimension_axis(
259
+ image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
260
+ ) -> int:
261
+ """
262
+ Returns the channel dimension axis of the image.
263
+
264
+ Args:
265
+ image (`np.ndarray`):
266
+ The image to get the channel dimension axis of.
267
+ input_data_format (`ChannelDimension` or `str`, *optional*):
268
+ The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
269
+
270
+ Returns:
271
+ The channel dimension axis of the image.
272
+ """
273
+ if input_data_format is None:
274
+ input_data_format = infer_channel_dimension_format(image)
275
+ if input_data_format == ChannelDimension.FIRST:
276
+ return image.ndim - 3
277
+ elif input_data_format == ChannelDimension.LAST:
278
+ return image.ndim - 1
279
+ raise ValueError(f"Unsupported data format: {input_data_format}")
280
+
281
+
282
+ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
283
+ """
284
+ Returns the (height, width) dimensions of the image.
285
+
286
+ Args:
287
+ image (`np.ndarray`):
288
+ The image to get the dimensions of.
289
+ channel_dim (`ChannelDimension`, *optional*):
290
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
291
+
292
+ Returns:
293
+ A tuple of the image's height and width.
294
+ """
295
+ if channel_dim is None:
296
+ channel_dim = infer_channel_dimension_format(image)
297
+
298
+ if channel_dim == ChannelDimension.FIRST:
299
+ return image.shape[-2], image.shape[-1]
300
+ elif channel_dim == ChannelDimension.LAST:
301
+ return image.shape[-3], image.shape[-2]
302
+ else:
303
+ raise ValueError(f"Unsupported data format: {channel_dim}")
304
+
305
+
306
+ def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
307
+ if (
308
+ isinstance(annotation, dict)
309
+ and "image_id" in annotation
310
+ and "annotations" in annotation
311
+ and isinstance(annotation["annotations"], (list, tuple))
312
+ and (
313
+ # an image can have no annotations
314
+ len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
315
+ )
316
+ ):
317
+ return True
318
+ return False
319
+
320
+
321
+ def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
322
+ if (
323
+ isinstance(annotation, dict)
324
+ and "image_id" in annotation
325
+ and "segments_info" in annotation
326
+ and "file_name" in annotation
327
+ and isinstance(annotation["segments_info"], (list, tuple))
328
+ and (
329
+ # an image can have no segments
330
+ len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
331
+ )
332
+ ):
333
+ return True
334
+ return False
335
+
336
+
337
+ def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
338
+ return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
339
+
340
+
341
+ def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
342
+ return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
343
+
344
+
345
+ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
346
+ """
347
+ Loads `image` to a PIL Image.
348
+
349
+ Args:
350
+ image (`str` or `PIL.Image.Image`):
351
+ The image to convert to the PIL Image format.
352
+ timeout (`float`, *optional*):
353
+ The timeout value in seconds for the URL request.
354
+
355
+ Returns:
356
+ `PIL.Image.Image`: A PIL Image.
357
+ """
358
+ requires_backends(load_image, ["vision"])
359
+ if isinstance(image, str):
360
+ if image.startswith("http://") or image.startswith("https://"):
361
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
362
+ # like http_huggingface_co.png
363
+ image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
364
+ elif os.path.isfile(image):
365
+ image = PIL.Image.open(image)
366
+ else:
367
+ if image.startswith("data:image/"):
368
+ image = image.split(",")[1]
369
+
370
+ # Try to load as base64
371
+ try:
372
+ b64 = base64.decodebytes(image.encode())
373
+ image = PIL.Image.open(BytesIO(b64))
374
+ except Exception as e:
375
+ raise ValueError(
376
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
377
+ )
378
+ elif isinstance(image, PIL.Image.Image):
379
+ image = image
380
+ else:
381
+ raise TypeError(
382
+ "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
383
+ )
384
+ image = PIL.ImageOps.exif_transpose(image)
385
+ image = image.convert("RGB")
386
+ return image
387
+
388
+
389
+ def load_images(
390
+ images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
391
+ ) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
392
+ """Loads images, handling different levels of nesting.
393
+
394
+ Args:
395
+ images: A single image, a list of images, or a list of lists of images to load.
396
+ timeout: Timeout for loading images.
397
+
398
+ Returns:
399
+ A single image, a list of images, a list of lists of images.
400
+ """
401
+ if isinstance(images, (list, tuple)):
402
+ if len(images) and isinstance(images[0], (list, tuple)):
403
+ return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
404
+ else:
405
+ return [load_image(image, timeout=timeout) for image in images]
406
+ else:
407
+ return load_image(images, timeout=timeout)
408
+
409
+
410
+ def validate_preprocess_arguments(
411
+ do_rescale: Optional[bool] = None,
412
+ rescale_factor: Optional[float] = None,
413
+ do_normalize: Optional[bool] = None,
414
+ image_mean: Optional[Union[float, List[float]]] = None,
415
+ image_std: Optional[Union[float, List[float]]] = None,
416
+ do_pad: Optional[bool] = None,
417
+ size_divisibility: Optional[int] = None,
418
+ do_center_crop: Optional[bool] = None,
419
+ crop_size: Optional[Dict[str, int]] = None,
420
+ do_resize: Optional[bool] = None,
421
+ size: Optional[Dict[str, int]] = None,
422
+ resample: Optional["PILImageResampling"] = None,
423
+ ):
424
+ """
425
+ Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
426
+ Raises `ValueError` if arguments incompatibility is caught.
427
+ Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
428
+ sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
429
+ existing arguments when possible.
430
+
431
+ """
432
+ if do_rescale and rescale_factor is None:
433
+ raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
434
+
435
+ if do_pad and size_divisibility is None:
436
+ # Here, size_divisor might be passed as the value of size
437
+ raise ValueError(
438
+ "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
439
+ )
440
+
441
+ if do_normalize and (image_mean is None or image_std is None):
442
+ raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
443
+
444
+ if do_center_crop and crop_size is None:
445
+ raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
446
+
447
+ if do_resize and (size is None or resample is None):
448
+ raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
449
+
450
+
451
+ def validate_fast_preprocess_arguments(
452
+ do_rescale: Optional[bool] = None,
453
+ rescale_factor: Optional[float] = None,
454
+ do_normalize: Optional[bool] = None,
455
+ image_mean: Optional[Union[float, List[float]]] = None,
456
+ image_std: Optional[Union[float, List[float]]] = None,
457
+ do_pad: Optional[bool] = None,
458
+ size_divisibility: Optional[int] = None,
459
+ do_center_crop: Optional[bool] = None,
460
+ crop_size: Optional[Dict[str, int]] = None,
461
+ do_resize: Optional[bool] = None,
462
+ size: Optional[Dict[str, int]] = None,
463
+ resample: Optional["PILImageResampling"] = None,
464
+ return_tensors: Optional[Union[str, TensorType]] = None,
465
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
466
+ ):
467
+ """
468
+ Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
469
+ Raises `ValueError` if arguments incompatibility is caught.
470
+ """
471
+ validate_preprocess_arguments(
472
+ do_rescale=do_rescale,
473
+ rescale_factor=rescale_factor,
474
+ do_normalize=do_normalize,
475
+ image_mean=image_mean,
476
+ image_std=image_std,
477
+ do_resize=do_resize,
478
+ size=size,
479
+ resample=resample,
480
+ )
481
+ # Extra checks for ImageProcessorFast
482
+ if return_tensors != "pt":
483
+ raise ValueError("Only returning PyTorch tensors is currently supported.")
484
+
485
+ if data_format != ChannelDimension.FIRST:
486
+ raise ValueError("Only channel first data format is currently supported.")
487
+
488
+
489
+ # In the future we can add a TF implementation here when we have TF models.
490
+ class ImageFeatureExtractionMixin:
491
+ """
492
+ Mixin that contain utilities for preparing image features.
493
+ """
494
+
495
+ def _ensure_format_supported(self, image):
496
+ if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
497
+ raise ValueError(
498
+ f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
499
+ "`torch.Tensor` are."
500
+ )
501
+
502
+ def to_pil_image(self, image, rescale=None):
503
+ """
504
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
505
+ needed.
506
+
507
+ Args:
508
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
509
+ The image to convert to the PIL Image format.
510
+ rescale (`bool`, *optional*):
511
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
512
+ default to `True` if the image type is a floating type, `False` otherwise.
513
+ """
514
+ self._ensure_format_supported(image)
515
+
516
+ if is_torch_tensor(image):
517
+ image = image.numpy()
518
+
519
+ if isinstance(image, np.ndarray):
520
+ if rescale is None:
521
+ # rescale default to the array being of floating type.
522
+ rescale = isinstance(image.flat[0], np.floating)
523
+ # If the channel as been moved to first dim, we put it back at the end.
524
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
525
+ image = image.transpose(1, 2, 0)
526
+ if rescale:
527
+ image = image * 255
528
+ image = image.astype(np.uint8)
529
+ return PIL.Image.fromarray(image)
530
+ return image
531
+
532
+ def convert_rgb(self, image):
533
+ """
534
+ Converts `PIL.Image.Image` to RGB format.
535
+
536
+ Args:
537
+ image (`PIL.Image.Image`):
538
+ The image to convert.
539
+ """
540
+ self._ensure_format_supported(image)
541
+ if not isinstance(image, PIL.Image.Image):
542
+ return image
543
+
544
+ return image.convert("RGB")
545
+
546
+ def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
547
+ """
548
+ Rescale a numpy image by scale amount
549
+ """
550
+ self._ensure_format_supported(image)
551
+ return image * scale
552
+
553
+ def to_numpy_array(self, image, rescale=None, channel_first=True):
554
+ """
555
+ Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
556
+ dimension.
557
+
558
+ Args:
559
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
560
+ The image to convert to a NumPy array.
561
+ rescale (`bool`, *optional*):
562
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
563
+ default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
564
+ channel_first (`bool`, *optional*, defaults to `True`):
565
+ Whether or not to permute the dimensions of the image to put the channel dimension first.
566
+ """
567
+ self._ensure_format_supported(image)
568
+
569
+ if isinstance(image, PIL.Image.Image):
570
+ image = np.array(image)
571
+
572
+ if is_torch_tensor(image):
573
+ image = image.numpy()
574
+
575
+ rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
576
+
577
+ if rescale:
578
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
579
+
580
+ if channel_first and image.ndim == 3:
581
+ image = image.transpose(2, 0, 1)
582
+
583
+ return image
584
+
585
+ def expand_dims(self, image):
586
+ """
587
+ Expands 2-dimensional `image` to 3 dimensions.
588
+
589
+ Args:
590
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
591
+ The image to expand.
592
+ """
593
+ self._ensure_format_supported(image)
594
+
595
+ # Do nothing if PIL image
596
+ if isinstance(image, PIL.Image.Image):
597
+ return image
598
+
599
+ if is_torch_tensor(image):
600
+ image = image.unsqueeze(0)
601
+ else:
602
+ image = np.expand_dims(image, axis=0)
603
+ return image
604
+
605
+ def normalize(self, image, mean, std, rescale=False):
606
+ """
607
+ Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
608
+ if it's a PIL Image.
609
+
610
+ Args:
611
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
612
+ The image to normalize.
613
+ mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
614
+ The mean (per channel) to use for normalization.
615
+ std (`List[float]` or `np.ndarray` or `torch.Tensor`):
616
+ The standard deviation (per channel) to use for normalization.
617
+ rescale (`bool`, *optional*, defaults to `False`):
618
+ Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
619
+ happen automatically.
620
+ """
621
+ self._ensure_format_supported(image)
622
+
623
+ if isinstance(image, PIL.Image.Image):
624
+ image = self.to_numpy_array(image, rescale=True)
625
+ # If the input image is a PIL image, it automatically gets rescaled. If it's another
626
+ # type it may need rescaling.
627
+ elif rescale:
628
+ if isinstance(image, np.ndarray):
629
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
630
+ elif is_torch_tensor(image):
631
+ image = self.rescale(image.float(), 1 / 255.0)
632
+
633
+ if isinstance(image, np.ndarray):
634
+ if not isinstance(mean, np.ndarray):
635
+ mean = np.array(mean).astype(image.dtype)
636
+ if not isinstance(std, np.ndarray):
637
+ std = np.array(std).astype(image.dtype)
638
+ elif is_torch_tensor(image):
639
+ import torch
640
+
641
+ if not isinstance(mean, torch.Tensor):
642
+ if isinstance(mean, np.ndarray):
643
+ mean = torch.from_numpy(mean)
644
+ else:
645
+ mean = torch.tensor(mean)
646
+ if not isinstance(std, torch.Tensor):
647
+ if isinstance(std, np.ndarray):
648
+ std = torch.from_numpy(std)
649
+ else:
650
+ std = torch.tensor(std)
651
+
652
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
653
+ return (image - mean[:, None, None]) / std[:, None, None]
654
+ else:
655
+ return (image - mean) / std
656
+
657
+ def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
658
+ """
659
+ Resizes `image`. Enforces conversion of input to PIL.Image.
660
+
661
+ Args:
662
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
663
+ The image to resize.
664
+ size (`int` or `Tuple[int, int]`):
665
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
666
+ matched to this.
667
+
668
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
669
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
670
+ this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
671
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
672
+ The filter to user for resampling.
673
+ default_to_square (`bool`, *optional*, defaults to `True`):
674
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
675
+ square (`size`,`size`). If set to `False`, will replicate
676
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
677
+ with support for resizing only the smallest edge and providing an optional `max_size`.
678
+ max_size (`int`, *optional*, defaults to `None`):
679
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
680
+ greater than `max_size` after being resized according to `size`, then the image is resized again so
681
+ that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
682
+ edge may be shorter than `size`. Only used if `default_to_square` is `False`.
683
+
684
+ Returns:
685
+ image: A resized `PIL.Image.Image`.
686
+ """
687
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
688
+
689
+ self._ensure_format_supported(image)
690
+
691
+ if not isinstance(image, PIL.Image.Image):
692
+ image = self.to_pil_image(image)
693
+
694
+ if isinstance(size, list):
695
+ size = tuple(size)
696
+
697
+ if isinstance(size, int) or len(size) == 1:
698
+ if default_to_square:
699
+ size = (size, size) if isinstance(size, int) else (size[0], size[0])
700
+ else:
701
+ width, height = image.size
702
+ # specified size only for the smallest edge
703
+ short, long = (width, height) if width <= height else (height, width)
704
+ requested_new_short = size if isinstance(size, int) else size[0]
705
+
706
+ if short == requested_new_short:
707
+ return image
708
+
709
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
710
+
711
+ if max_size is not None:
712
+ if max_size <= requested_new_short:
713
+ raise ValueError(
714
+ f"max_size = {max_size} must be strictly greater than the requested "
715
+ f"size for the smaller edge size = {size}"
716
+ )
717
+ if new_long > max_size:
718
+ new_short, new_long = int(max_size * new_short / new_long), max_size
719
+
720
+ size = (new_short, new_long) if width <= height else (new_long, new_short)
721
+
722
+ return image.resize(size, resample=resample)
723
+
724
+ def center_crop(self, image, size):
725
+ """
726
+ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
727
+ size given, it will be padded (so the returned result has the size asked).
728
+
729
+ Args:
730
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
731
+ The image to resize.
732
+ size (`int` or `Tuple[int, int]`):
733
+ The size to which crop the image.
734
+
735
+ Returns:
736
+ new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
737
+ height, width).
738
+ """
739
+ self._ensure_format_supported(image)
740
+
741
+ if not isinstance(size, tuple):
742
+ size = (size, size)
743
+
744
+ # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
745
+ if is_torch_tensor(image) or isinstance(image, np.ndarray):
746
+ if image.ndim == 2:
747
+ image = self.expand_dims(image)
748
+ image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
749
+ else:
750
+ image_shape = (image.size[1], image.size[0])
751
+
752
+ top = (image_shape[0] - size[0]) // 2
753
+ bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
754
+ left = (image_shape[1] - size[1]) // 2
755
+ right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
756
+
757
+ # For PIL Images we have a method to crop directly.
758
+ if isinstance(image, PIL.Image.Image):
759
+ return image.crop((left, top, right, bottom))
760
+
761
+ # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
762
+ channel_first = True if image.shape[0] in [1, 3] else False
763
+
764
+ # Transpose (height, width, n_channels) format images
765
+ if not channel_first:
766
+ if isinstance(image, np.ndarray):
767
+ image = image.transpose(2, 0, 1)
768
+ if is_torch_tensor(image):
769
+ image = image.permute(2, 0, 1)
770
+
771
+ # Check if cropped area is within image boundaries
772
+ if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
773
+ return image[..., top:bottom, left:right]
774
+
775
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
776
+ new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
777
+ if isinstance(image, np.ndarray):
778
+ new_image = np.zeros_like(image, shape=new_shape)
779
+ elif is_torch_tensor(image):
780
+ new_image = image.new_zeros(new_shape)
781
+
782
+ top_pad = (new_shape[-2] - image_shape[0]) // 2
783
+ bottom_pad = top_pad + image_shape[0]
784
+ left_pad = (new_shape[-1] - image_shape[1]) // 2
785
+ right_pad = left_pad + image_shape[1]
786
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
787
+
788
+ top += top_pad
789
+ bottom += top_pad
790
+ left += left_pad
791
+ right += left_pad
792
+
793
+ new_image = new_image[
794
+ ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
795
+ ]
796
+
797
+ return new_image
798
+
799
+ def flip_channel_order(self, image):
800
+ """
801
+ Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
802
+ `image` to a NumPy array if it's a PIL Image.
803
+
804
+ Args:
805
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
806
+ The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
807
+ be first.
808
+ """
809
+ self._ensure_format_supported(image)
810
+
811
+ if isinstance(image, PIL.Image.Image):
812
+ image = self.to_numpy_array(image)
813
+
814
+ return image[::-1, :, :]
815
+
816
+ def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
817
+ """
818
+ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
819
+ counter clockwise around its centre.
820
+
821
+ Args:
822
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
823
+ The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
824
+ rotating.
825
+
826
+ Returns:
827
+ image: A rotated `PIL.Image.Image`.
828
+ """
829
+ resample = resample if resample is not None else PIL.Image.NEAREST
830
+
831
+ self._ensure_format_supported(image)
832
+
833
+ if not isinstance(image, PIL.Image.Image):
834
+ image = self.to_pil_image(image)
835
+
836
+ return image.rotate(
837
+ angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
838
+ )
839
+
840
+
841
+ def validate_annotations(
842
+ annotation_format: AnnotationFormat,
843
+ supported_annotation_formats: Tuple[AnnotationFormat, ...],
844
+ annotations: List[Dict],
845
+ ) -> None:
846
+ if annotation_format not in supported_annotation_formats:
847
+ raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
848
+
849
+ if annotation_format is AnnotationFormat.COCO_DETECTION:
850
+ if not valid_coco_detection_annotations(annotations):
851
+ raise ValueError(
852
+ "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
853
+ "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
854
+ "being a list of annotations in the COCO format."
855
+ )
856
+
857
+ if annotation_format is AnnotationFormat.COCO_PANOPTIC:
858
+ if not valid_coco_panoptic_annotations(annotations):
859
+ raise ValueError(
860
+ "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
861
+ "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
862
+ "the latter being a list of annotations in the COCO format."
863
+ )
864
+
865
+
866
+ def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
867
+ unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
868
+ if unused_keys:
869
+ unused_key_str = ", ".join(unused_keys)
870
+ # TODO raise a warning here instead of simply logging?
871
+ logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
.venv/lib/python3.11/site-packages/transformers/keras_callbacks.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from time import sleep
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import tensorflow as tf
9
+ from huggingface_hub import Repository, create_repo
10
+ from packaging.version import parse
11
+
12
+ from . import IntervalStrategy, PreTrainedTokenizerBase
13
+ from .modelcard import TrainingSummary
14
+ from .modeling_tf_utils import keras
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class KerasMetricCallback(keras.callbacks.Callback):
21
+ """
22
+ Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
23
+ compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
24
+ operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the
25
+ `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute
26
+ metrics and return a dict mapping metric names to metric values.
27
+
28
+ We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that
29
+ this example skips some post-processing for readability and simplicity, and should probably not be used as-is!
30
+
31
+ ```py
32
+ from datasets import load_metric
33
+
34
+ rouge_metric = load_metric("rouge")
35
+
36
+
37
+ def rouge_fn(predictions, labels):
38
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
39
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
40
+ result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
41
+ return {key: value.mid.fmeasure * 100 for key, value in result.items()}
42
+ ```
43
+
44
+ The above function will return a dict containing values which will be logged like any other Keras metric:
45
+
46
+ ```
47
+ {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781
48
+ ```
49
+
50
+ Args:
51
+ metric_fn (`Callable`):
52
+ Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`.
53
+ These contain the model's outputs and matching labels from the dataset. It should return a dict mapping
54
+ metric names to numerical values.
55
+ eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
56
+ Validation data to be used to generate predictions for the `metric_fn`.
57
+ output_cols (`List[str], *optional*):
58
+ A list of columns to be retained from the model output as the predictions. Defaults to all.
59
+ label_cols ('`List[str]`, *optional*'):
60
+ A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not
61
+ supplied.
62
+ batch_size (`int`, *optional*):
63
+ Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
64
+ predict_with_generate (`bool`, *optional*, defaults to `False`):
65
+ Whether we should use `model.generate()` to get outputs for the model.
66
+ use_xla_generation (`bool`, *optional*, defaults to `False`):
67
+ If we're generating, whether to compile model generation with XLA. This can massively increase the speed of
68
+ generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA
69
+ generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of`
70
+ argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and
71
+ save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`.
72
+ generate_kwargs (`dict`, *optional*):
73
+ Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate`
74
+ is `False`.
75
+
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ metric_fn: Callable,
81
+ eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
82
+ output_cols: Optional[List[str]] = None,
83
+ label_cols: Optional[List[str]] = None,
84
+ batch_size: Optional[int] = None,
85
+ predict_with_generate: bool = False,
86
+ use_xla_generation: bool = False,
87
+ generate_kwargs: Optional[dict] = None,
88
+ ):
89
+ super().__init__()
90
+ self.metric_fn = metric_fn
91
+ self.batch_size = batch_size
92
+ if not isinstance(eval_dataset, tf.data.Dataset):
93
+ if batch_size is None:
94
+ raise ValueError(
95
+ "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset "
96
+ "the batch_size argument must be set."
97
+ )
98
+ # Wrap a tf.data.Dataset around it
99
+ eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False)
100
+ self.eval_dataset = eval_dataset
101
+ self.predict_with_generate = predict_with_generate
102
+ self.output_cols = output_cols
103
+
104
+ # This next block attempts to parse out which elements of the dataset should be appended to the labels list
105
+ # that is passed to the metric_fn
106
+ if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2:
107
+ input_spec, label_spec = eval_dataset.element_spec
108
+ else:
109
+ input_spec = eval_dataset.element_spec
110
+ label_spec = None
111
+ if label_cols is not None:
112
+ for label in label_cols:
113
+ if label not in input_spec:
114
+ raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!")
115
+ self.label_cols = label_cols
116
+ self.use_keras_label = False
117
+ elif label_spec is not None:
118
+ # If the dataset inputs are split into a 2-tuple of inputs and labels,
119
+ # assume the second element is the labels
120
+ self.label_cols = None
121
+ self.use_keras_label = True
122
+ elif "labels" in input_spec:
123
+ self.label_cols = ["labels"]
124
+ self.use_keras_label = False
125
+ logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
126
+ elif "start_positions" in input_spec and "end_positions" in input_spec:
127
+ self.label_cols = ["start_positions", "end_positions"]
128
+ self.use_keras_label = False
129
+ logging.warning(
130
+ "No label_cols specified for KerasMetricCallback, assuming you want the "
131
+ "start_positions and end_positions keys."
132
+ )
133
+ else:
134
+ raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
135
+ if parse(tf.__version__) < parse("2.7"):
136
+ logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
137
+
138
+ self.use_xla_generation = use_xla_generation
139
+ self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs
140
+
141
+ self.generation_function = None
142
+
143
+ @staticmethod
144
+ def _concatenate_batches(batches, padding_index=-100):
145
+ # If all batches are unidimensional or same length, do a simple concatenation
146
+ if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches):
147
+ return np.concatenate(batches, axis=0)
148
+
149
+ # Welp, they're not the same length. Let's do some padding
150
+ max_len = max([batch.shape[1] for batch in batches])
151
+ num_samples = sum([batch.shape[0] for batch in batches])
152
+ output = np.full_like(
153
+ batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
154
+ )
155
+ # i keeps track of which part of the concatenated array we're writing the next batch to
156
+ i = 0
157
+ for batch in batches:
158
+ output[i : i + len(batch), : batch.shape[1]] = batch
159
+ i += len(batch)
160
+ return output
161
+
162
+ def _postprocess_predictions_or_labels(self, inputs):
163
+ if isinstance(inputs[0], dict):
164
+ outputs = {}
165
+ for key in inputs[0].keys():
166
+ outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
167
+ # If it's a dict with only one key, just return the array
168
+ if len(outputs) == 1:
169
+ outputs = list(outputs.values())[0]
170
+ elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
171
+ outputs = []
172
+ for input_list in zip(*inputs):
173
+ outputs.append(self._concatenate_batches(input_list))
174
+ if len(outputs) == 1:
175
+ outputs = outputs[0] # If it's a list with only one element, just return the array
176
+ elif isinstance(inputs[0], np.ndarray):
177
+ outputs = self._concatenate_batches(inputs)
178
+ elif isinstance(inputs[0], tf.Tensor):
179
+ outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
180
+ else:
181
+ raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
182
+ return outputs
183
+
184
+ def on_epoch_end(self, epoch, logs=None):
185
+ if hasattr(self.model, "config"):
186
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
187
+ else:
188
+ ignore_keys = []
189
+
190
+ main_input_name = None
191
+ if self.predict_with_generate:
192
+ # This dense conditional recognizes the case where we have an encoder-decoder model, but
193
+ # avoids getting tangled up when we just have a model with a layer called 'encoder'
194
+ if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
195
+ main_input_name = self.model.encoder.main_input_name
196
+ else:
197
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
198
+
199
+ if self.use_xla_generation and self.generation_function is None:
200
+
201
+ def generation_function(inputs, attention_mask):
202
+ return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)
203
+
204
+ self.generation_function = tf.function(generation_function, jit_compile=True)
205
+
206
+ prediction_list = []
207
+ label_list = []
208
+
209
+ # The whole predict/generate loop is handled inside this method
210
+ for batch in self.eval_dataset:
211
+ if isinstance(batch, tuple):
212
+ batch, labels = batch
213
+ else:
214
+ labels = None
215
+ if self.predict_with_generate:
216
+ if isinstance(batch, dict):
217
+ generation_inputs = batch[main_input_name]
218
+ attention_mask = batch.get("attention_mask", None)
219
+ else:
220
+ generation_inputs = batch
221
+ attention_mask = None
222
+ if self.use_xla_generation:
223
+ predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
224
+ else:
225
+ predictions = self.model.generate(
226
+ generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
227
+ )
228
+ else:
229
+ predictions = self.model.predict_on_batch(batch)
230
+ if isinstance(predictions, dict):
231
+ # This converts any dict-subclass to a regular dict
232
+ # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
233
+ predictions = dict(predictions)
234
+ if self.output_cols is not None:
235
+ predictions = {key: predictions[key] for key in self.output_cols}
236
+ else:
237
+ predictions = {
238
+ key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]
239
+ }
240
+ prediction_list.append(predictions)
241
+ if not self.use_keras_label:
242
+ labels = {key: batch[key].numpy() for key in self.label_cols}
243
+ elif isinstance(labels, dict):
244
+ labels = {key: array.numpy() for key, array in labels.items()}
245
+ elif isinstance(labels, list) or isinstance(labels, tuple):
246
+ labels = [array.numpy() for array in labels]
247
+ elif isinstance(labels, tf.Tensor):
248
+ labels = labels.numpy()
249
+ else:
250
+ raise TypeError(f"Confused by labels of type {type(labels)}")
251
+ label_list.append(labels)
252
+
253
+ all_preds = self._postprocess_predictions_or_labels(prediction_list)
254
+ all_labels = self._postprocess_predictions_or_labels(label_list)
255
+
256
+ metric_output = self.metric_fn((all_preds, all_labels))
257
+ if not isinstance(metric_output, dict):
258
+ raise TypeError(
259
+ f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
260
+ )
261
+ # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch
262
+ # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of
263
+ # new keys in there, which will then get read by the History callback and treated like any other metric value.
264
+ # I promise that I have it in writing from Chollet that this is okay.
265
+ logs.update(metric_output)
266
+
267
+
268
+ class PushToHubCallback(keras.callbacks.Callback):
269
+ """
270
+ Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
271
+ be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
272
+ as with the `from_pretrained` method.
273
+
274
+ ```py
275
+ from transformers.keras_callbacks import PushToHubCallback
276
+
277
+ push_to_hub_callback = PushToHubCallback(
278
+ output_dir="./model_save",
279
+ tokenizer=tokenizer,
280
+ hub_model_id="gpt5-7xlarge",
281
+ )
282
+
283
+ model.fit(train_dataset, callbacks=[push_to_hub_callback])
284
+ ```
285
+
286
+ Args:
287
+ output_dir (`str`):
288
+ The output directory where the model predictions and checkpoints will be written and synced with the
289
+ repository on the Hub.
290
+ save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
291
+ The checkpoint save strategy to adopt during training. Possible values are:
292
+
293
+ - `"no"`: Save is done at the end of training.
294
+ - `"epoch"`: Save is done at the end of each epoch.
295
+ - `"steps"`: Save is done every `save_steps`
296
+ save_steps (`int`, *optional*):
297
+ The number of steps between saves when using the "steps" `save_strategy`.
298
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
299
+ The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
300
+ hub_model_id (`str`, *optional*):
301
+ The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
302
+ which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
303
+ for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
304
+ `"organization_name/model"`.
305
+
306
+ Will default to the name of `output_dir`.
307
+ hub_token (`str`, *optional*):
308
+ The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
309
+ `huggingface-cli login`.
310
+ checkpoint (`bool`, *optional*, defaults to `False`):
311
+ Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
312
+ resumed. Only usable when `save_strategy` is `"epoch"`.
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ output_dir: Union[str, Path],
318
+ save_strategy: Union[str, IntervalStrategy] = "epoch",
319
+ save_steps: Optional[int] = None,
320
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
321
+ hub_model_id: Optional[str] = None,
322
+ hub_token: Optional[str] = None,
323
+ checkpoint: bool = False,
324
+ **model_card_args,
325
+ ):
326
+ super().__init__()
327
+ if checkpoint and save_strategy != "epoch":
328
+ raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
329
+ if isinstance(save_strategy, str):
330
+ save_strategy = IntervalStrategy(save_strategy.lower())
331
+ self.save_strategy = save_strategy
332
+ if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
333
+ raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
334
+ self.save_steps = save_steps
335
+ output_dir = Path(output_dir)
336
+
337
+ # Create repo and retrieve repo_id
338
+ if hub_model_id is None:
339
+ hub_model_id = output_dir.absolute().name
340
+ self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
341
+
342
+ self.output_dir = output_dir
343
+ self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
344
+
345
+ self.tokenizer = tokenizer
346
+ self.last_job = None
347
+ self.checkpoint = checkpoint
348
+ self.training_history = None
349
+ self.model_card_args = model_card_args
350
+
351
+ def on_train_begin(self, logs=None):
352
+ # Although we can access model.history, we have no guarantees that the History callback will fire before this
353
+ # one, so we keep track of it here too
354
+ self.training_history = []
355
+
356
+ def on_train_batch_end(self, batch, logs=None):
357
+ if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0:
358
+ if self.last_job is not None and not self.last_job.is_done:
359
+ return # The last upload is still running, don't start another
360
+ self.model.save_pretrained(self.output_dir)
361
+ if self.tokenizer is not None:
362
+ self.tokenizer.save_pretrained(self.output_dir)
363
+ _, self.last_job = self.repo.push_to_hub(
364
+ commit_message=f"Training in progress steps {batch}", blocking=False
365
+ )
366
+
367
+ def on_epoch_end(self, epoch, logs=None):
368
+ logs = logs.copy() # Don't accidentally write things that Keras will read later
369
+ if "epoch" not in logs:
370
+ logs["epoch"] = epoch
371
+ self.training_history.append(logs)
372
+ if self.save_strategy == IntervalStrategy.EPOCH:
373
+ if self.last_job is not None and not self.last_job.is_done:
374
+ return # The last upload is still running, don't start another
375
+ self.model.save_pretrained(self.output_dir)
376
+ if self.tokenizer is not None:
377
+ self.tokenizer.save_pretrained(self.output_dir)
378
+ if self.checkpoint:
379
+ checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
380
+ self.model._save_checkpoint(checkpoint_dir, epoch)
381
+ train_summary = TrainingSummary.from_keras(
382
+ model=self.model,
383
+ model_name=self.hub_model_id,
384
+ keras_history=self.training_history,
385
+ **self.model_card_args,
386
+ )
387
+ model_card = train_summary.to_model_card()
388
+ with (self.output_dir / "README.md").open("w") as f:
389
+ f.write(model_card)
390
+ _, self.last_job = self.repo.push_to_hub(
391
+ commit_message=f"Training in progress epoch {epoch}", blocking=False
392
+ )
393
+
394
+ def on_train_end(self, logs=None):
395
+ # Makes sure the latest version of the model is uploaded
396
+ if self.last_job is not None and not self.last_job.is_done:
397
+ logging.info("Pushing the last epoch to the Hub, this may take a while...")
398
+ while not self.last_job.is_done:
399
+ sleep(1)
400
+ else:
401
+ self.model.save_pretrained(self.output_dir)
402
+ if self.tokenizer is not None:
403
+ self.tokenizer.save_pretrained(self.output_dir)
404
+ train_summary = TrainingSummary.from_keras(
405
+ model=self.model,
406
+ model_name=self.hub_model_id,
407
+ keras_history=self.training_history,
408
+ **self.model_card_args,
409
+ )
410
+ model_card = train_summary.to_model_card()
411
+ with (self.output_dir / "README.md").open("w") as f:
412
+ f.write(model_card)
413
+ self.repo.push_to_hub(commit_message="End of training", blocking=True)
.venv/lib/python3.11/site-packages/transformers/kernels/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/ms_deform_attn.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "cpu/ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "cuda/ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+
20
+ at::Tensor
21
+ ms_deform_attn_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ if (value.type().is_cuda())
30
+ {
31
+ #ifdef WITH_CUDA
32
+ return ms_deform_attn_cuda_forward(
33
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34
+ #else
35
+ AT_ERROR("Not compiled with GPU support");
36
+ #endif
37
+ }
38
+ AT_ERROR("Not implemented on the CPU");
39
+ }
40
+
41
+ std::vector<at::Tensor>
42
+ ms_deform_attn_backward(
43
+ const at::Tensor &value,
44
+ const at::Tensor &spatial_shapes,
45
+ const at::Tensor &level_start_index,
46
+ const at::Tensor &sampling_loc,
47
+ const at::Tensor &attn_weight,
48
+ const at::Tensor &grad_output,
49
+ const int im2col_step)
50
+ {
51
+ if (value.type().is_cuda())
52
+ {
53
+ #ifdef WITH_CUDA
54
+ return ms_deform_attn_cuda_backward(
55
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56
+ #else
57
+ AT_ERROR("Not compiled with GPU support");
58
+ #endif
59
+ }
60
+ AT_ERROR("Not implemented on the CPU");
61
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+
17
+ at::Tensor
18
+ ms_deform_attn_cpu_forward(
19
+ const at::Tensor &value,
20
+ const at::Tensor &spatial_shapes,
21
+ const at::Tensor &level_start_index,
22
+ const at::Tensor &sampling_loc,
23
+ const at::Tensor &attn_weight,
24
+ const int im2col_step)
25
+ {
26
+ AT_ERROR("Not implement on cpu");
27
+ }
28
+
29
+ std::vector<at::Tensor>
30
+ ms_deform_attn_cpu_backward(
31
+ const at::Tensor &value,
32
+ const at::Tensor &spatial_shapes,
33
+ const at::Tensor &level_start_index,
34
+ const at::Tensor &sampling_loc,
35
+ const at::Tensor &attn_weight,
36
+ const at::Tensor &grad_output,
37
+ const int im2col_step)
38
+ {
39
+ AT_ERROR("Not implement on cpu");
40
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor
15
+ ms_deform_attn_cpu_forward(
16
+ const at::Tensor &value,
17
+ const at::Tensor &spatial_shapes,
18
+ const at::Tensor &level_start_index,
19
+ const at::Tensor &sampling_loc,
20
+ const at::Tensor &attn_weight,
21
+ const int im2col_step);
22
+
23
+ std::vector<at::Tensor>
24
+ ms_deform_attn_cpu_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "cuda/ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+ #pragma once
20
+ #include <torch/extension.h>
21
+
22
+
23
+ at::Tensor ms_deform_attn_cuda_forward(
24
+ const at::Tensor &value,
25
+ const at::Tensor &spatial_shapes,
26
+ const at::Tensor &level_start_index,
27
+ const at::Tensor &sampling_loc,
28
+ const at::Tensor &attn_weight,
29
+ const int im2col_step)
30
+ {
31
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
32
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
33
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
34
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
35
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
36
+
37
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
38
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
39
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
40
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
41
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
42
+
43
+ const int batch = value.size(0);
44
+ const int spatial_size = value.size(1);
45
+ const int num_heads = value.size(2);
46
+ const int channels = value.size(3);
47
+
48
+ const int num_levels = spatial_shapes.size(0);
49
+
50
+ const int num_query = sampling_loc.size(1);
51
+ const int num_point = sampling_loc.size(4);
52
+
53
+ const int im2col_step_ = std::min(batch, im2col_step);
54
+
55
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
56
+
57
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
58
+
59
+ const int batch_n = im2col_step_;
60
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
61
+ auto per_value_size = spatial_size * num_heads * channels;
62
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
63
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
64
+ for (int n = 0; n < batch/im2col_step_; ++n)
65
+ {
66
+ auto columns = output_n.select(0, n);
67
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
68
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
69
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
70
+ spatial_shapes.data<int64_t>(),
71
+ level_start_index.data<int64_t>(),
72
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
73
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
74
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
75
+ columns.data<scalar_t>());
76
+
77
+ }));
78
+ }
79
+
80
+ output = output.view({batch, num_query, num_heads*channels});
81
+
82
+ return output;
83
+ }
84
+
85
+
86
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
87
+ const at::Tensor &value,
88
+ const at::Tensor &spatial_shapes,
89
+ const at::Tensor &level_start_index,
90
+ const at::Tensor &sampling_loc,
91
+ const at::Tensor &attn_weight,
92
+ const at::Tensor &grad_output,
93
+ const int im2col_step)
94
+ {
95
+
96
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
97
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
98
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
99
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
100
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
101
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
102
+
103
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
104
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
105
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
106
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
107
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
108
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
109
+
110
+ const int batch = value.size(0);
111
+ const int spatial_size = value.size(1);
112
+ const int num_heads = value.size(2);
113
+ const int channels = value.size(3);
114
+
115
+ const int num_levels = spatial_shapes.size(0);
116
+
117
+ const int num_query = sampling_loc.size(1);
118
+ const int num_point = sampling_loc.size(4);
119
+
120
+ const int im2col_step_ = std::min(batch, im2col_step);
121
+
122
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
123
+
124
+ auto grad_value = at::zeros_like(value);
125
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
126
+ auto grad_attn_weight = at::zeros_like(attn_weight);
127
+
128
+ const int batch_n = im2col_step_;
129
+ auto per_value_size = spatial_size * num_heads * channels;
130
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
131
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
132
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
133
+
134
+ for (int n = 0; n < batch/im2col_step_; ++n)
135
+ {
136
+ auto grad_output_g = grad_output_n.select(0, n);
137
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
138
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
139
+ grad_output_g.data<scalar_t>(),
140
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
141
+ spatial_shapes.data<int64_t>(),
142
+ level_start_index.data<int64_t>(),
143
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
144
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
145
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
146
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
147
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
148
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
149
+
150
+ }));
151
+ }
152
+
153
+ return {
154
+ grad_value, grad_sampling_loc, grad_attn_weight
155
+ };
156
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh ADDED
@@ -0,0 +1,1467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <cuda.h>
14
+ #include <cuda_runtime.h>
15
+
16
+ #include <cstdio>
17
+ #include <algorithm>
18
+ #include <cstring>
19
+
20
+ #include <ATen/ATen.h>
21
+ #include <ATen/cuda/CUDAContext.h>
22
+
23
+ #include <THC/THCAtomics.cuh>
24
+
25
+ #define CUDA_KERNEL_LOOP(i, n) \
26
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
27
+ i < (n); \
28
+ i += blockDim.x * gridDim.x)
29
+
30
+
31
+ at::Tensor ms_deform_attn_cuda_forward(
32
+ const at::Tensor &value,
33
+ const at::Tensor &spatial_shapes,
34
+ const at::Tensor &level_start_index,
35
+ const at::Tensor &sampling_loc,
36
+ const at::Tensor &attn_weight,
37
+ const int im2col_step)
38
+ {
39
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
40
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
41
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
42
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
43
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
44
+
45
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
46
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
47
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
48
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
49
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
50
+
51
+ const int batch = value.size(0);
52
+ const int spatial_size = value.size(1);
53
+ const int num_heads = value.size(2);
54
+ const int channels = value.size(3);
55
+
56
+ const int num_levels = spatial_shapes.size(0);
57
+
58
+ const int num_query = sampling_loc.size(1);
59
+ const int num_point = sampling_loc.size(4);
60
+
61
+ const int im2col_step_ = std::min(batch, im2col_step);
62
+
63
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
64
+
65
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
66
+
67
+ const int batch_n = im2col_step_;
68
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
69
+ auto per_value_size = spatial_size * num_heads * channels;
70
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
71
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
72
+ for (int n = 0; n < batch/im2col_step_; ++n)
73
+ {
74
+ auto columns = output_n.select(0, n);
75
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
76
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
77
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
78
+ spatial_shapes.data<int64_t>(),
79
+ level_start_index.data<int64_t>(),
80
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
81
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
82
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
83
+ columns.data<scalar_t>());
84
+
85
+ }));
86
+ }
87
+
88
+ output = output.view({batch, num_query, num_heads*channels});
89
+
90
+ return output;
91
+ }
92
+
93
+
94
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
95
+ const at::Tensor &value,
96
+ const at::Tensor &spatial_shapes,
97
+ const at::Tensor &level_start_index,
98
+ const at::Tensor &sampling_loc,
99
+ const at::Tensor &attn_weight,
100
+ const at::Tensor &grad_output,
101
+ const int im2col_step)
102
+ {
103
+
104
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
105
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
106
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
107
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
108
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
109
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
110
+
111
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
112
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
113
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
114
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
115
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
116
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
117
+
118
+ const int batch = value.size(0);
119
+ const int spatial_size = value.size(1);
120
+ const int num_heads = value.size(2);
121
+ const int channels = value.size(3);
122
+
123
+ const int num_levels = spatial_shapes.size(0);
124
+
125
+ const int num_query = sampling_loc.size(1);
126
+ const int num_point = sampling_loc.size(4);
127
+
128
+ const int im2col_step_ = std::min(batch, im2col_step);
129
+
130
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
131
+
132
+ auto grad_value = at::zeros_like(value);
133
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
134
+ auto grad_attn_weight = at::zeros_like(attn_weight);
135
+
136
+ const int batch_n = im2col_step_;
137
+ auto per_value_size = spatial_size * num_heads * channels;
138
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
139
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
140
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
141
+
142
+ for (int n = 0; n < batch/im2col_step_; ++n)
143
+ {
144
+ auto grad_output_g = grad_output_n.select(0, n);
145
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
146
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
147
+ grad_output_g.data<scalar_t>(),
148
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
149
+ spatial_shapes.data<int64_t>(),
150
+ level_start_index.data<int64_t>(),
151
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
152
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
153
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
154
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
155
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
156
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
157
+
158
+ }));
159
+ }
160
+
161
+ return {
162
+ grad_value, grad_sampling_loc, grad_attn_weight
163
+ };
164
+ }
165
+
166
+ const int CUDA_NUM_THREADS = 1024;
167
+ inline int GET_BLOCKS(const int N, const int num_threads)
168
+ {
169
+ return (N + num_threads - 1) / num_threads;
170
+ }
171
+
172
+
173
+ template <typename scalar_t>
174
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
175
+ const int &height, const int &width, const int &nheads, const int &channels,
176
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
177
+ {
178
+ const int h_low = floor(h);
179
+ const int w_low = floor(w);
180
+ const int h_high = h_low + 1;
181
+ const int w_high = w_low + 1;
182
+
183
+ const scalar_t lh = h - h_low;
184
+ const scalar_t lw = w - w_low;
185
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
186
+
187
+ const int w_stride = nheads * channels;
188
+ const int h_stride = width * w_stride;
189
+ const int h_low_ptr_offset = h_low * h_stride;
190
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
191
+ const int w_low_ptr_offset = w_low * w_stride;
192
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
193
+ const int base_ptr = m * channels + c;
194
+
195
+ scalar_t v1 = 0;
196
+ if (h_low >= 0 && w_low >= 0)
197
+ {
198
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
199
+ v1 = bottom_data[ptr1];
200
+ }
201
+ scalar_t v2 = 0;
202
+ if (h_low >= 0 && w_high <= width - 1)
203
+ {
204
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
205
+ v2 = bottom_data[ptr2];
206
+ }
207
+ scalar_t v3 = 0;
208
+ if (h_high <= height - 1 && w_low >= 0)
209
+ {
210
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
211
+ v3 = bottom_data[ptr3];
212
+ }
213
+ scalar_t v4 = 0;
214
+ if (h_high <= height - 1 && w_high <= width - 1)
215
+ {
216
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
217
+ v4 = bottom_data[ptr4];
218
+ }
219
+
220
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
221
+
222
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
223
+ return val;
224
+ }
225
+
226
+
227
+ template <typename scalar_t>
228
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
229
+ const int &height, const int &width, const int &nheads, const int &channels,
230
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
231
+ const scalar_t &top_grad,
232
+ const scalar_t &attn_weight,
233
+ scalar_t* &grad_value,
234
+ scalar_t* grad_sampling_loc,
235
+ scalar_t* grad_attn_weight)
236
+ {
237
+ const int h_low = floor(h);
238
+ const int w_low = floor(w);
239
+ const int h_high = h_low + 1;
240
+ const int w_high = w_low + 1;
241
+
242
+ const scalar_t lh = h - h_low;
243
+ const scalar_t lw = w - w_low;
244
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
245
+
246
+ const int w_stride = nheads * channels;
247
+ const int h_stride = width * w_stride;
248
+ const int h_low_ptr_offset = h_low * h_stride;
249
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
250
+ const int w_low_ptr_offset = w_low * w_stride;
251
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
252
+ const int base_ptr = m * channels + c;
253
+
254
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
255
+ const scalar_t top_grad_value = top_grad * attn_weight;
256
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
257
+
258
+ scalar_t v1 = 0;
259
+ if (h_low >= 0 && w_low >= 0)
260
+ {
261
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
262
+ v1 = bottom_data[ptr1];
263
+ grad_h_weight -= hw * v1;
264
+ grad_w_weight -= hh * v1;
265
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
266
+ }
267
+ scalar_t v2 = 0;
268
+ if (h_low >= 0 && w_high <= width - 1)
269
+ {
270
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
271
+ v2 = bottom_data[ptr2];
272
+ grad_h_weight -= lw * v2;
273
+ grad_w_weight += hh * v2;
274
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
275
+ }
276
+ scalar_t v3 = 0;
277
+ if (h_high <= height - 1 && w_low >= 0)
278
+ {
279
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
280
+ v3 = bottom_data[ptr3];
281
+ grad_h_weight += hw * v3;
282
+ grad_w_weight -= lh * v3;
283
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
284
+ }
285
+ scalar_t v4 = 0;
286
+ if (h_high <= height - 1 && w_high <= width - 1)
287
+ {
288
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
289
+ v4 = bottom_data[ptr4];
290
+ grad_h_weight += lw * v4;
291
+ grad_w_weight += lh * v4;
292
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
293
+ }
294
+
295
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
296
+ *grad_attn_weight = top_grad * val;
297
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
298
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
299
+ }
300
+
301
+
302
+ template <typename scalar_t>
303
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
304
+ const int &height, const int &width, const int &nheads, const int &channels,
305
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
306
+ const scalar_t &top_grad,
307
+ const scalar_t &attn_weight,
308
+ scalar_t* &grad_value,
309
+ scalar_t* grad_sampling_loc,
310
+ scalar_t* grad_attn_weight)
311
+ {
312
+ const int h_low = floor(h);
313
+ const int w_low = floor(w);
314
+ const int h_high = h_low + 1;
315
+ const int w_high = w_low + 1;
316
+
317
+ const scalar_t lh = h - h_low;
318
+ const scalar_t lw = w - w_low;
319
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
320
+
321
+ const int w_stride = nheads * channels;
322
+ const int h_stride = width * w_stride;
323
+ const int h_low_ptr_offset = h_low * h_stride;
324
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
325
+ const int w_low_ptr_offset = w_low * w_stride;
326
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
327
+ const int base_ptr = m * channels + c;
328
+
329
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
330
+ const scalar_t top_grad_value = top_grad * attn_weight;
331
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
332
+
333
+ scalar_t v1 = 0;
334
+ if (h_low >= 0 && w_low >= 0)
335
+ {
336
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
337
+ v1 = bottom_data[ptr1];
338
+ grad_h_weight -= hw * v1;
339
+ grad_w_weight -= hh * v1;
340
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
341
+ }
342
+ scalar_t v2 = 0;
343
+ if (h_low >= 0 && w_high <= width - 1)
344
+ {
345
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
346
+ v2 = bottom_data[ptr2];
347
+ grad_h_weight -= lw * v2;
348
+ grad_w_weight += hh * v2;
349
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
350
+ }
351
+ scalar_t v3 = 0;
352
+ if (h_high <= height - 1 && w_low >= 0)
353
+ {
354
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
355
+ v3 = bottom_data[ptr3];
356
+ grad_h_weight += hw * v3;
357
+ grad_w_weight -= lh * v3;
358
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
359
+ }
360
+ scalar_t v4 = 0;
361
+ if (h_high <= height - 1 && w_high <= width - 1)
362
+ {
363
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
364
+ v4 = bottom_data[ptr4];
365
+ grad_h_weight += lw * v4;
366
+ grad_w_weight += lh * v4;
367
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
368
+ }
369
+
370
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
371
+ atomicAdd(grad_attn_weight, top_grad * val);
372
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
373
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
374
+ }
375
+
376
+
377
+ template <typename scalar_t>
378
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
379
+ const scalar_t *data_value,
380
+ const int64_t *data_spatial_shapes,
381
+ const int64_t *data_level_start_index,
382
+ const scalar_t *data_sampling_loc,
383
+ const scalar_t *data_attn_weight,
384
+ const int batch_size,
385
+ const int spatial_size,
386
+ const int num_heads,
387
+ const int channels,
388
+ const int num_levels,
389
+ const int num_query,
390
+ const int num_point,
391
+ scalar_t *data_col)
392
+ {
393
+ CUDA_KERNEL_LOOP(index, n)
394
+ {
395
+ int _temp = index;
396
+ const int c_col = _temp % channels;
397
+ _temp /= channels;
398
+ const int sampling_index = _temp;
399
+ const int m_col = _temp % num_heads;
400
+ _temp /= num_heads;
401
+ const int q_col = _temp % num_query;
402
+ _temp /= num_query;
403
+ const int b_col = _temp;
404
+
405
+ scalar_t *data_col_ptr = data_col + index;
406
+ int data_weight_ptr = sampling_index * num_levels * num_point;
407
+ int data_loc_w_ptr = data_weight_ptr << 1;
408
+ const int qid_stride = num_heads * channels;
409
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
410
+ scalar_t col = 0;
411
+
412
+ for (int l_col=0; l_col < num_levels; ++l_col)
413
+ {
414
+ const int level_start_id = data_level_start_index[l_col];
415
+ const int spatial_h_ptr = l_col << 1;
416
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
417
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
418
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
419
+ for (int p_col=0; p_col < num_point; ++p_col)
420
+ {
421
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
422
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
423
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
424
+
425
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
426
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
427
+
428
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
429
+ {
430
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
431
+ }
432
+
433
+ data_weight_ptr += 1;
434
+ data_loc_w_ptr += 2;
435
+ }
436
+ }
437
+ *data_col_ptr = col;
438
+ }
439
+ }
440
+
441
+ template <typename scalar_t, unsigned int blockSize>
442
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
443
+ const scalar_t *grad_col,
444
+ const scalar_t *data_value,
445
+ const int64_t *data_spatial_shapes,
446
+ const int64_t *data_level_start_index,
447
+ const scalar_t *data_sampling_loc,
448
+ const scalar_t *data_attn_weight,
449
+ const int batch_size,
450
+ const int spatial_size,
451
+ const int num_heads,
452
+ const int channels,
453
+ const int num_levels,
454
+ const int num_query,
455
+ const int num_point,
456
+ scalar_t *grad_value,
457
+ scalar_t *grad_sampling_loc,
458
+ scalar_t *grad_attn_weight)
459
+ {
460
+ CUDA_KERNEL_LOOP(index, n)
461
+ {
462
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
463
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
464
+ unsigned int tid = threadIdx.x;
465
+ int _temp = index;
466
+ const int c_col = _temp % channels;
467
+ _temp /= channels;
468
+ const int sampling_index = _temp;
469
+ const int m_col = _temp % num_heads;
470
+ _temp /= num_heads;
471
+ const int q_col = _temp % num_query;
472
+ _temp /= num_query;
473
+ const int b_col = _temp;
474
+
475
+ const scalar_t top_grad = grad_col[index];
476
+
477
+ int data_weight_ptr = sampling_index * num_levels * num_point;
478
+ int data_loc_w_ptr = data_weight_ptr << 1;
479
+ const int grad_sampling_ptr = data_weight_ptr;
480
+ grad_sampling_loc += grad_sampling_ptr << 1;
481
+ grad_attn_weight += grad_sampling_ptr;
482
+ const int grad_weight_stride = 1;
483
+ const int grad_loc_stride = 2;
484
+ const int qid_stride = num_heads * channels;
485
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
486
+
487
+ for (int l_col=0; l_col < num_levels; ++l_col)
488
+ {
489
+ const int level_start_id = data_level_start_index[l_col];
490
+ const int spatial_h_ptr = l_col << 1;
491
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
492
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
493
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
494
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
495
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
496
+
497
+ for (int p_col=0; p_col < num_point; ++p_col)
498
+ {
499
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
500
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
501
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
502
+
503
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
504
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
505
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
506
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
507
+ *(cache_grad_attn_weight+threadIdx.x)=0;
508
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
509
+ {
510
+ ms_deform_attn_col2im_bilinear(
511
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
512
+ top_grad, weight, grad_value_ptr,
513
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
514
+ }
515
+
516
+ __syncthreads();
517
+ if (tid == 0)
518
+ {
519
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
520
+ int sid=2;
521
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
522
+ {
523
+ _grad_w += cache_grad_sampling_loc[sid];
524
+ _grad_h += cache_grad_sampling_loc[sid + 1];
525
+ _grad_a += cache_grad_attn_weight[tid];
526
+ sid += 2;
527
+ }
528
+
529
+
530
+ *grad_sampling_loc = _grad_w;
531
+ *(grad_sampling_loc + 1) = _grad_h;
532
+ *grad_attn_weight = _grad_a;
533
+ }
534
+ __syncthreads();
535
+
536
+ data_weight_ptr += 1;
537
+ data_loc_w_ptr += 2;
538
+ grad_attn_weight += grad_weight_stride;
539
+ grad_sampling_loc += grad_loc_stride;
540
+ }
541
+ }
542
+ }
543
+ }
544
+
545
+
546
+ template <typename scalar_t, unsigned int blockSize>
547
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
548
+ const scalar_t *grad_col,
549
+ const scalar_t *data_value,
550
+ const int64_t *data_spatial_shapes,
551
+ const int64_t *data_level_start_index,
552
+ const scalar_t *data_sampling_loc,
553
+ const scalar_t *data_attn_weight,
554
+ const int batch_size,
555
+ const int spatial_size,
556
+ const int num_heads,
557
+ const int channels,
558
+ const int num_levels,
559
+ const int num_query,
560
+ const int num_point,
561
+ scalar_t *grad_value,
562
+ scalar_t *grad_sampling_loc,
563
+ scalar_t *grad_attn_weight)
564
+ {
565
+ CUDA_KERNEL_LOOP(index, n)
566
+ {
567
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
568
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
569
+ unsigned int tid = threadIdx.x;
570
+ int _temp = index;
571
+ const int c_col = _temp % channels;
572
+ _temp /= channels;
573
+ const int sampling_index = _temp;
574
+ const int m_col = _temp % num_heads;
575
+ _temp /= num_heads;
576
+ const int q_col = _temp % num_query;
577
+ _temp /= num_query;
578
+ const int b_col = _temp;
579
+
580
+ const scalar_t top_grad = grad_col[index];
581
+
582
+ int data_weight_ptr = sampling_index * num_levels * num_point;
583
+ int data_loc_w_ptr = data_weight_ptr << 1;
584
+ const int grad_sampling_ptr = data_weight_ptr;
585
+ grad_sampling_loc += grad_sampling_ptr << 1;
586
+ grad_attn_weight += grad_sampling_ptr;
587
+ const int grad_weight_stride = 1;
588
+ const int grad_loc_stride = 2;
589
+ const int qid_stride = num_heads * channels;
590
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
591
+
592
+ for (int l_col=0; l_col < num_levels; ++l_col)
593
+ {
594
+ const int level_start_id = data_level_start_index[l_col];
595
+ const int spatial_h_ptr = l_col << 1;
596
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
597
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
598
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
599
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
600
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
601
+
602
+ for (int p_col=0; p_col < num_point; ++p_col)
603
+ {
604
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
605
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
606
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
607
+
608
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
609
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
610
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
611
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
612
+ *(cache_grad_attn_weight+threadIdx.x)=0;
613
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
614
+ {
615
+ ms_deform_attn_col2im_bilinear(
616
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
617
+ top_grad, weight, grad_value_ptr,
618
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
619
+ }
620
+
621
+ __syncthreads();
622
+
623
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
624
+ {
625
+ if (tid < s) {
626
+ const unsigned int xid1 = tid << 1;
627
+ const unsigned int xid2 = (tid + s) << 1;
628
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
629
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
630
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
631
+ }
632
+ __syncthreads();
633
+ }
634
+
635
+ if (tid == 0)
636
+ {
637
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
638
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
639
+ *grad_attn_weight = cache_grad_attn_weight[0];
640
+ }
641
+ __syncthreads();
642
+
643
+ data_weight_ptr += 1;
644
+ data_loc_w_ptr += 2;
645
+ grad_attn_weight += grad_weight_stride;
646
+ grad_sampling_loc += grad_loc_stride;
647
+ }
648
+ }
649
+ }
650
+ }
651
+
652
+
653
+ template <typename scalar_t>
654
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
655
+ const scalar_t *grad_col,
656
+ const scalar_t *data_value,
657
+ const int64_t *data_spatial_shapes,
658
+ const int64_t *data_level_start_index,
659
+ const scalar_t *data_sampling_loc,
660
+ const scalar_t *data_attn_weight,
661
+ const int batch_size,
662
+ const int spatial_size,
663
+ const int num_heads,
664
+ const int channels,
665
+ const int num_levels,
666
+ const int num_query,
667
+ const int num_point,
668
+ scalar_t *grad_value,
669
+ scalar_t *grad_sampling_loc,
670
+ scalar_t *grad_attn_weight)
671
+ {
672
+ CUDA_KERNEL_LOOP(index, n)
673
+ {
674
+ extern __shared__ int _s[];
675
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
676
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
677
+ unsigned int tid = threadIdx.x;
678
+ int _temp = index;
679
+ const int c_col = _temp % channels;
680
+ _temp /= channels;
681
+ const int sampling_index = _temp;
682
+ const int m_col = _temp % num_heads;
683
+ _temp /= num_heads;
684
+ const int q_col = _temp % num_query;
685
+ _temp /= num_query;
686
+ const int b_col = _temp;
687
+
688
+ const scalar_t top_grad = grad_col[index];
689
+
690
+ int data_weight_ptr = sampling_index * num_levels * num_point;
691
+ int data_loc_w_ptr = data_weight_ptr << 1;
692
+ const int grad_sampling_ptr = data_weight_ptr;
693
+ grad_sampling_loc += grad_sampling_ptr << 1;
694
+ grad_attn_weight += grad_sampling_ptr;
695
+ const int grad_weight_stride = 1;
696
+ const int grad_loc_stride = 2;
697
+ const int qid_stride = num_heads * channels;
698
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
699
+
700
+ for (int l_col=0; l_col < num_levels; ++l_col)
701
+ {
702
+ const int level_start_id = data_level_start_index[l_col];
703
+ const int spatial_h_ptr = l_col << 1;
704
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
705
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
706
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
707
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
708
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
709
+
710
+ for (int p_col=0; p_col < num_point; ++p_col)
711
+ {
712
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
713
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
714
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
715
+
716
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
717
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
718
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
719
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
720
+ *(cache_grad_attn_weight+threadIdx.x)=0;
721
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
722
+ {
723
+ ms_deform_attn_col2im_bilinear(
724
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
725
+ top_grad, weight, grad_value_ptr,
726
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
727
+ }
728
+
729
+ __syncthreads();
730
+ if (tid == 0)
731
+ {
732
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
733
+ int sid=2;
734
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
735
+ {
736
+ _grad_w += cache_grad_sampling_loc[sid];
737
+ _grad_h += cache_grad_sampling_loc[sid + 1];
738
+ _grad_a += cache_grad_attn_weight[tid];
739
+ sid += 2;
740
+ }
741
+
742
+
743
+ *grad_sampling_loc = _grad_w;
744
+ *(grad_sampling_loc + 1) = _grad_h;
745
+ *grad_attn_weight = _grad_a;
746
+ }
747
+ __syncthreads();
748
+
749
+ data_weight_ptr += 1;
750
+ data_loc_w_ptr += 2;
751
+ grad_attn_weight += grad_weight_stride;
752
+ grad_sampling_loc += grad_loc_stride;
753
+ }
754
+ }
755
+ }
756
+ }
757
+
758
+ template <typename scalar_t>
759
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
760
+ const scalar_t *grad_col,
761
+ const scalar_t *data_value,
762
+ const int64_t *data_spatial_shapes,
763
+ const int64_t *data_level_start_index,
764
+ const scalar_t *data_sampling_loc,
765
+ const scalar_t *data_attn_weight,
766
+ const int batch_size,
767
+ const int spatial_size,
768
+ const int num_heads,
769
+ const int channels,
770
+ const int num_levels,
771
+ const int num_query,
772
+ const int num_point,
773
+ scalar_t *grad_value,
774
+ scalar_t *grad_sampling_loc,
775
+ scalar_t *grad_attn_weight)
776
+ {
777
+ CUDA_KERNEL_LOOP(index, n)
778
+ {
779
+ extern __shared__ int _s[];
780
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
781
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
782
+ unsigned int tid = threadIdx.x;
783
+ int _temp = index;
784
+ const int c_col = _temp % channels;
785
+ _temp /= channels;
786
+ const int sampling_index = _temp;
787
+ const int m_col = _temp % num_heads;
788
+ _temp /= num_heads;
789
+ const int q_col = _temp % num_query;
790
+ _temp /= num_query;
791
+ const int b_col = _temp;
792
+
793
+ const scalar_t top_grad = grad_col[index];
794
+
795
+ int data_weight_ptr = sampling_index * num_levels * num_point;
796
+ int data_loc_w_ptr = data_weight_ptr << 1;
797
+ const int grad_sampling_ptr = data_weight_ptr;
798
+ grad_sampling_loc += grad_sampling_ptr << 1;
799
+ grad_attn_weight += grad_sampling_ptr;
800
+ const int grad_weight_stride = 1;
801
+ const int grad_loc_stride = 2;
802
+ const int qid_stride = num_heads * channels;
803
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
804
+
805
+ for (int l_col=0; l_col < num_levels; ++l_col)
806
+ {
807
+ const int level_start_id = data_level_start_index[l_col];
808
+ const int spatial_h_ptr = l_col << 1;
809
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
810
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
811
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
812
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
813
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
814
+
815
+ for (int p_col=0; p_col < num_point; ++p_col)
816
+ {
817
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
818
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
819
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
820
+
821
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
822
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
823
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
824
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
825
+ *(cache_grad_attn_weight+threadIdx.x)=0;
826
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
827
+ {
828
+ ms_deform_attn_col2im_bilinear(
829
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
830
+ top_grad, weight, grad_value_ptr,
831
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
832
+ }
833
+
834
+ __syncthreads();
835
+
836
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
837
+ {
838
+ if (tid < s) {
839
+ const unsigned int xid1 = tid << 1;
840
+ const unsigned int xid2 = (tid + s) << 1;
841
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
842
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
843
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
844
+ if (tid + (s << 1) < spre)
845
+ {
846
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
847
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
848
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
849
+ }
850
+ }
851
+ __syncthreads();
852
+ }
853
+
854
+ if (tid == 0)
855
+ {
856
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
857
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
858
+ *grad_attn_weight = cache_grad_attn_weight[0];
859
+ }
860
+ __syncthreads();
861
+
862
+ data_weight_ptr += 1;
863
+ data_loc_w_ptr += 2;
864
+ grad_attn_weight += grad_weight_stride;
865
+ grad_sampling_loc += grad_loc_stride;
866
+ }
867
+ }
868
+ }
869
+ }
870
+
871
+ template <typename scalar_t>
872
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
873
+ const scalar_t *grad_col,
874
+ const scalar_t *data_value,
875
+ const int64_t *data_spatial_shapes,
876
+ const int64_t *data_level_start_index,
877
+ const scalar_t *data_sampling_loc,
878
+ const scalar_t *data_attn_weight,
879
+ const int batch_size,
880
+ const int spatial_size,
881
+ const int num_heads,
882
+ const int channels,
883
+ const int num_levels,
884
+ const int num_query,
885
+ const int num_point,
886
+ scalar_t *grad_value,
887
+ scalar_t *grad_sampling_loc,
888
+ scalar_t *grad_attn_weight)
889
+ {
890
+ CUDA_KERNEL_LOOP(index, n)
891
+ {
892
+ extern __shared__ int _s[];
893
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
894
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
895
+ unsigned int tid = threadIdx.x;
896
+ int _temp = index;
897
+ const int c_col = _temp % channels;
898
+ _temp /= channels;
899
+ const int sampling_index = _temp;
900
+ const int m_col = _temp % num_heads;
901
+ _temp /= num_heads;
902
+ const int q_col = _temp % num_query;
903
+ _temp /= num_query;
904
+ const int b_col = _temp;
905
+
906
+ const scalar_t top_grad = grad_col[index];
907
+
908
+ int data_weight_ptr = sampling_index * num_levels * num_point;
909
+ int data_loc_w_ptr = data_weight_ptr << 1;
910
+ const int grad_sampling_ptr = data_weight_ptr;
911
+ grad_sampling_loc += grad_sampling_ptr << 1;
912
+ grad_attn_weight += grad_sampling_ptr;
913
+ const int grad_weight_stride = 1;
914
+ const int grad_loc_stride = 2;
915
+ const int qid_stride = num_heads * channels;
916
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
917
+
918
+ for (int l_col=0; l_col < num_levels; ++l_col)
919
+ {
920
+ const int level_start_id = data_level_start_index[l_col];
921
+ const int spatial_h_ptr = l_col << 1;
922
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
923
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
924
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
925
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
926
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
927
+
928
+ for (int p_col=0; p_col < num_point; ++p_col)
929
+ {
930
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
931
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
932
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
933
+
934
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
935
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
936
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
937
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
938
+ *(cache_grad_attn_weight+threadIdx.x)=0;
939
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
940
+ {
941
+ ms_deform_attn_col2im_bilinear(
942
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
943
+ top_grad, weight, grad_value_ptr,
944
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
945
+ }
946
+
947
+ __syncthreads();
948
+
949
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
950
+ {
951
+ if (tid < s) {
952
+ const unsigned int xid1 = tid << 1;
953
+ const unsigned int xid2 = (tid + s) << 1;
954
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
955
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
956
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
957
+ if (tid + (s << 1) < spre)
958
+ {
959
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
960
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
961
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
962
+ }
963
+ }
964
+ __syncthreads();
965
+ }
966
+
967
+ if (tid == 0)
968
+ {
969
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
970
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
971
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
972
+ }
973
+ __syncthreads();
974
+
975
+ data_weight_ptr += 1;
976
+ data_loc_w_ptr += 2;
977
+ grad_attn_weight += grad_weight_stride;
978
+ grad_sampling_loc += grad_loc_stride;
979
+ }
980
+ }
981
+ }
982
+ }
983
+
984
+
985
+ template <typename scalar_t>
986
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
987
+ const scalar_t *grad_col,
988
+ const scalar_t *data_value,
989
+ const int64_t *data_spatial_shapes,
990
+ const int64_t *data_level_start_index,
991
+ const scalar_t *data_sampling_loc,
992
+ const scalar_t *data_attn_weight,
993
+ const int batch_size,
994
+ const int spatial_size,
995
+ const int num_heads,
996
+ const int channels,
997
+ const int num_levels,
998
+ const int num_query,
999
+ const int num_point,
1000
+ scalar_t *grad_value,
1001
+ scalar_t *grad_sampling_loc,
1002
+ scalar_t *grad_attn_weight)
1003
+ {
1004
+ CUDA_KERNEL_LOOP(index, n)
1005
+ {
1006
+ int _temp = index;
1007
+ const int c_col = _temp % channels;
1008
+ _temp /= channels;
1009
+ const int sampling_index = _temp;
1010
+ const int m_col = _temp % num_heads;
1011
+ _temp /= num_heads;
1012
+ const int q_col = _temp % num_query;
1013
+ _temp /= num_query;
1014
+ const int b_col = _temp;
1015
+
1016
+ const scalar_t top_grad = grad_col[index];
1017
+
1018
+ int data_weight_ptr = sampling_index * num_levels * num_point;
1019
+ int data_loc_w_ptr = data_weight_ptr << 1;
1020
+ const int grad_sampling_ptr = data_weight_ptr;
1021
+ grad_sampling_loc += grad_sampling_ptr << 1;
1022
+ grad_attn_weight += grad_sampling_ptr;
1023
+ const int grad_weight_stride = 1;
1024
+ const int grad_loc_stride = 2;
1025
+ const int qid_stride = num_heads * channels;
1026
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
1027
+
1028
+ for (int l_col=0; l_col < num_levels; ++l_col)
1029
+ {
1030
+ const int level_start_id = data_level_start_index[l_col];
1031
+ const int spatial_h_ptr = l_col << 1;
1032
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
1033
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
1034
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
1035
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
1036
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
1037
+
1038
+ for (int p_col=0; p_col < num_point; ++p_col)
1039
+ {
1040
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
1041
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
1042
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
1043
+
1044
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
1045
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
1046
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
1047
+ {
1048
+ ms_deform_attn_col2im_bilinear_gm(
1049
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
1050
+ top_grad, weight, grad_value_ptr,
1051
+ grad_sampling_loc, grad_attn_weight);
1052
+ }
1053
+ data_weight_ptr += 1;
1054
+ data_loc_w_ptr += 2;
1055
+ grad_attn_weight += grad_weight_stride;
1056
+ grad_sampling_loc += grad_loc_stride;
1057
+ }
1058
+ }
1059
+ }
1060
+ }
1061
+
1062
+
1063
+ template <typename scalar_t>
1064
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
1065
+ const scalar_t* data_value,
1066
+ const int64_t* data_spatial_shapes,
1067
+ const int64_t* data_level_start_index,
1068
+ const scalar_t* data_sampling_loc,
1069
+ const scalar_t* data_attn_weight,
1070
+ const int batch_size,
1071
+ const int spatial_size,
1072
+ const int num_heads,
1073
+ const int channels,
1074
+ const int num_levels,
1075
+ const int num_query,
1076
+ const int num_point,
1077
+ scalar_t* data_col)
1078
+ {
1079
+ const int num_kernels = batch_size * num_query * num_heads * channels;
1080
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
1081
+ const int num_threads = CUDA_NUM_THREADS;
1082
+ ms_deformable_im2col_gpu_kernel<scalar_t>
1083
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1084
+ 0, stream>>>(
1085
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
1086
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
1087
+
1088
+ cudaError_t err = cudaGetLastError();
1089
+ if (err != cudaSuccess)
1090
+ {
1091
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
1092
+ }
1093
+
1094
+ }
1095
+
1096
+ template <typename scalar_t>
1097
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
1098
+ const scalar_t* grad_col,
1099
+ const scalar_t* data_value,
1100
+ const int64_t * data_spatial_shapes,
1101
+ const int64_t * data_level_start_index,
1102
+ const scalar_t * data_sampling_loc,
1103
+ const scalar_t * data_attn_weight,
1104
+ const int batch_size,
1105
+ const int spatial_size,
1106
+ const int num_heads,
1107
+ const int channels,
1108
+ const int num_levels,
1109
+ const int num_query,
1110
+ const int num_point,
1111
+ scalar_t* grad_value,
1112
+ scalar_t* grad_sampling_loc,
1113
+ scalar_t* grad_attn_weight)
1114
+ {
1115
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
1116
+ const int num_kernels = batch_size * num_query * num_heads * channels;
1117
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
1118
+ if (channels > 1024)
1119
+ {
1120
+ if ((channels & 1023) == 0)
1121
+ {
1122
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
1123
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1124
+ num_threads*3*sizeof(scalar_t), stream>>>(
1125
+ num_kernels,
1126
+ grad_col,
1127
+ data_value,
1128
+ data_spatial_shapes,
1129
+ data_level_start_index,
1130
+ data_sampling_loc,
1131
+ data_attn_weight,
1132
+ batch_size,
1133
+ spatial_size,
1134
+ num_heads,
1135
+ channels,
1136
+ num_levels,
1137
+ num_query,
1138
+ num_point,
1139
+ grad_value,
1140
+ grad_sampling_loc,
1141
+ grad_attn_weight);
1142
+ }
1143
+ else
1144
+ {
1145
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1146
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1147
+ 0, stream>>>(
1148
+ num_kernels,
1149
+ grad_col,
1150
+ data_value,
1151
+ data_spatial_shapes,
1152
+ data_level_start_index,
1153
+ data_sampling_loc,
1154
+ data_attn_weight,
1155
+ batch_size,
1156
+ spatial_size,
1157
+ num_heads,
1158
+ channels,
1159
+ num_levels,
1160
+ num_query,
1161
+ num_point,
1162
+ grad_value,
1163
+ grad_sampling_loc,
1164
+ grad_attn_weight);
1165
+ }
1166
+ }
1167
+ else{
1168
+ switch(channels)
1169
+ {
1170
+ case 1:
1171
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1172
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1173
+ 0, stream>>>(
1174
+ num_kernels,
1175
+ grad_col,
1176
+ data_value,
1177
+ data_spatial_shapes,
1178
+ data_level_start_index,
1179
+ data_sampling_loc,
1180
+ data_attn_weight,
1181
+ batch_size,
1182
+ spatial_size,
1183
+ num_heads,
1184
+ channels,
1185
+ num_levels,
1186
+ num_query,
1187
+ num_point,
1188
+ grad_value,
1189
+ grad_sampling_loc,
1190
+ grad_attn_weight);
1191
+ break;
1192
+ case 2:
1193
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1194
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1195
+ 0, stream>>>(
1196
+ num_kernels,
1197
+ grad_col,
1198
+ data_value,
1199
+ data_spatial_shapes,
1200
+ data_level_start_index,
1201
+ data_sampling_loc,
1202
+ data_attn_weight,
1203
+ batch_size,
1204
+ spatial_size,
1205
+ num_heads,
1206
+ channels,
1207
+ num_levels,
1208
+ num_query,
1209
+ num_point,
1210
+ grad_value,
1211
+ grad_sampling_loc,
1212
+ grad_attn_weight);
1213
+ break;
1214
+ case 4:
1215
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1216
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1217
+ 0, stream>>>(
1218
+ num_kernels,
1219
+ grad_col,
1220
+ data_value,
1221
+ data_spatial_shapes,
1222
+ data_level_start_index,
1223
+ data_sampling_loc,
1224
+ data_attn_weight,
1225
+ batch_size,
1226
+ spatial_size,
1227
+ num_heads,
1228
+ channels,
1229
+ num_levels,
1230
+ num_query,
1231
+ num_point,
1232
+ grad_value,
1233
+ grad_sampling_loc,
1234
+ grad_attn_weight);
1235
+ break;
1236
+ case 8:
1237
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1238
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1239
+ 0, stream>>>(
1240
+ num_kernels,
1241
+ grad_col,
1242
+ data_value,
1243
+ data_spatial_shapes,
1244
+ data_level_start_index,
1245
+ data_sampling_loc,
1246
+ data_attn_weight,
1247
+ batch_size,
1248
+ spatial_size,
1249
+ num_heads,
1250
+ channels,
1251
+ num_levels,
1252
+ num_query,
1253
+ num_point,
1254
+ grad_value,
1255
+ grad_sampling_loc,
1256
+ grad_attn_weight);
1257
+ break;
1258
+ case 16:
1259
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1260
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1261
+ 0, stream>>>(
1262
+ num_kernels,
1263
+ grad_col,
1264
+ data_value,
1265
+ data_spatial_shapes,
1266
+ data_level_start_index,
1267
+ data_sampling_loc,
1268
+ data_attn_weight,
1269
+ batch_size,
1270
+ spatial_size,
1271
+ num_heads,
1272
+ channels,
1273
+ num_levels,
1274
+ num_query,
1275
+ num_point,
1276
+ grad_value,
1277
+ grad_sampling_loc,
1278
+ grad_attn_weight);
1279
+ break;
1280
+ case 32:
1281
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1282
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1283
+ 0, stream>>>(
1284
+ num_kernels,
1285
+ grad_col,
1286
+ data_value,
1287
+ data_spatial_shapes,
1288
+ data_level_start_index,
1289
+ data_sampling_loc,
1290
+ data_attn_weight,
1291
+ batch_size,
1292
+ spatial_size,
1293
+ num_heads,
1294
+ channels,
1295
+ num_levels,
1296
+ num_query,
1297
+ num_point,
1298
+ grad_value,
1299
+ grad_sampling_loc,
1300
+ grad_attn_weight);
1301
+ break;
1302
+ case 64:
1303
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1304
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1305
+ 0, stream>>>(
1306
+ num_kernels,
1307
+ grad_col,
1308
+ data_value,
1309
+ data_spatial_shapes,
1310
+ data_level_start_index,
1311
+ data_sampling_loc,
1312
+ data_attn_weight,
1313
+ batch_size,
1314
+ spatial_size,
1315
+ num_heads,
1316
+ channels,
1317
+ num_levels,
1318
+ num_query,
1319
+ num_point,
1320
+ grad_value,
1321
+ grad_sampling_loc,
1322
+ grad_attn_weight);
1323
+ break;
1324
+ case 128:
1325
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1326
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1327
+ 0, stream>>>(
1328
+ num_kernels,
1329
+ grad_col,
1330
+ data_value,
1331
+ data_spatial_shapes,
1332
+ data_level_start_index,
1333
+ data_sampling_loc,
1334
+ data_attn_weight,
1335
+ batch_size,
1336
+ spatial_size,
1337
+ num_heads,
1338
+ channels,
1339
+ num_levels,
1340
+ num_query,
1341
+ num_point,
1342
+ grad_value,
1343
+ grad_sampling_loc,
1344
+ grad_attn_weight);
1345
+ break;
1346
+ case 256:
1347
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1348
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1349
+ 0, stream>>>(
1350
+ num_kernels,
1351
+ grad_col,
1352
+ data_value,
1353
+ data_spatial_shapes,
1354
+ data_level_start_index,
1355
+ data_sampling_loc,
1356
+ data_attn_weight,
1357
+ batch_size,
1358
+ spatial_size,
1359
+ num_heads,
1360
+ channels,
1361
+ num_levels,
1362
+ num_query,
1363
+ num_point,
1364
+ grad_value,
1365
+ grad_sampling_loc,
1366
+ grad_attn_weight);
1367
+ break;
1368
+ case 512:
1369
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1370
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1371
+ 0, stream>>>(
1372
+ num_kernels,
1373
+ grad_col,
1374
+ data_value,
1375
+ data_spatial_shapes,
1376
+ data_level_start_index,
1377
+ data_sampling_loc,
1378
+ data_attn_weight,
1379
+ batch_size,
1380
+ spatial_size,
1381
+ num_heads,
1382
+ channels,
1383
+ num_levels,
1384
+ num_query,
1385
+ num_point,
1386
+ grad_value,
1387
+ grad_sampling_loc,
1388
+ grad_attn_weight);
1389
+ break;
1390
+ case 1024:
1391
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1392
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1393
+ 0, stream>>>(
1394
+ num_kernels,
1395
+ grad_col,
1396
+ data_value,
1397
+ data_spatial_shapes,
1398
+ data_level_start_index,
1399
+ data_sampling_loc,
1400
+ data_attn_weight,
1401
+ batch_size,
1402
+ spatial_size,
1403
+ num_heads,
1404
+ channels,
1405
+ num_levels,
1406
+ num_query,
1407
+ num_point,
1408
+ grad_value,
1409
+ grad_sampling_loc,
1410
+ grad_attn_weight);
1411
+ break;
1412
+ default:
1413
+ if (channels < 64)
1414
+ {
1415
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1416
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1417
+ num_threads*3*sizeof(scalar_t), stream>>>(
1418
+ num_kernels,
1419
+ grad_col,
1420
+ data_value,
1421
+ data_spatial_shapes,
1422
+ data_level_start_index,
1423
+ data_sampling_loc,
1424
+ data_attn_weight,
1425
+ batch_size,
1426
+ spatial_size,
1427
+ num_heads,
1428
+ channels,
1429
+ num_levels,
1430
+ num_query,
1431
+ num_point,
1432
+ grad_value,
1433
+ grad_sampling_loc,
1434
+ grad_attn_weight);
1435
+ }
1436
+ else
1437
+ {
1438
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1439
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1440
+ num_threads*3*sizeof(scalar_t), stream>>>(
1441
+ num_kernels,
1442
+ grad_col,
1443
+ data_value,
1444
+ data_spatial_shapes,
1445
+ data_level_start_index,
1446
+ data_sampling_loc,
1447
+ data_attn_weight,
1448
+ batch_size,
1449
+ spatial_size,
1450
+ num_heads,
1451
+ channels,
1452
+ num_levels,
1453
+ num_query,
1454
+ num_point,
1455
+ grad_value,
1456
+ grad_sampling_loc,
1457
+ grad_attn_weight);
1458
+ }
1459
+ }
1460
+ }
1461
+ cudaError_t err = cudaGetLastError();
1462
+ if (err != cudaSuccess)
1463
+ {
1464
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1465
+ }
1466
+
1467
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor ms_deform_attn_cuda_forward(
15
+ const at::Tensor &value,
16
+ const at::Tensor &spatial_shapes,
17
+ const at::Tensor &level_start_index,
18
+ const at::Tensor &sampling_loc,
19
+ const at::Tensor &attn_weight,
20
+ const int im2col_step);
21
+
22
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const at::Tensor &grad_output,
29
+ const int im2col_step);
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/deta/ms_deform_attn.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "cpu/ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "cuda/ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+
20
+ at::Tensor
21
+ ms_deform_attn_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ if (value.type().is_cuda())
30
+ {
31
+ #ifdef WITH_CUDA
32
+ return ms_deform_attn_cuda_forward(
33
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34
+ #else
35
+ AT_ERROR("Not compiled with GPU support");
36
+ #endif
37
+ }
38
+ AT_ERROR("Not implemented on the CPU");
39
+ }
40
+
41
+ std::vector<at::Tensor>
42
+ ms_deform_attn_backward(
43
+ const at::Tensor &value,
44
+ const at::Tensor &spatial_shapes,
45
+ const at::Tensor &level_start_index,
46
+ const at::Tensor &sampling_loc,
47
+ const at::Tensor &attn_weight,
48
+ const at::Tensor &grad_output,
49
+ const int im2col_step)
50
+ {
51
+ if (value.type().is_cuda())
52
+ {
53
+ #ifdef WITH_CUDA
54
+ return ms_deform_attn_cuda_backward(
55
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56
+ #else
57
+ AT_ERROR("Not compiled with GPU support");
58
+ #endif
59
+ }
60
+ AT_ERROR("Not implemented on the CPU");
61
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/deta/vision.cpp ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include "ms_deform_attn.h"
12
+
13
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-311.pyc ADDED
Binary file (21.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda.cu ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+
4
+ #define MIN_VALUE (-1e38)
5
+
6
+ template <typename F>
7
+ __global__ void kernel_forward(
8
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
9
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y
10
+ ) {
11
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
12
+ const int _b = idx / C;
13
+ const int _c = idx % C;
14
+ const int _offset = _b * T * C + _c;
15
+
16
+ F u = _u[_c];
17
+ F w = _w[_c];
18
+ const F *__restrict__ const k = _k + _offset;
19
+ const F *__restrict__ const v = _v + _offset;
20
+ F *__restrict__ const y = _y + _offset;
21
+
22
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
23
+ F aa = 0, bb = 0, pp = MIN_VALUE;
24
+ for (int i = 0; i < T; i++) {
25
+ const int ii = i * C;
26
+ const F kk = k[ii];
27
+ const F vv = v[ii];
28
+
29
+ F ww = u + kk;
30
+ F p = max(pp, ww);
31
+ F e1 = exp(pp - p);
32
+ F e2 = exp(ww - p);
33
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
34
+
35
+ ww = w + pp;
36
+ p = max(ww, kk);
37
+ e1 = exp(ww - p);
38
+ e2 = exp(kk - p);
39
+ aa = e1 * aa + e2 * vv;
40
+ bb = e1 * bb + e2;
41
+ pp = p;
42
+ }
43
+ }
44
+
45
+ template <typename F>
46
+ __global__ void kernel_forward_with_state(
47
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
48
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s
49
+ ) {
50
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
51
+ const int _b = idx / C;
52
+ const int _c = idx % C;
53
+ const int _offset_s = _b * C * 3 + _c * 3;
54
+ const int _offset = _b * T * C + _c;
55
+
56
+ F u = _u[_c];
57
+ F w = _w[_c];
58
+ const F *__restrict__ const k = _k + _offset;
59
+ const F *__restrict__ const v = _v + _offset;
60
+ F *__restrict__ const y = _y + _offset;
61
+ F *__restrict__ const s = _s + _offset_s;
62
+
63
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
64
+ F aa = s[0], bb = s[1], pp = s[2];
65
+ for (int i = 0; i < T; i++) {
66
+ const int ii = i * C;
67
+ const F kk = k[ii];
68
+ const F vv = v[ii];
69
+
70
+ F ww = u + kk;
71
+ F p = max(pp, ww);
72
+ F e1 = exp(pp - p);
73
+ F e2 = exp(ww - p);
74
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
75
+
76
+ ww = w + pp;
77
+ p = max(ww, kk);
78
+ e1 = exp(ww - p);
79
+ e2 = exp(kk - p);
80
+ aa = e1 * aa + e2 * vv;
81
+ bb = e1 * bb + e2;
82
+ pp = p;
83
+ }
84
+ s[0] = aa;
85
+ s[1] = bb;
86
+ s[2] = pp;
87
+ }
88
+
89
+ template <typename F>
90
+ __global__ void kernel_backward(
91
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
92
+ const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y,
93
+ const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk,
94
+ F *__restrict__ const _gv
95
+ ) {
96
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
97
+ const int _b = idx / C;
98
+ const int _c = idx % C;
99
+ const int _offset = _b * T * C + _c;
100
+
101
+ F u = _u[_c];
102
+ F w = _w[_c];
103
+ const F *__restrict__ const k = _k + _offset;
104
+ const F *__restrict__ const v = _v + _offset;
105
+ const F *__restrict__ const y = _y + _offset;
106
+ const F *__restrict__ const gy = _gy + _offset;
107
+ F *__restrict__ const gk = _gk + _offset;
108
+ F *__restrict__ const gv = _gv + _offset;
109
+
110
+ F q[Tmax], r[Tmax];
111
+
112
+ F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
113
+ for (int i = 0; i < T; i++) {
114
+ const int ii = i * C;
115
+ const F kk = k[ii];
116
+ const F vv = v[ii];
117
+ const F yy = y[ii];
118
+
119
+ F ww = u + kk;
120
+ F p = max(pp, ww);
121
+ F e1 = exp(pp - p);
122
+ F e2 = exp(ww - p);
123
+ const F qq = gy[ii] / (e1 * bb + e2);
124
+ gw += (ga - gb * yy) * e1 * qq;
125
+ gu += (vv - yy) * e2 * qq;
126
+ q[i] = qq;
127
+ r[i] = ww - p;
128
+
129
+ ww = w + pp;
130
+ p = max(ww, kk);
131
+ e1 = exp(ww - p);
132
+ e2 = exp(kk - p);
133
+ ga = e1 * (aa + ga);
134
+ gb = e1 * (bb + gb);
135
+ aa = e1 * aa + e2 * vv;
136
+ bb = e1 * bb + e2;
137
+ pp = p;
138
+ }
139
+ const int _offsetBC = _b * C + _c;
140
+ _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
141
+ _gu[_offsetBC] = gu;
142
+
143
+ aa = 0, bb = 0, pp = MIN_VALUE;
144
+ for (int i = T - 1; i >= 0; i--) {
145
+ const int ii = i * C;
146
+ const F kk = k[ii];
147
+ const F vv = v[ii];
148
+ const F yy = y[ii];
149
+ const F qq = q[i];
150
+ const F rr = r[i];
151
+
152
+ F e1 = qq * exp(rr);
153
+ F e2 = exp(kk + pp);
154
+ gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
155
+ gv[ii] = e1 + e2 * aa;
156
+
157
+ const F ww = w + pp;
158
+ const F www = rr - u - kk;
159
+ const F p = max(ww, www);
160
+ e1 = exp(ww - p);
161
+ e2 = qq * exp(www - p);
162
+ aa = e1 * aa + e2;
163
+ bb = e1 * bb - e2 * yy;
164
+ pp = p;
165
+ }
166
+ }
167
+
168
+ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
169
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
170
+ assert(B * C % threadsPerBlock.x == 0);
171
+ dim3 numBlocks(B * C / threadsPerBlock.x);
172
+ kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
173
+ }
174
+
175
+ void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) {
176
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
177
+ assert(B * C % threadsPerBlock.x == 0);
178
+ dim3 numBlocks(B * C / threadsPerBlock.x);
179
+ kernel_forward_with_state<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s);
180
+ }
181
+
182
+ void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
183
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
184
+ assert(B * C % threadsPerBlock.x == 0);
185
+ dim3 numBlocks(B * C / threadsPerBlock.x);
186
+ kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
187
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ #define MIN_VALUE (-1e38)
5
+ typedef at::BFloat16 bf16;
6
+
7
+ __global__ void kernel_forward_bf16(
8
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
9
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y
10
+ ) {
11
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
12
+ const int _b = idx / C;
13
+ const int _c = idx % C;
14
+ const int _offset = _b * T * C + _c;
15
+
16
+ float u = float(_u[_c]);
17
+ float w = _w[_c];
18
+ const bf16 *__restrict__ const k = _k + _offset;
19
+ const bf16 *__restrict__ const v = _v + _offset;
20
+ bf16 *__restrict__ const y = _y + _offset;
21
+
22
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
23
+ float aa = 0, bb = 0, pp = MIN_VALUE;
24
+ for (int i = 0; i < T; i++) {
25
+ const int ii = i * C;
26
+ const float kk = float(k[ii]);
27
+ const float vv = float(v[ii]);
28
+
29
+ float ww = u + kk;
30
+ float p = max(pp, ww);
31
+ float e1 = exp(pp - p);
32
+ float e2 = exp(ww - p);
33
+ y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
34
+
35
+ ww = w + pp;
36
+ p = max(ww, kk);
37
+ e1 = exp(ww - p);
38
+ e2 = exp(kk - p);
39
+ aa = e1 * aa + e2 * vv;
40
+ bb = e1 * bb + e2;
41
+ pp = p;
42
+ }
43
+ }
44
+
45
+ __global__ void kernel_forward_with_state_bf16(
46
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
47
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y,
48
+ float *__restrict__ const _s
49
+ ) {
50
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
51
+ const int _b = idx / C;
52
+ const int _c = idx % C;
53
+ const int _offset_s = _b * C * 3 + _c * 3;
54
+ const int _offset = _b * T * C + _c;
55
+
56
+ float u = float(_u[_c]);
57
+ float w = _w[_c];
58
+ const bf16 *__restrict__ const k = _k + _offset;
59
+ const bf16 *__restrict__ const v = _v + _offset;
60
+ bf16 *__restrict__ const y = _y + _offset;
61
+ float *__restrict__ const s = _s + _offset_s;
62
+
63
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
64
+ float aa = s[0], bb = s[1], pp = s[2];
65
+ for (int i = 0; i < T; i++) {
66
+ const int ii = i * C;
67
+ const float kk = float(k[ii]);
68
+ const float vv = float(v[ii]);
69
+
70
+ float ww = u + kk;
71
+ float p = max(pp, ww);
72
+ float e1 = exp(pp - p);
73
+ float e2 = exp(ww - p);
74
+ y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2);
75
+
76
+ ww = w + pp;
77
+ p = max(ww, kk);
78
+ e1 = exp(ww - p);
79
+ e2 = exp(kk - p);
80
+ aa = e1 * aa + e2 * vv;
81
+ bb = e1 * bb + e2;
82
+ pp = p;
83
+ }
84
+ s[0] = aa;
85
+ s[1] = bb;
86
+ s[2] = pp;
87
+ }
88
+
89
+ __global__ void kernel_backward_bf16(
90
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
91
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y,
92
+ const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu,
93
+ bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv
94
+ ) {
95
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
96
+ const int _b = idx / C;
97
+ const int _c = idx % C;
98
+ const int _offset = _b * T * C + _c;
99
+
100
+ float u = float(_u[_c]);
101
+ float w = _w[_c];
102
+ const bf16 *__restrict__ const k = _k + _offset;
103
+ const bf16 *__restrict__ const v = _v + _offset;
104
+ const bf16 *__restrict__ const y = _y + _offset;
105
+ const bf16 *__restrict__ const gy = _gy + _offset;
106
+ bf16 *__restrict__ const gk = _gk + _offset;
107
+ bf16 *__restrict__ const gv = _gv + _offset;
108
+
109
+ float q[Tmax], r[Tmax];
110
+
111
+ float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
112
+ for (int i = 0; i < T; i++) {
113
+ const int ii = i * C;
114
+ const float kk = float(k[ii]);
115
+ const float vv = float(v[ii]);
116
+ const float yy = float(y[ii]);
117
+
118
+ float ww = u + kk;
119
+ float p = max(pp, ww);
120
+ float e1 = exp(pp - p);
121
+ float e2 = exp(ww - p);
122
+ const float qq = float(gy[ii]) / (e1 * bb + e2);
123
+ gw += (ga - gb * yy) * e1 * qq;
124
+ gu += (vv - yy) * e2 * qq;
125
+ q[i] = qq;
126
+ r[i] = ww - p;
127
+
128
+ ww = w + pp;
129
+ p = max(ww, kk);
130
+ e1 = exp(ww - p);
131
+ e2 = exp(kk - p);
132
+ ga = e1 * (aa + ga);
133
+ gb = e1 * (bb + gb);
134
+ aa = e1 * aa + e2 * vv;
135
+ bb = e1 * bb + e2;
136
+ pp = p;
137
+ }
138
+ const int _offsetBC = _b * C + _c;
139
+ _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
140
+ _gu[_offsetBC] = bf16(gu);
141
+
142
+ aa = 0, bb = 0, pp = MIN_VALUE;
143
+ for (int i = T - 1; i >= 0; i--) {
144
+ const int ii = i * C;
145
+ const float kk = float(k[ii]);
146
+ const float vv = float(v[ii]);
147
+ const float yy = float(y[ii]);
148
+ const float qq = q[i];
149
+ const float rr = r[i];
150
+
151
+ float e1 = qq * exp(rr);
152
+ float e2 = exp(kk + pp);
153
+ gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
154
+ gv[ii] = bf16(e1 + e2 * aa);
155
+
156
+ const float ww = w + pp;
157
+ const float www = rr - u - kk;
158
+ const float p = max(ww, www);
159
+ e1 = exp(ww - p);
160
+ e2 = qq * exp(www - p);
161
+ aa = e1 * aa + e2;
162
+ bb = e1 * bb - e2 * yy;
163
+ pp = p;
164
+ }
165
+ }
166
+
167
+ void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
168
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
169
+ assert(B * C % threadsPerBlock.x == 0);
170
+ dim3 numBlocks(B * C / threadsPerBlock.x);
171
+ kernel_forward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
172
+ }
173
+
174
+ void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) {
175
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
176
+ assert(B * C % threadsPerBlock.x == 0);
177
+ dim3 numBlocks(B * C / threadsPerBlock.x);
178
+ kernel_forward_with_state_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s);
179
+ }
180
+
181
+ void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
182
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
183
+ assert(B * C % threadsPerBlock.x == 0);
184
+ dim3 numBlocks(B * C / threadsPerBlock.x);
185
+ kernel_backward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
186
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_op.cpp ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ typedef at::BFloat16 bf16;
4
+
5
+ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
6
+ void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
7
+ void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);
8
+ void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);
9
+ void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
10
+ void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
11
+
12
+ void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
13
+ const int B = k.size(0);
14
+ const int T = k.size(1);
15
+ const int C = k.size(2);
16
+ cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
17
+ }
18
+ void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
19
+ const int B = k.size(0);
20
+ const int T = k.size(1);
21
+ const int C = k.size(2);
22
+ cuda_forward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
23
+ }
24
+ void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
25
+ const int B = k.size(0);
26
+ const int T = k.size(1);
27
+ const int C = k.size(2);
28
+ cuda_forward_with_state(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), s.data_ptr<float>());
29
+ }
30
+ void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
31
+ const int B = k.size(0);
32
+ const int T = k.size(1);
33
+ const int C = k.size(2);
34
+ cuda_forward_with_state_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), s.data_ptr<float>());
35
+ }
36
+ void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
37
+ const int B = k.size(0);
38
+ const int T = k.size(1);
39
+ const int C = k.size(2);
40
+ cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
41
+ }
42
+ void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
43
+ const int B = k.size(0);
44
+ const int T = k.size(1);
45
+ const int C = k.size(2);
46
+ cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
47
+ gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
48
+ }
49
+
50
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
51
+ m.def("forward", &forward, "wkv forward");
52
+ m.def("forward_bf16", &forward_bf16, "wkv forward bf16");
53
+ m.def("forward_with_state", &forward_with_state, "wkv forward with state");
54
+ m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16");
55
+ m.def("backward", &backward, "wkv backward");
56
+ m.def("backward_bf16", &backward_bf16, "wkv backward bf16");
57
+ }
58
+
59
+ TORCH_LIBRARY(wkv, m) {
60
+ m.def("forward", forward);
61
+ m.def("forward_bf16", forward_bf16);
62
+ m.def("forward_with_state", forward_with_state);
63
+ m.def("forward_with_state_bf16", forward_with_state_bf16);
64
+ m.def("backward", backward);
65
+ m.def("backward_bf16", backward_bf16);
66
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/common.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #define min(a, b) ((a)<(b)?(a):(b))
3
+ #define max(a, b) ((a)>(b)?(a):(b))
4
+ #define ceil_divide(a, b) ((a)/(b)+((a)%(b)!=0))
5
+ #define select(cond, a, b) ((cond)?(a):(b))
6
+ #define PI 3.141592
7
+ #define EPSILON 1e-8
8
+ #define MAX_VAL 1e12
9
+ #define MIN_VAL -1e12
10
+ #define EMPTY_VALUE -1
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation.cu ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation.cu
2
+
3
+ #include <torch/extension.h>
4
+ #include <ATen/ATen.h>
5
+ #include "fast_lsh_cumulation.h"
6
+ #include "fast_lsh_cumulation_cuda.h"
7
+ #include "common_cuda.h"
8
+ #include "common.h"
9
+ #include <vector>
10
+ //////////////////////////////////////////////////////////////////////////////////////////////////
11
+ //////////////////////////////////////////////////////////////////////////////////////////////////
12
+
13
+ std::vector<at::Tensor> fast_hash_ver1_kernel(
14
+ at::Tensor query_mask,
15
+ at::Tensor query_vector,
16
+ at::Tensor key_mask,
17
+ at::Tensor key_vector,
18
+ int num_hash_f,
19
+ int hash_code_len,
20
+ bool use_cuda
21
+ ) {
22
+
23
+ int batch_size = query_vector.size(0);
24
+ int num_query = query_vector.size(1);
25
+ int num_key = key_vector.size(1);
26
+ int vector_dim = query_vector.size(2);
27
+
28
+ int num_hash_per_part = vector_dim / hash_code_len;
29
+ int num_part = max(1, ceil_divide(num_hash_f, num_hash_per_part));
30
+
31
+ at::Tensor Dmat = 2 * at::randint(0, 2, {batch_size, 3, num_part, vector_dim}, query_mask.options()) - 1;
32
+ at::Tensor query_hash_code = at::zeros({batch_size, num_query, num_hash_f}, query_mask.options());
33
+ at::Tensor key_hash_code = at::zeros({batch_size, num_key, num_hash_f}, key_mask.options());
34
+
35
+ int *query_mask_ptr = query_mask.data_ptr<int>();
36
+ float *query_vector_ptr = query_vector.data_ptr<float>();
37
+ int *key_mask_ptr = key_mask.data_ptr<int>();
38
+ float *key_vector_ptr = key_vector.data_ptr<float>();
39
+
40
+ int *Dmat_ptr = Dmat.data_ptr<int>();
41
+
42
+ int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
43
+ int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
44
+
45
+ if (use_cuda) {
46
+ {
47
+ dim3 threads(vector_dim);
48
+ dim3 blocks(num_part, num_query, batch_size);
49
+ int shared_mem = vector_dim * sizeof(float);
50
+ fast_hash_ver1_cuda_kernel<<<blocks, threads, shared_mem>>>(
51
+ query_mask_ptr,
52
+ query_vector_ptr,
53
+ Dmat_ptr,
54
+ query_hash_code_ptr,
55
+ batch_size,
56
+ num_query,
57
+ vector_dim,
58
+ num_part,
59
+ num_hash_f,
60
+ hash_code_len
61
+ );
62
+ }
63
+ {
64
+ dim3 threads(vector_dim);
65
+ dim3 blocks(num_part, num_key, batch_size);
66
+ int shared_mem = vector_dim * sizeof(float);
67
+ fast_hash_ver1_cuda_kernel<<<blocks, threads, shared_mem>>>(
68
+ key_mask_ptr,
69
+ key_vector_ptr,
70
+ Dmat_ptr,
71
+ key_hash_code_ptr,
72
+ batch_size,
73
+ num_key,
74
+ vector_dim,
75
+ num_part,
76
+ num_hash_f,
77
+ hash_code_len
78
+ );
79
+ }
80
+ }
81
+
82
+ return {query_hash_code, key_hash_code};
83
+
84
+ }
85
+
86
+ at::Tensor lsh_cumulation_ver1_kernel(
87
+ at::Tensor query_mask,
88
+ at::Tensor query_hash_code,
89
+ at::Tensor key_mask,
90
+ at::Tensor key_hash_code,
91
+ at::Tensor value,
92
+ int hashtable_capacity,
93
+ bool use_cuda
94
+ ) {
95
+
96
+ int batch_size = query_hash_code.size(0);
97
+ int num_hash_f = query_hash_code.size(2);
98
+
99
+ int num_query = query_hash_code.size(1);
100
+ int num_key = key_hash_code.size(1);
101
+ int value_dim = value.size(2);
102
+
103
+ at::Tensor hashtable_value = at::empty({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options());
104
+ at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
105
+
106
+ if (use_cuda) {
107
+ int threads_x = WARP_SIZE;
108
+ int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE;
109
+ int block_x_step1 = num_key / threads_y;
110
+ int block_x_step2 = num_query / threads_y;
111
+ int block_y = batch_size;
112
+
113
+ dim3 threads(threads_x, threads_y);
114
+ dim3 blocks_step1(block_x_step1, block_y);
115
+ dim3 blocks_step2(block_x_step2, block_y);
116
+
117
+ int *query_mask_ptr = query_mask.data_ptr<int>();
118
+ int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
119
+ int *key_mask_ptr = key_mask.data_ptr<int>();
120
+ int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
121
+ float *value_ptr = value.data_ptr<float>();
122
+ float *hashtable_value_ptr = hashtable_value.data_ptr<float>();
123
+ float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
124
+
125
+ for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
126
+
127
+ cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float));
128
+
129
+ lsh_cumulation_ver1_step1_cuda_kernel<<<blocks_step1, threads>>>(
130
+ key_mask_ptr,
131
+ key_hash_code_ptr,
132
+ value_ptr,
133
+ hashtable_value_ptr,
134
+ batch_size,
135
+ num_hash_f,
136
+ hashtable_capacity,
137
+ num_key,
138
+ value_dim,
139
+ value_offset
140
+ );
141
+
142
+ lsh_cumulation_ver1_step2_cuda_kernel<<<blocks_step2, threads>>>(
143
+ query_mask_ptr,
144
+ query_hash_code_ptr,
145
+ hashtable_value_ptr,
146
+ cumulation_value_ptr,
147
+ batch_size,
148
+ num_hash_f,
149
+ hashtable_capacity,
150
+ num_query,
151
+ value_dim,
152
+ value_offset
153
+ );
154
+ }
155
+
156
+ }
157
+
158
+ return cumulation_value;
159
+
160
+ }
161
+
162
+ at::Tensor lsh_weighted_cumulation_ver1_kernel(
163
+ at::Tensor query_mask,
164
+ at::Tensor query_hash_code,
165
+ at::Tensor query_weight,
166
+ at::Tensor key_mask,
167
+ at::Tensor key_hash_code,
168
+ at::Tensor key_weight,
169
+ at::Tensor value,
170
+ int hashtable_capacity,
171
+ bool use_cuda
172
+ ) {
173
+
174
+ int batch_size = query_hash_code.size(0);
175
+ int num_hash_f = query_hash_code.size(2);
176
+
177
+ int num_query = query_hash_code.size(1);
178
+ int num_key = key_hash_code.size(1);
179
+ int value_dim = value.size(2);
180
+ int weight_dim = query_weight.size(2);
181
+
182
+ at::Tensor hashtable_value = at::zeros({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options());
183
+ at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
184
+
185
+ if (use_cuda) {
186
+ int threads_x = WARP_SIZE;
187
+ int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE;
188
+ int block_x_step1 = num_key / threads_y;
189
+ int block_x_step2 = num_query / threads_y;
190
+ int block_y = batch_size;
191
+
192
+ dim3 threads(threads_x, threads_y);
193
+ dim3 blocks_step1(block_x_step1, block_y);
194
+ dim3 blocks_step2(block_x_step2, block_y);
195
+
196
+ int *query_mask_ptr = query_mask.data_ptr<int>();
197
+ int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
198
+ float *query_weight_ptr = query_weight.data_ptr<float>();
199
+ int *key_mask_ptr = key_mask.data_ptr<int>();
200
+ int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
201
+ float *key_weight_ptr = key_weight.data_ptr<float>();
202
+ float *value_ptr = value.data_ptr<float>();
203
+ float *hashtable_value_ptr = hashtable_value.data_ptr<float>();
204
+ float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
205
+
206
+ for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
207
+ for (int weight_idx = 0; weight_idx < weight_dim; weight_idx++) {
208
+
209
+ cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float));
210
+
211
+ lsh_weighted_cumulation_ver1_step1_cuda_kernel<<<blocks_step1, threads>>>(
212
+ key_mask_ptr,
213
+ key_hash_code_ptr,
214
+ key_weight_ptr,
215
+ value_ptr,
216
+ hashtable_value_ptr,
217
+ batch_size,
218
+ num_hash_f,
219
+ hashtable_capacity,
220
+ num_key,
221
+ value_dim,
222
+ weight_dim,
223
+ value_offset,
224
+ weight_idx
225
+ );
226
+
227
+ lsh_weighted_cumulation_ver1_step2_cuda_kernel<<<blocks_step2, threads>>>(
228
+ query_mask_ptr,
229
+ query_hash_code_ptr,
230
+ query_weight_ptr,
231
+ hashtable_value_ptr,
232
+ cumulation_value_ptr,
233
+ batch_size,
234
+ num_hash_f,
235
+ hashtable_capacity,
236
+ num_query,
237
+ value_dim,
238
+ weight_dim,
239
+ value_offset,
240
+ weight_idx
241
+ );
242
+ }
243
+ }
244
+
245
+ }
246
+
247
+ return cumulation_value;
248
+
249
+ }
250
+
251
+ at::Tensor lsh_weighted_cumulation_ver2_kernel(
252
+ at::Tensor query_mask,
253
+ at::Tensor query_hash_code,
254
+ at::Tensor query_weight,
255
+ at::Tensor key_mask,
256
+ at::Tensor key_hash_code,
257
+ at::Tensor key_weight,
258
+ at::Tensor value,
259
+ int hashtable_capacity,
260
+ bool use_cuda
261
+ ) {
262
+
263
+ int batch_size = query_hash_code.size(0);
264
+ int num_hash_f = query_hash_code.size(2);
265
+
266
+ int num_query = query_hash_code.size(1);
267
+ int num_key = key_hash_code.size(1);
268
+ int value_dim = value.size(2);
269
+ int weight_dim = query_weight.size(2);
270
+
271
+ at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
272
+ at::Tensor key_sorted_idxes = at::zeros({batch_size, num_hash_f, num_key}, query_hash_code.options());
273
+ at::Tensor query_info = at::zeros({batch_size, num_query, 2, num_hash_f}, query_hash_code.options());
274
+ at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
275
+
276
+ if (use_cuda) {
277
+
278
+ int *query_mask_ptr = query_mask.data_ptr<int>();
279
+ int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
280
+ float *query_weight_ptr = query_weight.data_ptr<float>();
281
+ int *key_mask_ptr = key_mask.data_ptr<int>();
282
+ int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
283
+ float *key_weight_ptr = key_weight.data_ptr<float>();
284
+ float *value_ptr = value.data_ptr<float>();
285
+
286
+ int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
287
+ int *key_sorted_idxes_ptr = key_sorted_idxes.data_ptr<int>();
288
+ int *query_info_ptr = query_info.data_ptr<int>();
289
+
290
+ float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
291
+
292
+ {
293
+ dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
294
+ dim3 blocks_step13(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
295
+ dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
296
+ dim3 blocks_step2(num_hash_f, batch_size);
297
+ int shared_mem = hashtable_capacity * sizeof(float);
298
+ count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
299
+ key_mask_ptr,
300
+ key_hash_code_ptr,
301
+ count_sort_table_ptr,
302
+ batch_size,
303
+ num_hash_f,
304
+ hashtable_capacity,
305
+ num_key
306
+ );
307
+ count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
308
+ count_sort_table_ptr,
309
+ batch_size,
310
+ num_hash_f,
311
+ hashtable_capacity
312
+ );
313
+ count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
314
+ key_mask_ptr,
315
+ key_hash_code_ptr,
316
+ count_sort_table_ptr,
317
+ key_sorted_idxes_ptr,
318
+ batch_size,
319
+ num_hash_f,
320
+ hashtable_capacity,
321
+ num_key
322
+ );
323
+ }
324
+ {
325
+ dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
326
+ dim3 blocks(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
327
+ extract_query_info_cuda_kernel<<<blocks, threads>>>(
328
+ query_mask_ptr,
329
+ query_hash_code_ptr,
330
+ count_sort_table_ptr,
331
+ query_info_ptr,
332
+ batch_size,
333
+ num_hash_f,
334
+ hashtable_capacity,
335
+ num_query
336
+ );
337
+ }
338
+ {
339
+ dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
340
+ dim3 blocks(num_query, num_hash_f, batch_size);
341
+ int shared_mem = (weight_dim + WARP_SIZE) * sizeof(float);
342
+ lsh_weighted_cumulation_ver2_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
343
+ query_mask_ptr,
344
+ query_info_ptr,
345
+ key_sorted_idxes_ptr,
346
+ query_weight_ptr,
347
+ key_weight_ptr,
348
+ value_ptr,
349
+ cumulation_value_ptr,
350
+ batch_size,
351
+ num_hash_f,
352
+ num_query,
353
+ num_key,
354
+ value_dim,
355
+ weight_dim
356
+ );
357
+ }
358
+ }
359
+
360
+ return cumulation_value;
361
+
362
+ }
363
+
364
+ at::Tensor lsh_weighted_cumulation_ver3_kernel(
365
+ at::Tensor query_mask,
366
+ at::Tensor query_hash_code,
367
+ at::Tensor query_weight,
368
+ at::Tensor key_mask,
369
+ at::Tensor key_hash_code,
370
+ at::Tensor key_weight,
371
+ at::Tensor value,
372
+ int hashtable_capacity,
373
+ bool use_cuda
374
+ ) {
375
+
376
+ int batch_size = query_hash_code.size(0);
377
+ int num_hash_f = query_hash_code.size(2);
378
+
379
+ int num_query = query_hash_code.size(1);
380
+ int num_key = key_hash_code.size(1);
381
+ int value_dim = value.size(2);
382
+ int weight_dim = query_weight.size(2);
383
+
384
+ at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
385
+ at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options());
386
+ at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options());
387
+ at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
388
+
389
+ if (use_cuda) {
390
+
391
+ int *query_mask_ptr = query_mask.data_ptr<int>();
392
+ int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
393
+ float *query_weight_ptr = query_weight.data_ptr<float>();
394
+ int *key_mask_ptr = key_mask.data_ptr<int>();
395
+ int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
396
+ float *key_weight_ptr = key_weight.data_ptr<float>();
397
+ float *value_ptr = value.data_ptr<float>();
398
+
399
+ int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
400
+ int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr<int>();
401
+ int *key_info_ptr = key_info.data_ptr<int>();
402
+
403
+ float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
404
+
405
+ {
406
+ dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
407
+ dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
408
+ dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
409
+ dim3 blocks_step2(num_hash_f, batch_size);
410
+ int shared_mem = hashtable_capacity * sizeof(float);
411
+ count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
412
+ query_mask_ptr,
413
+ query_hash_code_ptr,
414
+ count_sort_table_ptr,
415
+ batch_size,
416
+ num_hash_f,
417
+ hashtable_capacity,
418
+ num_query
419
+ );
420
+ count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
421
+ count_sort_table_ptr,
422
+ batch_size,
423
+ num_hash_f,
424
+ hashtable_capacity
425
+ );
426
+ count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
427
+ query_mask_ptr,
428
+ query_hash_code_ptr,
429
+ count_sort_table_ptr,
430
+ query_sorted_idxes_ptr,
431
+ batch_size,
432
+ num_hash_f,
433
+ hashtable_capacity,
434
+ num_query
435
+ );
436
+ }
437
+ {
438
+ dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
439
+ dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
440
+ extract_query_info_cuda_kernel<<<blocks, threads>>>(
441
+ key_mask_ptr,
442
+ key_hash_code_ptr,
443
+ count_sort_table_ptr,
444
+ key_info_ptr,
445
+ batch_size,
446
+ num_hash_f,
447
+ hashtable_capacity,
448
+ num_key
449
+ );
450
+ }
451
+ {
452
+ dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
453
+ dim3 blocks(num_key, num_hash_f, batch_size);
454
+ int shared_mem = (weight_dim + value_dim + WARP_SIZE) * sizeof(float);
455
+ lsh_weighted_cumulation_ver3_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
456
+ query_sorted_idxes_ptr,
457
+ key_mask_ptr,
458
+ key_info_ptr,
459
+ query_weight_ptr,
460
+ key_weight_ptr,
461
+ value_ptr,
462
+ cumulation_value_ptr,
463
+ batch_size,
464
+ num_hash_f,
465
+ num_query,
466
+ num_key,
467
+ value_dim,
468
+ weight_dim
469
+ );
470
+ }
471
+ }
472
+
473
+ return cumulation_value;
474
+
475
+ }
476
+
477
+ at::Tensor lsh_weighted_cumulation_ver4_kernel(
478
+ at::Tensor query_mask,
479
+ at::Tensor query_hash_code,
480
+ at::Tensor query_weight,
481
+ at::Tensor key_mask,
482
+ at::Tensor key_hash_code,
483
+ at::Tensor key_weight,
484
+ at::Tensor value,
485
+ int hashtable_capacity,
486
+ bool use_cuda
487
+ ) {
488
+
489
+ int batch_size = query_hash_code.size(0);
490
+ int num_hash_f = query_hash_code.size(2);
491
+
492
+ int num_query = query_hash_code.size(1);
493
+ int num_key = key_hash_code.size(1);
494
+ int value_dim = value.size(2);
495
+ int weight_dim = query_weight.size(2);
496
+
497
+ at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
498
+ at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options());
499
+ at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options());
500
+ at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
501
+
502
+ if (use_cuda) {
503
+
504
+ int *query_mask_ptr = query_mask.data_ptr<int>();
505
+ int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
506
+ float *query_weight_ptr = query_weight.data_ptr<float>();
507
+ int *key_mask_ptr = key_mask.data_ptr<int>();
508
+ int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
509
+ float *key_weight_ptr = key_weight.data_ptr<float>();
510
+ float *value_ptr = value.data_ptr<float>();
511
+
512
+ int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
513
+ int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr<int>();
514
+ int *key_info_ptr = key_info.data_ptr<int>();
515
+
516
+ float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
517
+
518
+ {
519
+ dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
520
+ dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
521
+ dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
522
+ dim3 blocks_step2(num_hash_f, batch_size);
523
+ int shared_mem = hashtable_capacity * sizeof(float);
524
+ count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
525
+ query_mask_ptr,
526
+ query_hash_code_ptr,
527
+ count_sort_table_ptr,
528
+ batch_size,
529
+ num_hash_f,
530
+ hashtable_capacity,
531
+ num_query
532
+ );
533
+ count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
534
+ count_sort_table_ptr,
535
+ batch_size,
536
+ num_hash_f,
537
+ hashtable_capacity
538
+ );
539
+ count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
540
+ query_mask_ptr,
541
+ query_hash_code_ptr,
542
+ count_sort_table_ptr,
543
+ query_sorted_idxes_ptr,
544
+ batch_size,
545
+ num_hash_f,
546
+ hashtable_capacity,
547
+ num_query
548
+ );
549
+ }
550
+ {
551
+ dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
552
+ dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
553
+ extract_query_info_cuda_kernel<<<blocks, threads>>>(
554
+ key_mask_ptr,
555
+ key_hash_code_ptr,
556
+ count_sort_table_ptr,
557
+ key_info_ptr,
558
+ batch_size,
559
+ num_hash_f,
560
+ hashtable_capacity,
561
+ num_key
562
+ );
563
+ }
564
+ {
565
+ dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
566
+ dim3 blocks(num_key, batch_size);
567
+ int shared_mem = (weight_dim + value_dim + 2 * num_hash_f) * sizeof(float);
568
+ lsh_weighted_cumulation_ver4_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
569
+ query_sorted_idxes_ptr,
570
+ key_mask_ptr,
571
+ key_info_ptr,
572
+ query_weight_ptr,
573
+ key_weight_ptr,
574
+ value_ptr,
575
+ cumulation_value_ptr,
576
+ batch_size,
577
+ num_hash_f,
578
+ num_query,
579
+ num_key,
580
+ value_dim,
581
+ weight_dim
582
+ );
583
+ }
584
+ }
585
+
586
+ return cumulation_value;
587
+
588
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation.h ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include <vector>
4
+
5
+ std::vector<at::Tensor> fast_hash_ver1_kernel(
6
+ at::Tensor query_mask,
7
+ at::Tensor query_vector,
8
+ at::Tensor key_mask,
9
+ at::Tensor key_vector,
10
+ int num_hash_f,
11
+ int hash_code_len,
12
+ bool use_cuda
13
+ );
14
+
15
+ at::Tensor lsh_cumulation_ver1_kernel(
16
+ at::Tensor query_mask,
17
+ at::Tensor query_hash_code,
18
+ at::Tensor key_mask,
19
+ at::Tensor key_hash_code,
20
+ at::Tensor value,
21
+ int hashtable_capacity,
22
+ bool use_cuda
23
+ );
24
+
25
+ at::Tensor lsh_weighted_cumulation_ver1_kernel(
26
+ at::Tensor query_mask,
27
+ at::Tensor query_hash_code,
28
+ at::Tensor query_weight,
29
+ at::Tensor key_mask,
30
+ at::Tensor key_hash_code,
31
+ at::Tensor key_weight,
32
+ at::Tensor value,
33
+ int hashtable_capacity,
34
+ bool use_cuda
35
+ );
36
+
37
+ at::Tensor lsh_weighted_cumulation_ver2_kernel(
38
+ at::Tensor query_mask,
39
+ at::Tensor query_hash_code,
40
+ at::Tensor query_weight,
41
+ at::Tensor key_mask,
42
+ at::Tensor key_hash_code,
43
+ at::Tensor key_weight,
44
+ at::Tensor value,
45
+ int hashtable_capacity,
46
+ bool use_cuda
47
+ );
48
+
49
+ at::Tensor lsh_weighted_cumulation_ver3_kernel(
50
+ at::Tensor query_mask,
51
+ at::Tensor query_hash_code,
52
+ at::Tensor query_weight,
53
+ at::Tensor key_mask,
54
+ at::Tensor key_hash_code,
55
+ at::Tensor key_weight,
56
+ at::Tensor value,
57
+ int hashtable_capacity,
58
+ bool use_cuda
59
+ );
60
+
61
+ at::Tensor lsh_weighted_cumulation_ver4_kernel(
62
+ at::Tensor query_mask,
63
+ at::Tensor query_hash_code,
64
+ at::Tensor query_weight,
65
+ at::Tensor key_mask,
66
+ at::Tensor key_hash_code,
67
+ at::Tensor key_weight,
68
+ at::Tensor value,
69
+ int hashtable_capacity,
70
+ bool use_cuda
71
+ );
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_cuda.cu ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation_cuda.cu
2
+
3
+ #include "fast_lsh_cumulation_cuda.h"
4
+ #include "common_cuda_device.h"
5
+ #include "common_cuda.h"
6
+ #include "common.h"
7
+ #include <stdio.h>
8
+ //////////////////////////////////////////////////////////////////////////////////////////////////
9
+ //////////////////////////////////////////////////////////////////////////////////////////////////
10
+
11
+ inline __device__ void fast_hadamard_transform(float *vector_buffer, int vector_dim, int dim_idx) {
12
+ int stride = vector_dim / 2;
13
+ while (stride > (WARP_SIZE / 2)) {
14
+ __syncthreads();
15
+ int sign = 1 - ((dim_idx / stride) % 2) * 2;
16
+ float val1 = vector_buffer[dim_idx];
17
+ float val2 = vector_buffer[dim_idx + sign * stride];
18
+ __syncthreads();
19
+ vector_buffer[dim_idx] = float(sign) * val1 + val2;
20
+ stride = stride / 2;
21
+ }
22
+
23
+ float val = vector_buffer[dim_idx];
24
+ #pragma unroll
25
+ for (stride = (WARP_SIZE / 2); stride > 0; stride = stride / 2) {
26
+ int sign = 1 - ((dim_idx / stride) % 2) * 2;
27
+ val = float(sign) * val + __shfl_xor_sync(FULL_MASK, val, stride);
28
+ }
29
+ vector_buffer[dim_idx] = val;
30
+ }
31
+
32
+ __global__ void fast_hash_ver1_cuda_kernel(
33
+ int *mask, // [batch_size, num_vector]
34
+ float *vector, // [batch_size, num_vector, vector_dim]
35
+ int *Dmat, // [batch_size, 3, num_part, vector_dim]
36
+ int *hash_code, // [batch_size, num_vector, num_hash_f]
37
+ int batch_size,
38
+ int num_vector,
39
+ int vector_dim,
40
+ int num_part,
41
+ int num_hash_f,
42
+ int hash_code_len
43
+ ) {
44
+
45
+ int batch_idx = blockIdx.z;
46
+ int vector_idx = blockIdx.y;
47
+ int part_idx = blockIdx.x;
48
+
49
+ int dim_idx = threadIdx.x;
50
+
51
+ int batch_idx__vector_idx = batch_idx * num_vector + vector_idx;
52
+ if (mask[batch_idx__vector_idx] == 0) {
53
+ return;
54
+ }
55
+
56
+ extern __shared__ float buffer[];
57
+ float *vector_buffer = buffer;
58
+
59
+ vector_buffer[dim_idx] = vector[batch_idx__vector_idx * vector_dim + dim_idx];
60
+
61
+ vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 0) * num_part + part_idx) * vector_dim + dim_idx];
62
+ fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
63
+ vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 1) * num_part + part_idx) * vector_dim + dim_idx];
64
+ fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
65
+ vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 2) * num_part + part_idx) * vector_dim + dim_idx];
66
+ fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
67
+
68
+ int num_hash_per_part = vector_dim / hash_code_len;
69
+ if (hash_code_len == 8 || hash_code_len == 16) {
70
+ int code = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0);
71
+ for (int offset = 1; offset < hash_code_len; offset = offset * 2) {
72
+ code += __shfl_xor_sync(FULL_MASK, code, offset);
73
+ }
74
+ if (dim_idx % hash_code_len == 0) {
75
+ int hash_f_idx = part_idx * num_hash_per_part + dim_idx / hash_code_len;
76
+ if (hash_f_idx < num_hash_f) {
77
+ hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code;
78
+ }
79
+ }
80
+ } else {
81
+ vector_buffer[dim_idx] = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0);
82
+ __syncthreads();
83
+ if (dim_idx < num_hash_per_part) {
84
+ int code = 0;
85
+ for (int i = 0; i < hash_code_len; i++) {
86
+ code += vector_buffer[dim_idx * hash_code_len + i];
87
+ }
88
+ int hash_f_idx = part_idx * num_hash_per_part + dim_idx;
89
+ if (hash_f_idx < num_hash_f) {
90
+ hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code;
91
+ }
92
+ }
93
+ }
94
+ }
95
+
96
+ __global__ void lsh_cumulation_ver1_step1_cuda_kernel(
97
+ int *key_mask, // [batch_size, num_key]
98
+ int *key_hash_code, // [batch_size, num_key, num_hash_f]
99
+ float *value, // [batch_size, num_key, value_dim]
100
+ float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
101
+ int batch_size,
102
+ int num_hash_f,
103
+ int hashtable_capacity,
104
+ int num_key,
105
+ int value_dim,
106
+ int offset_warp
107
+ ) {
108
+
109
+ int warp_thread_idx = threadIdx.x;
110
+
111
+ int batch_idx = blockIdx.y;
112
+ int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
113
+
114
+ int batch_idx__key_idx = batch_idx * num_key + key_idx;
115
+ if (key_mask[batch_idx__key_idx] == 0) {
116
+ return;
117
+ }
118
+
119
+ if (num_hash_f > WARP_SIZE) {
120
+ float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
121
+ for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
122
+ int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx];
123
+ #pragma unroll
124
+ for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
125
+ int current_hashcode = warp_hashcode;
126
+ current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
127
+ int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
128
+ atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
129
+ }
130
+ }
131
+ } else {
132
+ float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
133
+ int warp_hashcode = 0;
134
+ if (warp_thread_idx < num_hash_f) {
135
+ warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx];
136
+ }
137
+ for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
138
+ int current_hashcode = warp_hashcode;
139
+ current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
140
+ int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
141
+ atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
142
+ }
143
+ }
144
+
145
+ }
146
+
147
+ __global__ void lsh_cumulation_ver1_step2_cuda_kernel(
148
+ int *query_mask, // [batch_size, num_query]
149
+ int *query_hash_code, // [batch_size, num_query, num_hash_f]
150
+ float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
151
+ float *cumulation_value, // [batch_size, num_query, value_dim]
152
+ int batch_size,
153
+ int num_hash_f,
154
+ int hashtable_capacity,
155
+ int num_query,
156
+ int value_dim,
157
+ int offset_warp
158
+ ) {
159
+
160
+ int warp_thread_idx = threadIdx.x;
161
+
162
+ int batch_idx = blockIdx.y;
163
+ int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
164
+
165
+ int batch_idx__query_idx = batch_idx * num_query + query_idx;
166
+ if (query_mask[batch_idx__query_idx] == 0) {
167
+ return;
168
+ }
169
+
170
+ if (num_hash_f > WARP_SIZE) {
171
+ float warp_value = 0;
172
+ for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
173
+ int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx];
174
+ #pragma unroll
175
+ for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
176
+ int current_hashcode = warp_hashcode;
177
+ current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
178
+ int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
179
+ warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
180
+ }
181
+ }
182
+ cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f);
183
+ } else {
184
+ float warp_value = 0;
185
+ int warp_hashcode = 0;
186
+ if (warp_thread_idx < num_hash_f) {
187
+ warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx];
188
+ }
189
+ for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
190
+ int current_hashcode = warp_hashcode;
191
+ current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
192
+ int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
193
+ warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
194
+ }
195
+ cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f);
196
+ }
197
+
198
+ }
199
+
200
+ __global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
201
+ int *key_mask, // [batch_size, num_key]
202
+ int *key_hash_code, // [batch_size, num_key, num_hash_f]
203
+ float *key_weight, // [batch_size, num_key, weight_dim]
204
+ float *value, // [batch_size, num_key, value_dim]
205
+ float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
206
+ int batch_size,
207
+ int num_hash_f,
208
+ int hashtable_capacity,
209
+ int num_key,
210
+ int value_dim,
211
+ int weight_dim,
212
+ int offset_warp,
213
+ int weight_idx
214
+ ) {
215
+
216
+ int warp_thread_idx = threadIdx.x;
217
+
218
+ int batch_idx = blockIdx.y;
219
+ int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
220
+
221
+ int batch_idx__key_idx = batch_idx * num_key + key_idx;
222
+ if (key_mask[batch_idx__key_idx] == 0) {
223
+ return;
224
+ }
225
+
226
+ if (num_hash_f > WARP_SIZE) {
227
+ float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
228
+ for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
229
+ int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx];
230
+ #pragma unroll
231
+ for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
232
+ int current_hashcode = warp_hashcode;
233
+ current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
234
+ int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
235
+ atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
236
+ }
237
+ }
238
+ } else {
239
+ float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
240
+ int warp_hashcode = 0;
241
+ if (warp_thread_idx < num_hash_f) {
242
+ warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx];
243
+ }
244
+ for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
245
+ int current_hashcode = warp_hashcode;
246
+ current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
247
+ int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
248
+ atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
249
+ }
250
+ }
251
+
252
+ }
253
+
254
+ __global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
255
+ int *query_mask, // [batch_size, num_query]
256
+ int *query_hash_code, // [batch_size, num_query, num_hash_f]
257
+ float *query_weight, // [batch_size, num_query, weight_dim]
258
+ float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
259
+ float *cumulation_value, // [batch_size, num_query, value_dim]
260
+ int batch_size,
261
+ int num_hash_f,
262
+ int hashtable_capacity,
263
+ int num_query,
264
+ int value_dim,
265
+ int weight_dim,
266
+ int offset_warp,
267
+ int weight_idx
268
+ ) {
269
+
270
+ int warp_thread_idx = threadIdx.x;
271
+
272
+ int batch_idx = blockIdx.y;
273
+ int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
274
+
275
+ int batch_idx__query_idx = batch_idx * num_query + query_idx;
276
+ if (query_mask[batch_idx__query_idx] == 0) {
277
+ return;
278
+ }
279
+
280
+ if (num_hash_f > WARP_SIZE) {
281
+ float warp_value = 0;
282
+ for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
283
+ int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx];
284
+ #pragma unroll
285
+ for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
286
+ int current_hashcode = warp_hashcode;
287
+ current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
288
+ int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
289
+ warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
290
+ }
291
+ }
292
+ float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx];
293
+ cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f);
294
+ } else {
295
+ float warp_value = 0;
296
+ int warp_hashcode = 0;
297
+ if (warp_thread_idx < num_hash_f) {
298
+ warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx];
299
+ }
300
+ for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
301
+ int current_hashcode = warp_hashcode;
302
+ current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
303
+ int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
304
+ warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
305
+ }
306
+ float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx];
307
+ cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f);
308
+ }
309
+
310
+ }
311
+
312
+ __global__ void count_sort_step1_cuda_kernel(
313
+ int *key_mask, // [batch_size, num_key]
314
+ int *key_hash_code, // [batch_size, num_key, num_hash_f]
315
+ int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
316
+ int batch_size,
317
+ int num_hash_f,
318
+ int hashtable_capacity,
319
+ int num_key
320
+ ) {
321
+
322
+ int batch_idx = blockIdx.y;
323
+ int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
324
+ int hash_f_idx = threadIdx.x;
325
+
326
+ int batch_idx__key_idx = batch_idx * num_key + key_idx;
327
+ if (key_mask[batch_idx__key_idx] == 0) {
328
+ return;
329
+ }
330
+
331
+ int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx];
332
+ atomicAdd(&count_sort_table[(batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code], 1);
333
+
334
+ }
335
+
336
+ __global__ void count_sort_step2_cuda_kernel(
337
+ int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
338
+ int batch_size,
339
+ int num_hash_f,
340
+ int hashtable_capacity
341
+ ) {
342
+
343
+ int batch_idx = blockIdx.y;
344
+ int hash_f_idx = blockIdx.x;
345
+
346
+ int num_threads = blockDim.x;
347
+ int thread_id = threadIdx.x;
348
+
349
+ int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
350
+
351
+ extern __shared__ float buffer[];
352
+ int *table_buffer = (int*)buffer;
353
+
354
+ if (thread_id == 0) {
355
+ table_buffer[0] = 0;
356
+ }
357
+ copy_data<int>(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], &table_buffer[1], hashtable_capacity - 1, num_threads, thread_id);
358
+
359
+ for (int table_idx_start = 0; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + num_threads) {
360
+ int thread_value = table_buffer[table_idx_start + thread_id];
361
+ int next_thread_value = 0;
362
+ for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
363
+ next_thread_value = __shfl_up_sync(FULL_MASK, thread_value, offset);
364
+ if (thread_id % WARP_SIZE >= offset) {
365
+ thread_value = thread_value + next_thread_value;
366
+ }
367
+ }
368
+ table_buffer[table_idx_start + thread_id] = thread_value;
369
+ }
370
+ __syncthreads();
371
+
372
+ if (hashtable_capacity > WARP_SIZE) {
373
+ if (thread_id < WARP_SIZE) {
374
+ for (int table_idx_start = WARP_SIZE; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + WARP_SIZE) {
375
+ table_buffer[table_idx_start + thread_id] += table_buffer[table_idx_start - 1];
376
+ }
377
+ }
378
+ }
379
+
380
+ copy_data<int>(table_buffer, &count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], hashtable_capacity, num_threads, thread_id);
381
+
382
+ }
383
+
384
+
385
+ __global__ void count_sort_step3_cuda_kernel(
386
+ int *key_mask, // [batch_size, num_key]
387
+ int *key_hash_code, // [batch_size, num_key, num_hash_f]
388
+ int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
389
+ int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
390
+ int batch_size,
391
+ int num_hash_f,
392
+ int hashtable_capacity,
393
+ int num_key
394
+ ) {
395
+
396
+ int batch_idx = blockIdx.y;
397
+ int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
398
+ int hash_f_idx = threadIdx.x;
399
+
400
+ int batch_idx__key_idx = batch_idx * num_key + key_idx;
401
+ if (key_mask[batch_idx__key_idx] == 0) {
402
+ return;
403
+ }
404
+
405
+ int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
406
+
407
+ int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx];
408
+ int sort_idx = atomicAdd(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity + hash_code], 1);
409
+ key_sorted_idxes[batch_idx__hash_f_idx * num_key + sort_idx] = key_idx;
410
+
411
+ }
412
+
413
+ __global__ void extract_query_info_cuda_kernel(
414
+ int *query_mask, // [batch_size, num_query]
415
+ int *query_hash_code, // [batch_size, num_query, num_hash_f]
416
+ int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
417
+ int *query_info, // [batch_size, num_query, 2, num_hash_f]
418
+ int batch_size,
419
+ int num_hash_f,
420
+ int hashtable_capacity,
421
+ int num_query
422
+ ) {
423
+
424
+ int batch_idx = blockIdx.y;
425
+ int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
426
+ int hash_f_idx = threadIdx.x;
427
+
428
+ int batch_idx__query_idx = batch_idx * num_query + query_idx;
429
+ if (query_mask[batch_idx__query_idx] == 0) {
430
+ return;
431
+ }
432
+
433
+ int hash_code = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_idx];
434
+ int batch_idx__hash_f_idx__hash_code = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code;
435
+
436
+ int key_offset = select(hash_code == 0, 0, count_sort_table[batch_idx__hash_f_idx__hash_code - 1]);
437
+ int key_count = count_sort_table[batch_idx__hash_f_idx__hash_code] - key_offset;
438
+
439
+ query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx] = key_offset;
440
+ query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx] = key_count;
441
+
442
+ }
443
+
444
+ __global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
445
+ int *query_mask, // [batch_size, num_query]
446
+ int *query_info, // [batch_size, num_query, 2, num_hash_f]
447
+ int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
448
+ float *query_weight, // [batch_size, num_query, weight_dim]
449
+ float *key_weight, // [batch_size, num_key, weight_dim]
450
+ float *value, // [batch_size, num_key, value_dim]
451
+ float *cumulation_value, // [batch_size, num_query, value_dim]
452
+ int batch_size,
453
+ int num_hash_f,
454
+ int num_query,
455
+ int num_key,
456
+ int value_dim,
457
+ int weight_dim
458
+ ) {
459
+
460
+ int batch_idx = blockIdx.z;
461
+ int hash_f_idx = blockIdx.y;
462
+ int query_idx = blockIdx.x;
463
+
464
+ int num_threads = blockDim.y * blockDim.x;
465
+ int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
466
+
467
+ int num_warps = blockDim.y;
468
+ int warp_idx = threadIdx.y;
469
+ int warp_thread_idx = threadIdx.x;
470
+
471
+ int batch_idx__query_idx = batch_idx * num_query + query_idx;
472
+ if (query_mask[batch_idx__query_idx] == 0) {
473
+ return;
474
+ }
475
+
476
+ int key_offset = query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx];
477
+ int key_count = query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx];
478
+
479
+ if (key_count == 0) {
480
+ return;
481
+ }
482
+
483
+ extern __shared__ float buffer[];
484
+
485
+ if (key_count == 1) {
486
+ if (warp_idx == 0) {
487
+ int key_idx = key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset];
488
+ int batch_idx__key_idx = batch_idx * num_key + key_idx;
489
+ float weight = 0;
490
+ for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
491
+ int weight_dim_idx = weight_offset + warp_thread_idx;
492
+ float val = query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx];
493
+ #pragma unroll
494
+ for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
495
+ val += __shfl_xor_sync(FULL_MASK, val, offset);
496
+ }
497
+ weight = weight + val;
498
+ }
499
+ weight = weight / float(num_hash_f);
500
+ for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
501
+ int value_dim_idx = value_offset + warp_thread_idx;
502
+ float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
503
+ atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
504
+ }
505
+ }
506
+ } else {
507
+ float *weight_buffer = buffer;
508
+ int *key_idxes_buffer = (int*)&buffer[weight_dim];
509
+
510
+ copy_data_nonblocking<float>(&query_weight[batch_idx__query_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
511
+
512
+ while (key_count > 0) {
513
+ int work_size = min(WARP_SIZE, key_count);
514
+ copy_data_nonblocking<int>(&key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset], key_idxes_buffer, work_size, num_threads, thread_id);
515
+ __syncthreads();
516
+ for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) {
517
+ int work_idx = work_offset + warp_idx;
518
+ if (work_idx < key_count) {
519
+ int key_idx = key_idxes_buffer[work_idx];
520
+ int batch_idx__key_idx = batch_idx * num_key + key_idx;
521
+ float weight = 0;
522
+ for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
523
+ int weight_dim_idx = weight_offset + warp_thread_idx;
524
+ float val = weight_buffer[weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx];
525
+ #pragma unroll
526
+ for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
527
+ val += __shfl_xor_sync(FULL_MASK, val, offset);
528
+ }
529
+ weight = weight + val;
530
+ }
531
+ weight = weight / float(num_hash_f);
532
+ for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
533
+ int value_dim_idx = value_offset + warp_thread_idx;
534
+ float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
535
+ atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
536
+ }
537
+ }
538
+ }
539
+ key_count = key_count - work_size;
540
+ key_offset = key_offset + work_size;
541
+ }
542
+ }
543
+
544
+ }
545
+
546
+ __global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
547
+ int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
548
+ int *key_mask, // [batch_size, num_key]
549
+ int *key_info, // [batch_size, num_key, 2, num_hash_f]
550
+ float *query_weight, // [batch_size, num_query, weight_dim]
551
+ float *key_weight, // [batch_size, num_key, weight_dim]
552
+ float *value, // [batch_size, num_key, value_dim]
553
+ float *cumulation_value, // [batch_size, num_query, value_dim]
554
+ int batch_size,
555
+ int num_hash_f,
556
+ int num_query,
557
+ int num_key,
558
+ int value_dim,
559
+ int weight_dim
560
+ ) {
561
+
562
+ int batch_idx = blockIdx.z;
563
+ int hash_f_idx = blockIdx.y;
564
+ int key_idx = blockIdx.x;
565
+
566
+ int num_threads = blockDim.y * blockDim.x;
567
+ int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
568
+
569
+ int num_warps = blockDim.y;
570
+ int warp_idx = threadIdx.y;
571
+ int warp_thread_idx = threadIdx.x;
572
+
573
+ int batch_idx__key_idx = batch_idx * num_key + key_idx;
574
+ if (key_mask[batch_idx__key_idx] == 0) {
575
+ return;
576
+ }
577
+
578
+ int query_offset = key_info[batch_idx__key_idx * 2 * num_hash_f + hash_f_idx];
579
+ int query_count = key_info[(batch_idx__key_idx * 2 + 1) * num_hash_f + hash_f_idx];
580
+
581
+ if (query_count == 0) {
582
+ return;
583
+ }
584
+
585
+ extern __shared__ float buffer[];
586
+
587
+ if (query_count == 1) {
588
+ if (warp_idx == 0) {
589
+ int query_idx = query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset];
590
+ int batch_idx__query_idx = batch_idx * num_query + query_idx;
591
+ float weight = 0;
592
+ for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
593
+ int weight_dim_idx = weight_offset + warp_thread_idx;
594
+ float val = key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
595
+ #pragma unroll
596
+ for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
597
+ val += __shfl_xor_sync(FULL_MASK, val, offset);
598
+ }
599
+ weight = weight + val;
600
+ }
601
+ weight = weight / float(num_hash_f);
602
+ for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
603
+ int value_dim_idx = value_offset + warp_thread_idx;
604
+ float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
605
+ atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
606
+ }
607
+ }
608
+ } else {
609
+ float *weight_buffer = buffer;
610
+ float *value_buffer = &buffer[weight_dim];
611
+ int *query_idxes_buffer = (int*)&buffer[weight_dim + value_dim];
612
+
613
+ copy_data_nonblocking<float>(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
614
+ copy_data_nonblocking<float>(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id);
615
+
616
+ while (query_count > 0) {
617
+ int work_size = min(WARP_SIZE, query_count);
618
+ copy_data_nonblocking<int>(&query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset], query_idxes_buffer, work_size, num_threads, thread_id);
619
+ __syncthreads();
620
+ for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) {
621
+ int work_idx = work_offset + warp_idx;
622
+ if (work_idx < query_count) {
623
+ int query_idx = query_idxes_buffer[work_idx];
624
+ int batch_idx__query_idx = batch_idx * num_query + query_idx;
625
+ float weight = 0;
626
+ for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
627
+ int weight_dim_idx = weight_offset + warp_thread_idx;
628
+ float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
629
+ #pragma unroll
630
+ for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
631
+ val += __shfl_xor_sync(FULL_MASK, val, offset);
632
+ }
633
+ weight = weight + val;
634
+ }
635
+ weight = weight / float(num_hash_f);
636
+ for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
637
+ int value_dim_idx = value_offset + warp_thread_idx;
638
+ float val = value_buffer[value_dim_idx];
639
+ atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
640
+ }
641
+ }
642
+ }
643
+ query_count = query_count - work_size;
644
+ query_offset = query_offset + work_size;
645
+ }
646
+ }
647
+
648
+ }
649
+
650
+ __global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
651
+ int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
652
+ int *key_mask, // [batch_size, num_key]
653
+ int *key_info, // [batch_size, num_key, 2, num_hash_f]
654
+ float *query_weight, // [batch_size, num_query, weight_dim]
655
+ float *key_weight, // [batch_size, num_key, weight_dim]
656
+ float *value, // [batch_size, num_key, value_dim]
657
+ float *cumulation_value, // [batch_size, num_query, value_dim]
658
+ int batch_size,
659
+ int num_hash_f,
660
+ int num_query,
661
+ int num_key,
662
+ int value_dim,
663
+ int weight_dim
664
+ ) {
665
+
666
+ int batch_idx = blockIdx.y;
667
+ int key_idx = blockIdx.x;
668
+
669
+ int num_threads = blockDim.y * blockDim.x;
670
+ int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
671
+
672
+ int num_warps = blockDim.y;
673
+ int warp_idx = threadIdx.y;
674
+ int warp_thread_idx = threadIdx.x;
675
+
676
+ int batch_idx__key_idx = batch_idx * num_key + key_idx;
677
+ if (key_mask[batch_idx__key_idx] == 0) {
678
+ return;
679
+ }
680
+
681
+ extern __shared__ float buffer[];
682
+ float *weight_buffer = buffer;
683
+ float *value_buffer = &buffer[weight_dim];
684
+ int *key_info_buffer = (int*)&buffer[weight_dim + value_dim];
685
+
686
+ copy_data_nonblocking<float>(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
687
+ copy_data_nonblocking<float>(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id);
688
+ copy_data_nonblocking<int>(&key_info[batch_idx__key_idx * 2 * num_hash_f], key_info_buffer, 2 * num_hash_f, num_threads, thread_id);
689
+
690
+ int *query_offset_buffer = key_info_buffer;
691
+ int *query_count_buffer = &key_info_buffer[num_hash_f];
692
+
693
+ const int hashtable_size = 1024 + OPTIMAL_THREADS_PER_BLOCK;
694
+ __shared__ int hashtable_query[hashtable_size];
695
+ __shared__ int hashtable_count[hashtable_size];
696
+ __shared__ int inserted_query[hashtable_size];
697
+ __shared__ int query_counter[1];
698
+
699
+ int hash_f_idx_base = 0;
700
+
701
+ while (true) {
702
+
703
+ init_buffer_nonblocking<int>(EMPTY_VALUE, hashtable_query, hashtable_size, num_threads, thread_id);
704
+ init_buffer_nonblocking<int>(0, hashtable_count, hashtable_size, num_threads, thread_id);
705
+ init_buffer_nonblocking<int>(EMPTY_VALUE, inserted_query, hashtable_size, num_threads, thread_id);
706
+ init_buffer_nonblocking<int>(0, query_counter, 1, num_threads, thread_id);
707
+ __syncthreads();
708
+
709
+ while (hash_f_idx_base < num_hash_f) {
710
+
711
+ int hash_f_idx = hash_f_idx_base + warp_idx;
712
+ int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
713
+
714
+ int stop_flag = 0;
715
+
716
+ int query_offset = query_offset_buffer[hash_f_idx];
717
+ int query_count = query_count_buffer[hash_f_idx];
718
+
719
+ while (query_count > 0) {
720
+
721
+ int work_size = min(query_count, WARP_SIZE);
722
+
723
+ // try inserting query to set and check whether the query is new
724
+ int found_new_query = 0;
725
+ int query_idx = -1;
726
+ if (warp_thread_idx < work_size) {
727
+ query_idx = query_sorted_idxes[batch_idx__hash_f_idx * num_query + query_offset + warp_thread_idx];
728
+ int slot = set_insert<int>(hashtable_query, hashtable_size, query_idx);
729
+ if (slot >= 0) {
730
+ found_new_query = atomicAdd(&hashtable_count[slot], 1) == 0;
731
+ }
732
+ }
733
+
734
+ // compute cumulative offset
735
+ int position_offset = found_new_query;
736
+ int next_position_offset = 0;
737
+ #pragma unroll
738
+ for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
739
+ next_position_offset = __shfl_up_sync(FULL_MASK, position_offset, offset);
740
+ if (thread_id % WARP_SIZE >= offset) {
741
+ position_offset = position_offset + next_position_offset;
742
+ }
743
+ }
744
+
745
+ // get the inserted query list end index
746
+ int inserted_query_base = 0;
747
+ if (thread_id % WARP_SIZE == WARP_SIZE - 1) {
748
+ inserted_query_base = atomicAdd(query_counter, position_offset);
749
+ }
750
+ inserted_query_base = __shfl_sync(FULL_MASK, inserted_query_base, WARP_SIZE - 1);
751
+
752
+ // insert new queries to list
753
+ int insert_idx = inserted_query_base + position_offset - 1;
754
+ if (found_new_query) {
755
+ inserted_query[insert_idx] = query_idx;
756
+ }
757
+
758
+ // remove inserted queries from list
759
+ query_offset_buffer[hash_f_idx] += work_size;
760
+ query_count_buffer[hash_f_idx] -= work_size;
761
+ query_offset += work_size;
762
+ query_count -= work_size;
763
+
764
+ // if list is almost full, stop inserting
765
+ if (inserted_query_base + OPTIMAL_THREADS_PER_BLOCK > hashtable_size) {
766
+ stop_flag = 1;
767
+ break;
768
+ }
769
+
770
+ }
771
+
772
+ if (stop_flag) {
773
+ break;
774
+ }
775
+
776
+ hash_f_idx_base = hash_f_idx_base + num_warps;
777
+
778
+ }
779
+
780
+ __syncthreads();
781
+
782
+ int num_distint_query = query_counter[0];
783
+
784
+ if (num_distint_query > 0) {
785
+ for (int idx_base = 0; idx_base < num_distint_query; idx_base = idx_base + num_warps) {
786
+ int idx = idx_base + warp_idx;
787
+ if (idx < num_distint_query) {
788
+ int query_idx = inserted_query[idx];
789
+ int batch_idx__query_idx = batch_idx * num_query + query_idx;
790
+
791
+ int slot = set_lookup<int>(hashtable_query, hashtable_size, query_idx);
792
+ int duplicate_count = hashtable_count[slot];
793
+
794
+ float weight = 0;
795
+ for (int weight_idx_base = 0; weight_idx_base < weight_dim; weight_idx_base = weight_idx_base + WARP_SIZE) {
796
+ int weight_dim_idx = weight_idx_base + warp_thread_idx;
797
+ float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
798
+ #pragma unroll
799
+ for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
800
+ val += __shfl_xor_sync(FULL_MASK, val, offset);
801
+ }
802
+ weight = weight + val;
803
+ }
804
+
805
+ weight = (float)duplicate_count * weight / float(num_hash_f);
806
+
807
+ for (int value_idx_base = 0; value_idx_base < value_dim; value_idx_base = value_idx_base + WARP_SIZE) {
808
+ int value_dim_idx = value_idx_base + warp_thread_idx;
809
+ float val = value_buffer[value_dim_idx];
810
+ atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
811
+ }
812
+ }
813
+ }
814
+ } else {
815
+
816
+ // all computation is completed if num_distint_query == 0
817
+ break;
818
+
819
+ }
820
+
821
+ __syncthreads();
822
+
823
+ }
824
+
825
+ }
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_cuda.h ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __global__ void fast_hash_ver1_cuda_kernel(
2
+ int *mask, // [batch_size, num_vector]
3
+ float *vector, // [batch_size, num_vector, vector_dim]
4
+ int *Dmat, // [3, num_part, vector_dim]
5
+ int *hash_code, // [batch_size, num_vector, num_hash_f]
6
+ int batch_size,
7
+ int num_vector,
8
+ int vector_dim,
9
+ int num_part,
10
+ int num_hash_f,
11
+ int hash_code_len
12
+ );
13
+
14
+ __global__ void lsh_cumulation_ver1_step1_cuda_kernel(
15
+ int *key_mask, // [batch_size, num_key]
16
+ int *key_hash_code, // [batch_size, num_key, num_hash_f]
17
+ float *value, // [batch_size, num_key, value_dim]
18
+ float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim]
19
+ int batch_size,
20
+ int num_hash_f,
21
+ int hashtable_capacity,
22
+ int num_key,
23
+ int value_dim,
24
+ int offset_warp
25
+ );
26
+
27
+ __global__ void lsh_cumulation_ver1_step2_cuda_kernel(
28
+ int *query_mask, // [batch_size, num_query]
29
+ int *query_hash_code, // [batch_size, num_query, num_hash_f]
30
+ float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim]
31
+ float *cumulation_value, // [batch_size, num_query, value_dim]
32
+ int batch_size,
33
+ int num_hash_f,
34
+ int hashtable_capacity,
35
+ int num_query,
36
+ int value_dim,
37
+ int offset_warp
38
+ );
39
+
40
+ __global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
41
+ int *key_mask, // [batch_size, num_key]
42
+ int *key_hash_code, // [batch_size, num_key, num_hash_f]
43
+ float *key_weight, // [batch_size, num_key, weight_dim]
44
+ float *value, // [batch_size, num_key, value_dim]
45
+ float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
46
+ int batch_size,
47
+ int num_hash_f,
48
+ int hashtable_capacity,
49
+ int num_key,
50
+ int value_dim,
51
+ int weight_dim,
52
+ int offset_warp,
53
+ int weight_idx
54
+ );
55
+
56
+ __global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
57
+ int *query_mask, // [batch_size, num_query]
58
+ int *query_hash_code, // [batch_size, num_query, num_hash_f]
59
+ float *query_weight, // [batch_size, num_query, weight_dim]
60
+ float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
61
+ float *cumulation_value, // [batch_size, num_query, value_dim]
62
+ int batch_size,
63
+ int num_hash_f,
64
+ int hashtable_capacity,
65
+ int num_query,
66
+ int value_dim,
67
+ int weight_dim,
68
+ int offset_warp,
69
+ int weight_idx
70
+ );
71
+
72
+ __global__ void count_sort_step1_cuda_kernel(
73
+ int *key_mask, // [batch_size, num_key]
74
+ int *key_hash_code, // [batch_size, num_key, num_hash_f]
75
+ int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
76
+ int batch_size,
77
+ int num_hash_f,
78
+ int hashtable_capacity,
79
+ int num_key
80
+ );
81
+
82
+ __global__ void count_sort_step2_cuda_kernel(
83
+ int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
84
+ int batch_size,
85
+ int num_hash_f,
86
+ int hashtable_capacity
87
+ );
88
+
89
+ __global__ void count_sort_step3_cuda_kernel(
90
+ int *key_mask, // [batch_size, num_key]
91
+ int *key_hash_code, // [batch_size, num_key, num_hash_f]
92
+ int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
93
+ int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
94
+ int batch_size,
95
+ int num_hash_f,
96
+ int hashtable_capacity,
97
+ int num_key
98
+ );
99
+
100
+ __global__ void extract_query_info_cuda_kernel(
101
+ int *query_mask, // [batch_size, num_query]
102
+ int *query_hash_code, // [batch_size, num_query, num_hash_f]
103
+ int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
104
+ int *query_info, // [batch_size, num_query, 2, num_hash_f]
105
+ int batch_size,
106
+ int num_hash_f,
107
+ int hashtable_capacity,
108
+ int num_query
109
+ );
110
+
111
+ __global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
112
+ int *query_mask, // [batch_size, num_query]
113
+ int *query_info, // [batch_size, num_query, 2, num_hash_f]
114
+ int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
115
+ float *query_weight, // [batch_size, num_query, weight_dim]
116
+ float *key_weight, // [batch_size, num_key, weight_dim]
117
+ float *value, // [batch_size, num_key, value_dim]
118
+ float *cumulation_value, // [batch_size, num_query, value_dim]
119
+ int batch_size,
120
+ int num_hash_f,
121
+ int num_query,
122
+ int num_key,
123
+ int value_dim,
124
+ int weight_dim
125
+ );
126
+
127
+ __global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
128
+ int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
129
+ int *key_mask, // [batch_size, num_key]
130
+ int *key_info, // [batch_size, num_key, 2, num_hash_f]
131
+ float *query_weight, // [batch_size, num_query, weight_dim]
132
+ float *key_weight, // [batch_size, num_key, weight_dim]
133
+ float *value, // [batch_size, num_key, value_dim]
134
+ float *cumulation_value, // [batch_size, num_query, value_dim]
135
+ int batch_size,
136
+ int num_hash_f,
137
+ int num_query,
138
+ int num_key,
139
+ int value_dim,
140
+ int weight_dim
141
+ );
142
+
143
+ __global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
144
+ int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
145
+ int *key_mask, // [batch_size, num_key]
146
+ int *key_info, // [batch_size, num_key, 2, num_hash_f]
147
+ float *query_weight, // [batch_size, num_query, weight_dim]
148
+ float *key_weight, // [batch_size, num_key, weight_dim]
149
+ float *value, // [batch_size, num_key, value_dim]
150
+ float *cumulation_value, // [batch_size, num_query, value_dim]
151
+ int batch_size,
152
+ int num_hash_f,
153
+ int num_query,
154
+ int num_key,
155
+ int value_dim,
156
+ int weight_dim
157
+ );
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_torch.cpp ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include "fast_lsh_cumulation.h"
4
+ #include "common_cuda.h"
5
+ #include <vector>
6
+
7
+ std::vector<at::Tensor> fast_hash(
8
+ at::Tensor query_mask,
9
+ at::Tensor query_vector,
10
+ at::Tensor key_mask,
11
+ at::Tensor key_vector,
12
+ int num_hash_f,
13
+ int hash_code_len,
14
+ bool use_cuda,
15
+ int version
16
+ ) {
17
+ return fast_hash_ver1_kernel(
18
+ query_mask,
19
+ query_vector,
20
+ key_mask,
21
+ key_vector,
22
+ num_hash_f,
23
+ hash_code_len,
24
+ use_cuda
25
+ );
26
+ }
27
+
28
+ at::Tensor lsh_cumulation(
29
+ at::Tensor query_mask, // [batch_size, num_query]
30
+ at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f]
31
+ at::Tensor key_mask, // [batch_size, num_key]
32
+ at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f]
33
+ at::Tensor value, // [batch_size, num_key, value_dim]
34
+ int hashtable_capacity,
35
+ bool use_cuda,
36
+ int version
37
+ ) {
38
+ return lsh_cumulation_ver1_kernel(
39
+ query_mask,
40
+ query_hash_code,
41
+ key_mask,
42
+ key_hash_code,
43
+ value,
44
+ hashtable_capacity,
45
+ use_cuda
46
+ );
47
+ }
48
+
49
+ at::Tensor lsh_weighted_cumulation(
50
+ at::Tensor query_mask, // [batch_size, num_query]
51
+ at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f]
52
+ at::Tensor query_weight, // [batch_size, num_query, weight_dim]
53
+ at::Tensor key_mask, // [batch_size, num_key]
54
+ at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f]
55
+ at::Tensor key_weight, // [batch_size, num_key, weight_dim]
56
+ at::Tensor value, // [batch_size, num_key, value_dim]
57
+ int hashtable_capacity,
58
+ bool use_cuda,
59
+ int version
60
+ ) {
61
+ if (version == 1) {
62
+ return lsh_weighted_cumulation_ver1_kernel(
63
+ query_mask,
64
+ query_hash_code,
65
+ query_weight,
66
+ key_mask,
67
+ key_hash_code,
68
+ key_weight,
69
+ value,
70
+ hashtable_capacity,
71
+ use_cuda
72
+ );
73
+ } else if (version == 2) {
74
+ return lsh_weighted_cumulation_ver2_kernel(
75
+ query_mask,
76
+ query_hash_code,
77
+ query_weight,
78
+ key_mask,
79
+ key_hash_code,
80
+ key_weight,
81
+ value,
82
+ hashtable_capacity,
83
+ use_cuda
84
+ );
85
+ } else if (version == 3) {
86
+ return lsh_weighted_cumulation_ver3_kernel(
87
+ query_mask,
88
+ query_hash_code,
89
+ query_weight,
90
+ key_mask,
91
+ key_hash_code,
92
+ key_weight,
93
+ value,
94
+ hashtable_capacity,
95
+ use_cuda
96
+ );
97
+ } else if (version == 4) {
98
+ return lsh_weighted_cumulation_ver4_kernel(
99
+ query_mask,
100
+ query_hash_code,
101
+ query_weight,
102
+ key_mask,
103
+ key_hash_code,
104
+ key_weight,
105
+ value,
106
+ hashtable_capacity,
107
+ use_cuda
108
+ );
109
+ } else {
110
+ return lsh_weighted_cumulation_ver3_kernel(
111
+ query_mask,
112
+ query_hash_code,
113
+ query_weight,
114
+ key_mask,
115
+ key_hash_code,
116
+ key_weight,
117
+ value,
118
+ hashtable_capacity,
119
+ use_cuda
120
+ );
121
+ }
122
+ }
123
+
124
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
125
+ m.def("fast_hash", &fast_hash, "Fast Hash (CUDA)");
126
+ m.def("lsh_cumulation", &lsh_cumulation, "LSH Cumulation (CUDA)");
127
+ m.def("lsh_weighted_cumulation", &lsh_weighted_cumulation, "LSH Weighted Cumulation (CUDA)");
128
+ }
.venv/lib/python3.11/site-packages/transformers/modelcard.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Configuration base class and utilities."""
16
+
17
+ import copy
18
+ import json
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional, Union
24
+
25
+ import requests
26
+ import yaml
27
+ from huggingface_hub import model_info
28
+ from huggingface_hub.utils import HFValidationError
29
+
30
+ from . import __version__
31
+ from .models.auto.modeling_auto import (
32
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
33
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
34
+ MODEL_FOR_CTC_MAPPING_NAMES,
35
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
36
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
37
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES,
38
+ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
39
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
40
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
41
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
42
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
43
+ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
44
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
45
+ MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
46
+ )
47
+ from .training_args import ParallelMode
48
+ from .utils import (
49
+ MODEL_CARD_NAME,
50
+ cached_file,
51
+ is_datasets_available,
52
+ is_offline_mode,
53
+ is_tf_available,
54
+ is_tokenizers_available,
55
+ is_torch_available,
56
+ logging,
57
+ )
58
+
59
+
60
+ TASK_MAPPING = {
61
+ "text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
62
+ "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
63
+ "image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
64
+ "fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
65
+ "object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
66
+ "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
67
+ "text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
68
+ "text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
69
+ "table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
70
+ "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
71
+ "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
72
+ "automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
73
+ "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
74
+ }
75
+
76
+ logger = logging.get_logger(__name__)
77
+
78
+
79
+ class ModelCard:
80
+ r"""
81
+ Structured Model Card class. Store model card as well as methods for loading/downloading/saving model cards.
82
+
83
+ Please read the following paper for details and explanation on the sections: "Model Cards for Model Reporting" by
84
+ Margaret Mitchell, Simone Wu, Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
85
+ Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards. Link: https://arxiv.org/abs/1810.03993
86
+
87
+ Note: A model card can be loaded and saved to disk.
88
+ """
89
+
90
+ def __init__(self, **kwargs):
91
+ warnings.warn(
92
+ "The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning
93
+ )
94
+ # Recommended attributes from https://arxiv.org/abs/1810.03993 (see papers)
95
+ self.model_details = kwargs.pop("model_details", {})
96
+ self.intended_use = kwargs.pop("intended_use", {})
97
+ self.factors = kwargs.pop("factors", {})
98
+ self.metrics = kwargs.pop("metrics", {})
99
+ self.evaluation_data = kwargs.pop("evaluation_data", {})
100
+ self.training_data = kwargs.pop("training_data", {})
101
+ self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
102
+ self.ethical_considerations = kwargs.pop("ethical_considerations", {})
103
+ self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
104
+
105
+ # Open additional attributes
106
+ for key, value in kwargs.items():
107
+ try:
108
+ setattr(self, key, value)
109
+ except AttributeError as err:
110
+ logger.error(f"Can't set {key} with value {value} for {self}")
111
+ raise err
112
+
113
+ def save_pretrained(self, save_directory_or_file):
114
+ """Save a model card object to the directory or file `save_directory_or_file`."""
115
+ if os.path.isdir(save_directory_or_file):
116
+ # If we save using the predefined names, we can load using `from_pretrained`
117
+ output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
118
+ else:
119
+ output_model_card_file = save_directory_or_file
120
+
121
+ self.to_json_file(output_model_card_file)
122
+ logger.info(f"Model card saved in {output_model_card_file}")
123
+
124
+ @classmethod
125
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
126
+ r"""
127
+ Instantiate a [`ModelCard`] from a pre-trained model model card.
128
+
129
+ Parameters:
130
+ pretrained_model_name_or_path: either:
131
+
132
+ - a string, the *model id* of a pretrained model card hosted inside a model repo on huggingface.co.
133
+ - a path to a *directory* containing a model card file saved using the [`~ModelCard.save_pretrained`]
134
+ method, e.g.: `./my_model_directory/`.
135
+ - a path or url to a saved model card JSON *file*, e.g.: `./my_model_directory/modelcard.json`.
136
+
137
+ cache_dir: (*optional*) string:
138
+ Path to a directory in which a downloaded pre-trained model card should be cached if the standard cache
139
+ should not be used.
140
+
141
+ kwargs: (*optional*) dict: key/value pairs with which to update the ModelCard object after loading.
142
+
143
+ - The values in kwargs of any keys which are model card attributes will be used to override the loaded
144
+ values.
145
+ - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the
146
+ *return_unused_kwargs* keyword parameter.
147
+
148
+ proxies: (*optional*) dict, default None:
149
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
150
+ 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
151
+
152
+ return_unused_kwargs: (*optional*) bool:
153
+
154
+ - If False, then this function returns just the final model card object.
155
+ - If True, then this functions returns a tuple *(model card, unused_kwargs)* where *unused_kwargs* is a
156
+ dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of
157
+ kwargs which has not been used to update *ModelCard* and is otherwise ignored.
158
+
159
+ Examples:
160
+
161
+ ```python
162
+ # Download model card from huggingface.co and cache.
163
+ modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased")
164
+ # Model card was saved using *save_pretrained('./test/saved_model/')*
165
+ modelcard = ModelCard.from_pretrained("./test/saved_model/")
166
+ modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
167
+ modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
168
+ ```"""
169
+ cache_dir = kwargs.pop("cache_dir", None)
170
+ proxies = kwargs.pop("proxies", None)
171
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
172
+ from_pipeline = kwargs.pop("_from_pipeline", None)
173
+
174
+ user_agent = {"file_type": "model_card"}
175
+ if from_pipeline is not None:
176
+ user_agent["using_pipeline"] = from_pipeline
177
+
178
+ is_local = os.path.isdir(pretrained_model_name_or_path)
179
+ if os.path.isfile(pretrained_model_name_or_path):
180
+ resolved_model_card_file = pretrained_model_name_or_path
181
+ is_local = True
182
+ else:
183
+ try:
184
+ # Load from URL or cache if already cached
185
+ resolved_model_card_file = cached_file(
186
+ pretrained_model_name_or_path,
187
+ filename=MODEL_CARD_NAME,
188
+ cache_dir=cache_dir,
189
+ proxies=proxies,
190
+ user_agent=user_agent,
191
+ )
192
+ if is_local:
193
+ logger.info(f"loading model card file {resolved_model_card_file}")
194
+ else:
195
+ logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
196
+ # Load model card
197
+ modelcard = cls.from_json_file(resolved_model_card_file)
198
+
199
+ except (EnvironmentError, json.JSONDecodeError):
200
+ # We fall back on creating an empty model card
201
+ modelcard = cls()
202
+
203
+ # Update model card with kwargs if needed
204
+ to_remove = []
205
+ for key, value in kwargs.items():
206
+ if hasattr(modelcard, key):
207
+ setattr(modelcard, key, value)
208
+ to_remove.append(key)
209
+ for key in to_remove:
210
+ kwargs.pop(key, None)
211
+
212
+ logger.info(f"Model card: {modelcard}")
213
+ if return_unused_kwargs:
214
+ return modelcard, kwargs
215
+ else:
216
+ return modelcard
217
+
218
+ @classmethod
219
+ def from_dict(cls, json_object):
220
+ """Constructs a `ModelCard` from a Python dictionary of parameters."""
221
+ return cls(**json_object)
222
+
223
+ @classmethod
224
+ def from_json_file(cls, json_file):
225
+ """Constructs a `ModelCard` from a json file of parameters."""
226
+ with open(json_file, "r", encoding="utf-8") as reader:
227
+ text = reader.read()
228
+ dict_obj = json.loads(text)
229
+ return cls(**dict_obj)
230
+
231
+ def __eq__(self, other):
232
+ return self.__dict__ == other.__dict__
233
+
234
+ def __repr__(self):
235
+ return str(self.to_json_string())
236
+
237
+ def to_dict(self):
238
+ """Serializes this instance to a Python dictionary."""
239
+ output = copy.deepcopy(self.__dict__)
240
+ return output
241
+
242
+ def to_json_string(self):
243
+ """Serializes this instance to a JSON string."""
244
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
245
+
246
+ def to_json_file(self, json_file_path):
247
+ """Save this instance to a json file."""
248
+ with open(json_file_path, "w", encoding="utf-8") as writer:
249
+ writer.write(self.to_json_string())
250
+
251
+
252
+ AUTOGENERATED_TRAINER_COMMENT = """
253
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
254
+ should probably proofread and complete it, then remove this comment. -->
255
+ """
256
+
257
+ AUTOGENERATED_KERAS_COMMENT = """
258
+ <!-- This model card has been generated automatically according to the information Keras had access to. You should
259
+ probably proofread and complete it, then remove this comment. -->
260
+ """
261
+
262
+
263
+ TASK_TAG_TO_NAME_MAPPING = {
264
+ "fill-mask": "Masked Language Modeling",
265
+ "image-classification": "Image Classification",
266
+ "image-segmentation": "Image Segmentation",
267
+ "multiple-choice": "Multiple Choice",
268
+ "object-detection": "Object Detection",
269
+ "question-answering": "Question Answering",
270
+ "summarization": "Summarization",
271
+ "table-question-answering": "Table Question Answering",
272
+ "text-classification": "Text Classification",
273
+ "text-generation": "Causal Language Modeling",
274
+ "text2text-generation": "Sequence-to-sequence Language Modeling",
275
+ "token-classification": "Token Classification",
276
+ "translation": "Translation",
277
+ "zero-shot-classification": "Zero Shot Classification",
278
+ "automatic-speech-recognition": "Automatic Speech Recognition",
279
+ "audio-classification": "Audio Classification",
280
+ }
281
+
282
+
283
+ METRIC_TAGS = [
284
+ "accuracy",
285
+ "bleu",
286
+ "f1",
287
+ "matthews_correlation",
288
+ "pearsonr",
289
+ "precision",
290
+ "recall",
291
+ "rouge",
292
+ "sacrebleu",
293
+ "spearmanr",
294
+ "wer",
295
+ ]
296
+
297
+
298
+ def _listify(obj):
299
+ if obj is None:
300
+ return []
301
+ elif isinstance(obj, str):
302
+ return [obj]
303
+ else:
304
+ return obj
305
+
306
+
307
+ def _insert_values_as_list(metadata, name, values):
308
+ if values is None:
309
+ return metadata
310
+ if isinstance(values, str):
311
+ values = [values]
312
+ values = [v for v in values if v is not None]
313
+ if len(values) == 0:
314
+ return metadata
315
+ metadata[name] = values
316
+ return metadata
317
+
318
+
319
+ def infer_metric_tags_from_eval_results(eval_results):
320
+ if eval_results is None:
321
+ return {}
322
+ result = {}
323
+ for key in eval_results.keys():
324
+ if key.lower().replace(" ", "_") in METRIC_TAGS:
325
+ result[key.lower().replace(" ", "_")] = key
326
+ elif key.lower() == "rouge1":
327
+ result["rouge"] = key
328
+ return result
329
+
330
+
331
+ def _insert_value(metadata, name, value):
332
+ if value is None:
333
+ return metadata
334
+ metadata[name] = value
335
+ return metadata
336
+
337
+
338
+ def is_hf_dataset(dataset):
339
+ if not is_datasets_available():
340
+ return False
341
+
342
+ from datasets import Dataset, IterableDataset
343
+
344
+ return isinstance(dataset, (Dataset, IterableDataset))
345
+
346
+
347
+ def _get_mapping_values(mapping):
348
+ result = []
349
+ for v in mapping.values():
350
+ if isinstance(v, (tuple, list)):
351
+ result += list(v)
352
+ else:
353
+ result.append(v)
354
+ return result
355
+
356
+
357
+ @dataclass
358
+ class TrainingSummary:
359
+ model_name: str
360
+ language: Optional[Union[str, List[str]]] = None
361
+ license: Optional[str] = None
362
+ tags: Optional[Union[str, List[str]]] = None
363
+ finetuned_from: Optional[str] = None
364
+ tasks: Optional[Union[str, List[str]]] = None
365
+ dataset: Optional[Union[str, List[str]]] = None
366
+ dataset_tags: Optional[Union[str, List[str]]] = None
367
+ dataset_args: Optional[Union[str, List[str]]] = None
368
+ dataset_metadata: Optional[Dict[str, Any]] = None
369
+ eval_results: Optional[Dict[str, float]] = None
370
+ eval_lines: Optional[List[str]] = None
371
+ hyperparameters: Optional[Dict[str, Any]] = None
372
+ source: Optional[str] = "trainer"
373
+
374
+ def __post_init__(self):
375
+ # Infer default license from the checkpoint used, if possible.
376
+ if (
377
+ self.license is None
378
+ and not is_offline_mode()
379
+ and self.finetuned_from is not None
380
+ and len(self.finetuned_from) > 0
381
+ ):
382
+ try:
383
+ info = model_info(self.finetuned_from)
384
+ for tag in info.tags:
385
+ if tag.startswith("license:"):
386
+ self.license = tag[8:]
387
+ except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError, HFValidationError):
388
+ pass
389
+
390
+ def create_model_index(self, metric_mapping):
391
+ model_index = {"name": self.model_name}
392
+
393
+ # Dataset mapping tag -> name
394
+ dataset_names = _listify(self.dataset)
395
+ dataset_tags = _listify(self.dataset_tags)
396
+ dataset_args = _listify(self.dataset_args)
397
+ dataset_metadata = _listify(self.dataset_metadata)
398
+ if len(dataset_args) < len(dataset_tags):
399
+ dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
400
+ dataset_mapping = dict(zip(dataset_tags, dataset_names))
401
+ dataset_arg_mapping = dict(zip(dataset_tags, dataset_args))
402
+ dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata))
403
+
404
+ task_mapping = {
405
+ task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
406
+ }
407
+
408
+ model_index["results"] = []
409
+
410
+ if len(task_mapping) == 0 and len(dataset_mapping) == 0:
411
+ return [model_index]
412
+ if len(task_mapping) == 0:
413
+ task_mapping = {None: None}
414
+ if len(dataset_mapping) == 0:
415
+ dataset_mapping = {None: None}
416
+
417
+ # One entry per dataset and per task
418
+ all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
419
+ for task_tag, ds_tag in all_possibilities:
420
+ result = {}
421
+ if task_tag is not None:
422
+ result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
423
+
424
+ if ds_tag is not None:
425
+ metadata = dataset_metadata_mapping.get(ds_tag, {})
426
+ result["dataset"] = {
427
+ "name": dataset_mapping[ds_tag],
428
+ "type": ds_tag,
429
+ **metadata,
430
+ }
431
+ if dataset_arg_mapping[ds_tag] is not None:
432
+ result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
433
+
434
+ if len(metric_mapping) > 0:
435
+ result["metrics"] = []
436
+ for metric_tag, metric_name in metric_mapping.items():
437
+ result["metrics"].append(
438
+ {
439
+ "name": metric_name,
440
+ "type": metric_tag,
441
+ "value": self.eval_results[metric_name],
442
+ }
443
+ )
444
+
445
+ # Remove partial results to avoid the model card being rejected.
446
+ if "task" in result and "dataset" in result and "metrics" in result:
447
+ model_index["results"].append(result)
448
+ else:
449
+ logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}")
450
+
451
+ return [model_index]
452
+
453
+ def create_metadata(self):
454
+ metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)
455
+
456
+ metadata = {}
457
+ metadata = _insert_value(metadata, "library_name", "transformers")
458
+ metadata = _insert_values_as_list(metadata, "language", self.language)
459
+ metadata = _insert_value(metadata, "license", self.license)
460
+ if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0:
461
+ metadata = _insert_value(metadata, "base_model", self.finetuned_from)
462
+ metadata = _insert_values_as_list(metadata, "tags", self.tags)
463
+ metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags)
464
+ metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys()))
465
+ metadata["model-index"] = self.create_model_index(metric_mapping)
466
+
467
+ return metadata
468
+
469
+ def to_model_card(self):
470
+ model_card = ""
471
+
472
+ metadata = yaml.dump(self.create_metadata(), sort_keys=False)
473
+ if len(metadata) > 0:
474
+ model_card = f"---\n{metadata}---\n"
475
+
476
+ # Now the model card for realsies.
477
+ if self.source == "trainer":
478
+ model_card += AUTOGENERATED_TRAINER_COMMENT
479
+ else:
480
+ model_card += AUTOGENERATED_KERAS_COMMENT
481
+
482
+ model_card += f"\n# {self.model_name}\n\n"
483
+
484
+ if self.finetuned_from is None:
485
+ model_card += "This model was trained from scratch on "
486
+ else:
487
+ model_card += (
488
+ "This model is a fine-tuned version of"
489
+ f" [{self.finetuned_from}](https://huggingface.co/{self.finetuned_from}) on "
490
+ )
491
+
492
+ if self.dataset is None:
493
+ model_card += "an unknown dataset."
494
+ else:
495
+ if isinstance(self.dataset, str):
496
+ model_card += f"the {self.dataset} dataset."
497
+ elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1:
498
+ model_card += f"the {self.dataset[0]} dataset."
499
+ else:
500
+ model_card += (
501
+ ", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets."
502
+ )
503
+
504
+ if self.eval_results is not None:
505
+ model_card += "\nIt achieves the following results on the evaluation set:\n"
506
+ model_card += "\n".join([f"- {name}: {_maybe_round(value)}" for name, value in self.eval_results.items()])
507
+ model_card += "\n"
508
+
509
+ model_card += "\n## Model description\n\nMore information needed\n"
510
+ model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
511
+ model_card += "\n## Training and evaluation data\n\nMore information needed\n"
512
+
513
+ model_card += "\n## Training procedure\n"
514
+ model_card += "\n### Training hyperparameters\n"
515
+ if self.hyperparameters is not None:
516
+ model_card += "\nThe following hyperparameters were used during training:\n"
517
+ model_card += "\n".join([f"- {name}: {value}" for name, value in self.hyperparameters.items()])
518
+ model_card += "\n"
519
+ else:
520
+ model_card += "\nMore information needed\n"
521
+
522
+ if self.eval_lines is not None:
523
+ model_card += "\n### Training results\n\n"
524
+ model_card += make_markdown_table(self.eval_lines)
525
+ model_card += "\n"
526
+
527
+ model_card += "\n### Framework versions\n\n"
528
+ model_card += f"- Transformers {__version__}\n"
529
+
530
+ if self.source == "trainer" and is_torch_available():
531
+ import torch
532
+
533
+ model_card += f"- Pytorch {torch.__version__}\n"
534
+ elif self.source == "keras" and is_tf_available():
535
+ import tensorflow as tf
536
+
537
+ model_card += f"- TensorFlow {tf.__version__}\n"
538
+ if is_datasets_available():
539
+ import datasets
540
+
541
+ model_card += f"- Datasets {datasets.__version__}\n"
542
+ if is_tokenizers_available():
543
+ import tokenizers
544
+
545
+ model_card += f"- Tokenizers {tokenizers.__version__}\n"
546
+
547
+ return model_card
548
+
549
+ @classmethod
550
+ def from_trainer(
551
+ cls,
552
+ trainer,
553
+ language=None,
554
+ license=None,
555
+ tags=None,
556
+ model_name=None,
557
+ finetuned_from=None,
558
+ tasks=None,
559
+ dataset_tags=None,
560
+ dataset_metadata=None,
561
+ dataset=None,
562
+ dataset_args=None,
563
+ ):
564
+ # Infer default from dataset
565
+ one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset
566
+ if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None):
567
+ default_tag = one_dataset.builder_name
568
+ # Those are not real datasets from the Hub so we exclude them.
569
+ if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
570
+ if dataset_metadata is None:
571
+ dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
572
+ if dataset_tags is None:
573
+ dataset_tags = [default_tag]
574
+ if dataset_args is None:
575
+ dataset_args = [one_dataset.config_name]
576
+
577
+ if dataset is None and dataset_tags is not None:
578
+ dataset = dataset_tags
579
+
580
+ # Infer default finetuned_from
581
+ if (
582
+ finetuned_from is None
583
+ and hasattr(trainer.model.config, "_name_or_path")
584
+ and not os.path.isdir(trainer.model.config._name_or_path)
585
+ ):
586
+ finetuned_from = trainer.model.config._name_or_path
587
+
588
+ # Infer default task tag:
589
+ if tasks is None:
590
+ model_class_name = trainer.model.__class__.__name__
591
+ for task, mapping in TASK_MAPPING.items():
592
+ if model_class_name in _get_mapping_values(mapping):
593
+ tasks = task
594
+
595
+ if model_name is None:
596
+ model_name = Path(trainer.args.output_dir).name
597
+ if len(model_name) == 0:
598
+ model_name = finetuned_from
599
+
600
+ # Add `generated_from_trainer` to the tags
601
+ if tags is None:
602
+ tags = ["generated_from_trainer"]
603
+ elif isinstance(tags, str) and tags != "generated_from_trainer":
604
+ tags = [tags, "generated_from_trainer"]
605
+ elif "generated_from_trainer" not in tags:
606
+ tags.append("generated_from_trainer")
607
+
608
+ _, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
609
+ hyperparameters = extract_hyperparameters_from_trainer(trainer)
610
+
611
+ return cls(
612
+ language=language,
613
+ license=license,
614
+ tags=tags,
615
+ model_name=model_name,
616
+ finetuned_from=finetuned_from,
617
+ tasks=tasks,
618
+ dataset=dataset,
619
+ dataset_tags=dataset_tags,
620
+ dataset_args=dataset_args,
621
+ dataset_metadata=dataset_metadata,
622
+ eval_results=eval_results,
623
+ eval_lines=eval_lines,
624
+ hyperparameters=hyperparameters,
625
+ )
626
+
627
+ @classmethod
628
+ def from_keras(
629
+ cls,
630
+ model,
631
+ model_name,
632
+ keras_history=None,
633
+ language=None,
634
+ license=None,
635
+ tags=None,
636
+ finetuned_from=None,
637
+ tasks=None,
638
+ dataset_tags=None,
639
+ dataset=None,
640
+ dataset_args=None,
641
+ ):
642
+ # Infer default from dataset
643
+ if dataset is not None:
644
+ if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
645
+ default_tag = dataset.builder_name
646
+ # Those are not real datasets from the Hub so we exclude them.
647
+ if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
648
+ if dataset_tags is None:
649
+ dataset_tags = [default_tag]
650
+ if dataset_args is None:
651
+ dataset_args = [dataset.config_name]
652
+
653
+ if dataset is None and dataset_tags is not None:
654
+ dataset = dataset_tags
655
+
656
+ # Infer default finetuned_from
657
+ if (
658
+ finetuned_from is None
659
+ and hasattr(model.config, "_name_or_path")
660
+ and not os.path.isdir(model.config._name_or_path)
661
+ ):
662
+ finetuned_from = model.config._name_or_path
663
+
664
+ # Infer default task tag:
665
+ if tasks is None:
666
+ model_class_name = model.__class__.__name__
667
+ for task, mapping in TASK_MAPPING.items():
668
+ if model_class_name in _get_mapping_values(mapping):
669
+ tasks = task
670
+
671
+ # Add `generated_from_keras_callback` to the tags
672
+ if tags is None:
673
+ tags = ["generated_from_keras_callback"]
674
+ elif isinstance(tags, str) and tags != "generated_from_keras_callback":
675
+ tags = [tags, "generated_from_keras_callback"]
676
+ elif "generated_from_keras_callback" not in tags:
677
+ tags.append("generated_from_keras_callback")
678
+
679
+ if keras_history is not None:
680
+ _, eval_lines, eval_results = parse_keras_history(keras_history)
681
+ else:
682
+ eval_lines = []
683
+ eval_results = {}
684
+ hyperparameters = extract_hyperparameters_from_keras(model)
685
+
686
+ return cls(
687
+ language=language,
688
+ license=license,
689
+ tags=tags,
690
+ model_name=model_name,
691
+ finetuned_from=finetuned_from,
692
+ tasks=tasks,
693
+ dataset_tags=dataset_tags,
694
+ dataset=dataset,
695
+ dataset_args=dataset_args,
696
+ eval_results=eval_results,
697
+ eval_lines=eval_lines,
698
+ hyperparameters=hyperparameters,
699
+ source="keras",
700
+ )
701
+
702
+
703
+ def parse_keras_history(logs):
704
+ """
705
+ Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict`
706
+ passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`.
707
+ """
708
+ if hasattr(logs, "history"):
709
+ # This looks like a `History` object
710
+ if not hasattr(logs, "epoch"):
711
+ # This history looks empty, return empty results
712
+ return None, [], {}
713
+ logs.history["epoch"] = logs.epoch
714
+ logs = logs.history
715
+ else:
716
+ # Training logs is a list of dicts, let's invert it to a dict of lists to match a History object
717
+ logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
718
+
719
+ lines = []
720
+ for i in range(len(logs["epoch"])):
721
+ epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
722
+ values = {}
723
+ for k, v in epoch_dict.items():
724
+ if k.startswith("val_"):
725
+ k = "validation_" + k[4:]
726
+ elif k != "epoch":
727
+ k = "train_" + k
728
+ splits = k.split("_")
729
+ name = " ".join([part.capitalize() for part in splits])
730
+ values[name] = v
731
+ lines.append(values)
732
+
733
+ eval_results = lines[-1]
734
+
735
+ return logs, lines, eval_results
736
+
737
+
738
+ def parse_log_history(log_history):
739
+ """
740
+ Parse the `log_history` of a Trainer to get the intermediate and final evaluation results.
741
+ """
742
+ idx = 0
743
+ while idx < len(log_history) and "train_runtime" not in log_history[idx]:
744
+ idx += 1
745
+
746
+ # If there are no training logs
747
+ if idx == len(log_history):
748
+ idx -= 1
749
+ while idx >= 0 and "eval_loss" not in log_history[idx]:
750
+ idx -= 1
751
+
752
+ if idx >= 0:
753
+ return None, None, log_history[idx]
754
+ else:
755
+ return None, None, None
756
+
757
+ # From now one we can assume we have training logs:
758
+ train_log = log_history[idx]
759
+ lines = []
760
+ training_loss = "No log"
761
+ for i in range(idx):
762
+ if "loss" in log_history[i]:
763
+ training_loss = log_history[i]["loss"]
764
+ if "eval_loss" in log_history[i]:
765
+ metrics = log_history[i].copy()
766
+ _ = metrics.pop("total_flos", None)
767
+ epoch = metrics.pop("epoch", None)
768
+ step = metrics.pop("step", None)
769
+ _ = metrics.pop("eval_runtime", None)
770
+ _ = metrics.pop("eval_samples_per_second", None)
771
+ _ = metrics.pop("eval_steps_per_second", None)
772
+ _ = metrics.pop("eval_jit_compilation_time", None)
773
+ values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
774
+ for k, v in metrics.items():
775
+ if k == "eval_loss":
776
+ values["Validation Loss"] = v
777
+ else:
778
+ splits = k.split("_")
779
+ name = " ".join([part.capitalize() for part in splits[1:]])
780
+ values[name] = v
781
+ lines.append(values)
782
+
783
+ idx = len(log_history) - 1
784
+ while idx >= 0 and "eval_loss" not in log_history[idx]:
785
+ idx -= 1
786
+
787
+ if idx > 0:
788
+ eval_results = {}
789
+ for key, value in log_history[idx].items():
790
+ if key.startswith("eval_"):
791
+ key = key[5:]
792
+ if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
793
+ camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
794
+ eval_results[camel_cased_key] = value
795
+ return train_log, lines, eval_results
796
+ else:
797
+ return train_log, lines, None
798
+
799
+
800
+ def extract_hyperparameters_from_keras(model):
801
+ from .modeling_tf_utils import keras
802
+
803
+ hyperparameters = {}
804
+ if hasattr(model, "optimizer") and model.optimizer is not None:
805
+ hyperparameters["optimizer"] = model.optimizer.get_config()
806
+ else:
807
+ hyperparameters["optimizer"] = None
808
+ hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name
809
+
810
+ return hyperparameters
811
+
812
+
813
+ def _maybe_round(v, decimals=4):
814
+ if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
815
+ return f"{v:.{decimals}f}"
816
+ return str(v)
817
+
818
+
819
+ def _regular_table_line(values, col_widths):
820
+ values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
821
+ return "".join(values_with_space) + "|\n"
822
+
823
+
824
+ def _second_table_line(col_widths):
825
+ values = ["|:" + "-" * w + ":" for w in col_widths]
826
+ return "".join(values) + "|\n"
827
+
828
+
829
+ def make_markdown_table(lines):
830
+ """
831
+ Create a nice Markdown table from the results in `lines`.
832
+ """
833
+ if lines is None or len(lines) == 0:
834
+ return ""
835
+ col_widths = {key: len(str(key)) for key in lines[0].keys()}
836
+ for line in lines:
837
+ for key, value in line.items():
838
+ if col_widths[key] < len(_maybe_round(value)):
839
+ col_widths[key] = len(_maybe_round(value))
840
+
841
+ table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
842
+ table += _second_table_line(list(col_widths.values()))
843
+ for line in lines:
844
+ table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
845
+ return table
846
+
847
+
848
+ _TRAINING_ARGS_KEYS = [
849
+ "learning_rate",
850
+ "train_batch_size",
851
+ "eval_batch_size",
852
+ "seed",
853
+ ]
854
+
855
+
856
+ def extract_hyperparameters_from_trainer(trainer):
857
+ hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}
858
+
859
+ if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
860
+ hyperparameters["distributed_type"] = (
861
+ "multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
862
+ )
863
+ if trainer.args.world_size > 1:
864
+ hyperparameters["num_devices"] = trainer.args.world_size
865
+ if trainer.args.gradient_accumulation_steps > 1:
866
+ hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps
867
+
868
+ total_train_batch_size = (
869
+ trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
870
+ )
871
+ if total_train_batch_size != hyperparameters["train_batch_size"]:
872
+ hyperparameters["total_train_batch_size"] = total_train_batch_size
873
+ total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
874
+ if total_eval_batch_size != hyperparameters["eval_batch_size"]:
875
+ hyperparameters["total_eval_batch_size"] = total_eval_batch_size
876
+
877
+ if trainer.args.optim:
878
+ optimizer_name = trainer.args.optim
879
+ optimizer_args = trainer.args.optim_args if trainer.args.optim_args else "No additional optimizer arguments"
880
+
881
+ if "adam" in optimizer_name.lower():
882
+ hyperparameters["optimizer"] = (
883
+ f"Use {optimizer_name} with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and"
884
+ f" epsilon={trainer.args.adam_epsilon} and optimizer_args={optimizer_args}"
885
+ )
886
+ else:
887
+ hyperparameters["optimizer"] = f"Use {optimizer_name} and the args are:\n{optimizer_args}"
888
+
889
+ hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
890
+ if trainer.args.warmup_ratio != 0.0:
891
+ hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
892
+ if trainer.args.warmup_steps != 0.0:
893
+ hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
894
+ if trainer.args.max_steps != -1:
895
+ hyperparameters["training_steps"] = trainer.args.max_steps
896
+ else:
897
+ hyperparameters["num_epochs"] = trainer.args.num_train_epochs
898
+
899
+ if trainer.args.fp16:
900
+ if trainer.use_apex:
901
+ hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
902
+ else:
903
+ hyperparameters["mixed_precision_training"] = "Native AMP"
904
+
905
+ if trainer.args.label_smoothing_factor != 0.0:
906
+ hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
907
+
908
+ return hyperparameters
.venv/lib/python3.11/site-packages/transformers/modeling_attn_mask_utils.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+ from .utils.import_utils import is_torchdynamo_compiling
20
+
21
+
22
+ @dataclass
23
+ class AttentionMaskConverter:
24
+ """
25
+ A utility attention mask class that allows one to:
26
+ - Create a causal 4d mask
27
+ - Create a causal 4d mask with slided window
28
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
29
+ key_value_length) that can be multiplied with attention scores
30
+
31
+ Examples:
32
+
33
+ ```python
34
+ >>> import torch
35
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
+
37
+ >>> converter = AttentionMaskConverter(True)
38
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
39
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
40
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
41
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
42
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
43
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
44
+ ```
45
+
46
+ Parameters:
47
+ is_causal (`bool`):
48
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
49
+
50
+ sliding_window (`int`, *optional*):
51
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
52
+ """
53
+
54
+ is_causal: bool
55
+ sliding_window: int
56
+
57
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
58
+ self.is_causal = is_causal
59
+ self.sliding_window = sliding_window
60
+
61
+ if self.sliding_window is not None and self.sliding_window <= 0:
62
+ raise ValueError(
63
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
64
+ )
65
+
66
+ def to_causal_4d(
67
+ self,
68
+ batch_size: int,
69
+ query_length: int,
70
+ key_value_length: int,
71
+ dtype: torch.dtype,
72
+ device: Union[torch.device, "str"] = "cpu",
73
+ ) -> Optional[torch.Tensor]:
74
+ """
75
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
76
+ bias to upper right hand triangular matrix (causal mask).
77
+ """
78
+ if not self.is_causal:
79
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
80
+
81
+ # If shape is not cached, create a new causal mask and cache it
82
+ input_shape = (batch_size, query_length)
83
+ past_key_values_length = key_value_length - query_length
84
+
85
+ # create causal mask
86
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
87
+ causal_4d_mask = None
88
+ if input_shape[-1] > 1 or self.sliding_window is not None:
89
+ causal_4d_mask = self._make_causal_mask(
90
+ input_shape,
91
+ dtype,
92
+ device=device,
93
+ past_key_values_length=past_key_values_length,
94
+ sliding_window=self.sliding_window,
95
+ )
96
+
97
+ return causal_4d_mask
98
+
99
+ def to_4d(
100
+ self,
101
+ attention_mask_2d: torch.Tensor,
102
+ query_length: int,
103
+ dtype: torch.dtype,
104
+ key_value_length: Optional[int] = None,
105
+ ) -> torch.Tensor:
106
+ """
107
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
108
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
109
+ causal, a causal mask will be added.
110
+ """
111
+ input_shape = (attention_mask_2d.shape[0], query_length)
112
+
113
+ # create causal mask
114
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
115
+ causal_4d_mask = None
116
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
117
+ if key_value_length is None:
118
+ raise ValueError(
119
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
120
+ )
121
+
122
+ past_key_values_length = key_value_length - query_length
123
+ causal_4d_mask = self._make_causal_mask(
124
+ input_shape,
125
+ dtype,
126
+ device=attention_mask_2d.device,
127
+ past_key_values_length=past_key_values_length,
128
+ sliding_window=self.sliding_window,
129
+ )
130
+ elif self.sliding_window is not None:
131
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
132
+
133
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
134
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
135
+ attention_mask_2d.device
136
+ )
137
+
138
+ if causal_4d_mask is not None:
139
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
140
+
141
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
142
+ expanded_4d_mask = expanded_attn_mask
143
+
144
+ return expanded_4d_mask
145
+
146
+ @staticmethod
147
+ def _make_causal_mask(
148
+ input_ids_shape: torch.Size,
149
+ dtype: torch.dtype,
150
+ device: torch.device,
151
+ past_key_values_length: int = 0,
152
+ sliding_window: Optional[int] = None,
153
+ ):
154
+ """
155
+ Make causal mask used for bi-directional self-attention.
156
+ """
157
+ bsz, tgt_len = input_ids_shape
158
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
159
+ mask_cond = torch.arange(mask.size(-1), device=device)
160
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
161
+
162
+ mask = mask.to(dtype)
163
+
164
+ if past_key_values_length > 0:
165
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
166
+
167
+ # add lower triangular sliding window mask if necessary
168
+ if sliding_window is not None:
169
+ diagonal = past_key_values_length - sliding_window - 1
170
+
171
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
172
+ # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
173
+ # See https://github.com/pytorch/pytorch/issues/127571
174
+ if is_torchdynamo_compiling():
175
+ mask = mask.clone()
176
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
177
+
178
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
179
+
180
+ @staticmethod
181
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
182
+ """
183
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
184
+ """
185
+ bsz, src_len = mask.size()
186
+ tgt_len = tgt_len if tgt_len is not None else src_len
187
+
188
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
189
+
190
+ inverted_mask = 1.0 - expanded_mask
191
+
192
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
193
+
194
+ @staticmethod
195
+ def _unmask_unattended(
196
+ expanded_mask: torch.FloatTensor,
197
+ min_dtype: float,
198
+ ):
199
+ # fmt: off
200
+ """
201
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
202
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
203
+ Details: https://github.com/pytorch/pytorch/issues/110213
204
+
205
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
206
+ `attention_mask` is [bsz, src_seq_len].
207
+
208
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
209
+
210
+ For example, if `expanded_mask` is (e.g. here left-padding case)
211
+ ```
212
+ [[[[0, 0, 0],
213
+ [0, 0, 0],
214
+ [0, 0, 1]]],
215
+ [[[1, 0, 0],
216
+ [1, 1, 0],
217
+ [1, 1, 1]]],
218
+ [[[0, 0, 0],
219
+ [0, 1, 0],
220
+ [0, 1, 1]]]]
221
+ ```
222
+ then the modified `expanded_mask` will be
223
+ ```
224
+ [[[[1, 1, 1], <-- modified
225
+ [1, 1, 1], <-- modified
226
+ [0, 0, 1]]],
227
+ [[[1, 0, 0],
228
+ [1, 1, 0],
229
+ [1, 1, 1]]],
230
+ [[[1, 1, 1], <-- modified
231
+ [0, 1, 0],
232
+ [0, 1, 1]]]]
233
+ ```
234
+ """
235
+ # fmt: on
236
+ if expanded_mask.dtype == torch.bool:
237
+ raise ValueError(
238
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
239
+ )
240
+
241
+ return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
242
+
243
+ @staticmethod
244
+ def _ignore_causal_mask_sdpa(
245
+ attention_mask: Optional[torch.Tensor],
246
+ inputs_embeds: torch.Tensor,
247
+ past_key_values_length: int,
248
+ sliding_window: Optional[int] = None,
249
+ is_training: bool = False,
250
+ ) -> bool:
251
+ """
252
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
253
+ ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
254
+
255
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
256
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
257
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
258
+ passed).
259
+ """
260
+
261
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
262
+ key_value_length = query_length + past_key_values_length
263
+
264
+ is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
265
+
266
+ ignore_causal_mask = False
267
+
268
+ if attention_mask is None:
269
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
270
+ # shape, thus SDPA's `is_causal` argument is rightfully updated
271
+ # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
272
+ # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
273
+ # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
274
+ # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
275
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
276
+ #
277
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
278
+ # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
279
+ if (
280
+ (is_training or not is_tracing)
281
+ and (query_length == 1 or key_value_length == query_length)
282
+ and (sliding_window is None or key_value_length < sliding_window)
283
+ ):
284
+ ignore_causal_mask = True
285
+ elif sliding_window is None or key_value_length < sliding_window:
286
+ if len(attention_mask.shape) == 4:
287
+ return False
288
+ elif not is_tracing and torch.all(attention_mask == 1):
289
+ if query_length == 1 or key_value_length == query_length:
290
+ # For query_length == 1, causal attention and bi-directional attention are the same.
291
+ ignore_causal_mask = True
292
+
293
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
294
+ # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
295
+ # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
296
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
297
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
298
+
299
+ return ignore_causal_mask
300
+
301
+
302
+ def _prepare_4d_causal_attention_mask(
303
+ attention_mask: Optional[torch.Tensor],
304
+ input_shape: Union[torch.Size, Tuple, List],
305
+ inputs_embeds: torch.Tensor,
306
+ past_key_values_length: int,
307
+ sliding_window: Optional[int] = None,
308
+ ):
309
+ """
310
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
311
+ `(batch_size, key_value_length)`
312
+
313
+ Args:
314
+ attention_mask (`torch.Tensor` or `None`):
315
+ A 2D attention mask of shape `(batch_size, key_value_length)`
316
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
317
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
318
+ inputs_embeds (`torch.Tensor`):
319
+ The embedded inputs as a torch Tensor.
320
+ past_key_values_length (`int`):
321
+ The length of the key value cache.
322
+ sliding_window (`int`, *optional*):
323
+ If the model uses windowed attention, a sliding window should be passed.
324
+ """
325
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
326
+
327
+ key_value_length = input_shape[-1] + past_key_values_length
328
+
329
+ # 4d mask is passed through the layers
330
+ if attention_mask is not None and len(attention_mask.shape) == 2:
331
+ attention_mask = attn_mask_converter.to_4d(
332
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
333
+ )
334
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
335
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
336
+ if tuple(attention_mask.shape) != expected_shape:
337
+ raise ValueError(
338
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
339
+ )
340
+ else:
341
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
342
+ inverted_mask = 1.0 - attention_mask
343
+ attention_mask = inverted_mask.masked_fill(
344
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
345
+ )
346
+ else:
347
+ attention_mask = attn_mask_converter.to_causal_4d(
348
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
349
+ )
350
+
351
+ return attention_mask
352
+
353
+
354
+ # Adapted from _prepare_4d_causal_attention_mask
355
+ def _prepare_4d_causal_attention_mask_for_sdpa(
356
+ attention_mask: Optional[torch.Tensor],
357
+ input_shape: Union[torch.Size, Tuple, List],
358
+ inputs_embeds: torch.Tensor,
359
+ past_key_values_length: int,
360
+ sliding_window: Optional[int] = None,
361
+ ):
362
+ """
363
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
364
+
365
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
366
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
367
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
368
+ """
369
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
370
+
371
+ key_value_length = input_shape[-1] + past_key_values_length
372
+
373
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
374
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
375
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
376
+ is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
377
+
378
+ ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
379
+ attention_mask=attention_mask,
380
+ inputs_embeds=inputs_embeds,
381
+ past_key_values_length=past_key_values_length,
382
+ sliding_window=sliding_window,
383
+ )
384
+
385
+ if ignore_causal_mask:
386
+ expanded_4d_mask = None
387
+ elif attention_mask is None:
388
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
389
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
390
+ )
391
+ else:
392
+ if attention_mask.dim() == 4:
393
+ expanded_4d_mask = attention_mask
394
+ else:
395
+ expanded_4d_mask = attn_mask_converter.to_4d(
396
+ attention_mask,
397
+ input_shape[-1],
398
+ dtype=inputs_embeds.dtype,
399
+ key_value_length=key_value_length,
400
+ )
401
+
402
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
403
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
404
+ # Details: https://github.com/pytorch/pytorch/issues/110213
405
+ if not is_tracing and expanded_4d_mask.device.type == "cuda":
406
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
407
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
408
+ )
409
+
410
+ return expanded_4d_mask
411
+
412
+
413
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
414
+ """
415
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
416
+ `(batch_size, key_value_length)`
417
+
418
+ Args:
419
+ mask (`torch.Tensor`):
420
+ A 2D attention mask of shape `(batch_size, key_value_length)`
421
+ dtype (`torch.dtype`):
422
+ The torch dtype the created mask shall have.
423
+ tgt_len (`int`):
424
+ The target length or query length the created mask shall have.
425
+ """
426
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
427
+
428
+
429
+ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
430
+ """
431
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
432
+ `(batch_size, key_value_length)`
433
+
434
+ Args:
435
+ mask (`torch.Tensor`):
436
+ A 2D attention mask of shape `(batch_size, key_value_length)`
437
+ dtype (`torch.dtype`):
438
+ The torch dtype the created mask shall have.
439
+ tgt_len (`int`):
440
+ The target length or query length the created mask shall have.
441
+ """
442
+ _, key_value_length = mask.shape
443
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
444
+
445
+ is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()
446
+
447
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
448
+ if not is_tracing and torch.all(mask == 1):
449
+ return None
450
+ else:
451
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
452
+
453
+
454
+ def _create_4d_causal_attention_mask(
455
+ input_shape: Union[torch.Size, Tuple, List],
456
+ dtype: torch.dtype,
457
+ device: torch.device,
458
+ past_key_values_length: int = 0,
459
+ sliding_window: Optional[int] = None,
460
+ ) -> Optional[torch.Tensor]:
461
+ """
462
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
463
+
464
+ Args:
465
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
466
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
467
+ dtype (`torch.dtype`):
468
+ The torch dtype the created mask shall have.
469
+ device (`int`):
470
+ The torch device the created mask shall have.
471
+ sliding_window (`int`, *optional*):
472
+ If the model uses windowed attention, a sliding window should be passed.
473
+ """
474
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
475
+
476
+ key_value_length = past_key_values_length + input_shape[-1]
477
+ attention_mask = attn_mask_converter.to_causal_4d(
478
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
479
+ )
480
+
481
+ return attention_mask