Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/transformers/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/transformers/activations.py +239 -0
- .venv/lib/python3.11/site-packages/transformers/activations_tf.py +147 -0
- .venv/lib/python3.11/site-packages/transformers/audio_utils.py +1123 -0
- .venv/lib/python3.11/site-packages/transformers/cache_utils.py +0 -0
- .venv/lib/python3.11/site-packages/transformers/configuration_utils.py +1187 -0
- .venv/lib/python3.11/site-packages/transformers/convert_graph_to_onnx.py +551 -0
- .venv/lib/python3.11/site-packages/transformers/convert_pytorch_checkpoint_to_tf2.py +446 -0
- .venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py +1642 -0
- .venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizers_checkpoints_to_fast.py +130 -0
- .venv/lib/python3.11/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py +87 -0
- .venv/lib/python3.11/site-packages/transformers/debug_utils.py +346 -0
- .venv/lib/python3.11/site-packages/transformers/dependency_versions_check.py +63 -0
- .venv/lib/python3.11/site-packages/transformers/dependency_versions_table.py +102 -0
- .venv/lib/python3.11/site-packages/transformers/dynamic_module_utils.py +685 -0
- .venv/lib/python3.11/site-packages/transformers/feature_extraction_sequence_utils.py +372 -0
- .venv/lib/python3.11/site-packages/transformers/hf_argparser.py +437 -0
- .venv/lib/python3.11/site-packages/transformers/hyperparameter_search.py +141 -0
- .venv/lib/python3.11/site-packages/transformers/image_processing_base.py +559 -0
- .venv/lib/python3.11/site-packages/transformers/image_processing_utils.py +287 -0
- .venv/lib/python3.11/site-packages/transformers/image_processing_utils_fast.py +133 -0
- .venv/lib/python3.11/site-packages/transformers/image_transforms.py +860 -0
- .venv/lib/python3.11/site-packages/transformers/image_utils.py +871 -0
- .venv/lib/python3.11/site-packages/transformers/keras_callbacks.py +413 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +1327 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/ms_deform_attn.h +61 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp +40 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h +32 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu +156 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh +1467 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h +29 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh +1327 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deta/ms_deform_attn.h +61 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/deta/vision.cpp +16 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda.cu +187 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu +186 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_op.cpp +66 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/yoso/common.h +10 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation.cu +588 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation.h +71 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_cuda.cu +825 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_cuda.h +157 -0
- .venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_torch.cpp +128 -0
- .venv/lib/python3.11/site-packages/transformers/modelcard.py +908 -0
- .venv/lib/python3.11/site-packages/transformers/modeling_attn_mask_utils.py +481 -0
.gitattributes
CHANGED
|
@@ -419,3 +419,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
|
|
| 419 |
.venv/lib/python3.11/site-packages/jiter/jiter.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 420 |
.venv/lib/python3.11/site-packages/idna/__pycache__/uts46data.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 421 |
.venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 419 |
.venv/lib/python3.11/site-packages/jiter/jiter.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 420 |
.venv/lib/python3.11/site-packages/idna/__pycache__/uts46data.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 421 |
.venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 422 |
+
.venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 423 |
+
.venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae8fd4ca816177bb2c6f471e0b6c7334eb9caa4704a71b0935e8abf1ca1a36d2
|
| 3 |
+
size 159566
|
.venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c669d4b4c31b91773ae2bd09aa2fd0eb809698068c77850d9ca93283b9acc875
|
| 3 |
+
size 277639
|
.venv/lib/python3.11/site-packages/transformers/__init__.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/transformers/activations.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from collections import OrderedDict
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from packaging import version
|
| 20 |
+
from torch import Tensor, nn
|
| 21 |
+
|
| 22 |
+
from .utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PytorchGELUTanh(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
A fast C implementation of the tanh approximation of the GeLU activation function. See
|
| 31 |
+
https://arxiv.org/abs/1606.08415.
|
| 32 |
+
|
| 33 |
+
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
|
| 34 |
+
match due to rounding errors.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
super().__init__()
|
| 39 |
+
if version.parse(torch.__version__) < version.parse("1.12.0"):
|
| 40 |
+
raise ImportError(
|
| 41 |
+
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
|
| 42 |
+
"PytorchGELUTanh. Please upgrade torch."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 46 |
+
return nn.functional.gelu(input, approximate="tanh")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class NewGELUActivation(nn.Module):
|
| 50 |
+
"""
|
| 51 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
| 52 |
+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 56 |
+
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class GELUActivation(nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
| 62 |
+
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
| 63 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
|
| 64 |
+
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, use_gelu_python: bool = False):
|
| 68 |
+
super().__init__()
|
| 69 |
+
if use_gelu_python:
|
| 70 |
+
self.act = self._gelu_python
|
| 71 |
+
else:
|
| 72 |
+
self.act = nn.functional.gelu
|
| 73 |
+
|
| 74 |
+
def _gelu_python(self, input: Tensor) -> Tensor:
|
| 75 |
+
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
|
| 76 |
+
|
| 77 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 78 |
+
return self.act(input)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class FastGELUActivation(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 87 |
+
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class QuickGELUActivation(nn.Module):
|
| 91 |
+
"""
|
| 92 |
+
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 96 |
+
return input * torch.sigmoid(1.702 * input)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ClippedGELUActivation(nn.Module):
|
| 100 |
+
"""
|
| 101 |
+
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
|
| 102 |
+
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
|
| 103 |
+
https://arxiv.org/abs/2004.09602.
|
| 104 |
+
|
| 105 |
+
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
| 106 |
+
initially created.
|
| 107 |
+
|
| 108 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
| 109 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, min: float, max: float):
|
| 113 |
+
if min > max:
|
| 114 |
+
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
|
| 115 |
+
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.min = min
|
| 118 |
+
self.max = max
|
| 119 |
+
|
| 120 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 121 |
+
return torch.clip(gelu(x), self.min, self.max)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class AccurateGELUActivation(nn.Module):
|
| 125 |
+
"""
|
| 126 |
+
Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
|
| 127 |
+
https://github.com/hendrycks/GELUs
|
| 128 |
+
|
| 129 |
+
Implemented along with MEGA (Moving Average Equipped Gated Attention)
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.precomputed_constant = math.sqrt(2 / math.pi)
|
| 135 |
+
|
| 136 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 137 |
+
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class MishActivation(nn.Module):
|
| 141 |
+
"""
|
| 142 |
+
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
| 143 |
+
visit the official repository for the paper: https://github.com/digantamisra98/Mish
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(self):
|
| 147 |
+
super().__init__()
|
| 148 |
+
if version.parse(torch.__version__) < version.parse("1.9.0"):
|
| 149 |
+
self.act = self._mish_python
|
| 150 |
+
else:
|
| 151 |
+
self.act = nn.functional.mish
|
| 152 |
+
|
| 153 |
+
def _mish_python(self, input: Tensor) -> Tensor:
|
| 154 |
+
return input * torch.tanh(nn.functional.softplus(input))
|
| 155 |
+
|
| 156 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 157 |
+
return self.act(input)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class LinearActivation(nn.Module):
|
| 161 |
+
"""
|
| 162 |
+
Applies the linear activation function, i.e. forwarding input directly to output.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 166 |
+
return input
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class LaplaceActivation(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
|
| 172 |
+
https://arxiv.org/abs/2209.10655
|
| 173 |
+
|
| 174 |
+
Inspired by squared relu, but with bounded range and gradient for better stability
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def forward(self, input, mu=0.707107, sigma=0.282095):
|
| 178 |
+
input = (input - mu).div(sigma * math.sqrt(2.0))
|
| 179 |
+
return 0.5 * (1.0 + torch.erf(input))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class ReLUSquaredActivation(nn.Module):
|
| 183 |
+
"""
|
| 184 |
+
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def forward(self, input):
|
| 188 |
+
relu_applied = nn.functional.relu(input)
|
| 189 |
+
squared = torch.square(relu_applied)
|
| 190 |
+
return squared
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ClassInstantier(OrderedDict):
|
| 194 |
+
def __getitem__(self, key):
|
| 195 |
+
content = super().__getitem__(key)
|
| 196 |
+
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
| 197 |
+
return cls(**kwargs)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
ACT2CLS = {
|
| 201 |
+
"gelu": GELUActivation,
|
| 202 |
+
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
|
| 203 |
+
"gelu_fast": FastGELUActivation,
|
| 204 |
+
"gelu_new": NewGELUActivation,
|
| 205 |
+
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
|
| 206 |
+
"gelu_pytorch_tanh": PytorchGELUTanh,
|
| 207 |
+
"gelu_accurate": AccurateGELUActivation,
|
| 208 |
+
"laplace": LaplaceActivation,
|
| 209 |
+
"leaky_relu": nn.LeakyReLU,
|
| 210 |
+
"linear": LinearActivation,
|
| 211 |
+
"mish": MishActivation,
|
| 212 |
+
"quick_gelu": QuickGELUActivation,
|
| 213 |
+
"relu": nn.ReLU,
|
| 214 |
+
"relu2": ReLUSquaredActivation,
|
| 215 |
+
"relu6": nn.ReLU6,
|
| 216 |
+
"sigmoid": nn.Sigmoid,
|
| 217 |
+
"silu": nn.SiLU,
|
| 218 |
+
"swish": nn.SiLU,
|
| 219 |
+
"tanh": nn.Tanh,
|
| 220 |
+
}
|
| 221 |
+
ACT2FN = ClassInstantier(ACT2CLS)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_activation(activation_string):
|
| 225 |
+
if activation_string in ACT2FN:
|
| 226 |
+
return ACT2FN[activation_string]
|
| 227 |
+
else:
|
| 228 |
+
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# For backwards compatibility with: from activations import gelu_python
|
| 232 |
+
gelu_python = get_activation("gelu_python")
|
| 233 |
+
gelu_new = get_activation("gelu_new")
|
| 234 |
+
gelu = get_activation("gelu")
|
| 235 |
+
gelu_fast = get_activation("gelu_fast")
|
| 236 |
+
quick_gelu = get_activation("quick_gelu")
|
| 237 |
+
silu = get_activation("silu")
|
| 238 |
+
mish = get_activation("mish")
|
| 239 |
+
linear_act = get_activation("linear")
|
.venv/lib/python3.11/site-packages/transformers/activations_tf.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import tensorflow as tf
|
| 18 |
+
from packaging.version import parse
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import tf_keras as keras
|
| 23 |
+
except (ModuleNotFoundError, ImportError):
|
| 24 |
+
import keras
|
| 25 |
+
|
| 26 |
+
if parse(keras.__version__).major > 2:
|
| 27 |
+
raise ValueError(
|
| 28 |
+
"Your currently installed version of Keras is Keras 3, but this is not yet supported in "
|
| 29 |
+
"Transformers. Please install the backwards-compatible tf-keras package with "
|
| 30 |
+
"`pip install tf-keras`."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _gelu(x):
|
| 35 |
+
"""
|
| 36 |
+
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
| 37 |
+
initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
| 38 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
|
| 39 |
+
https://arxiv.org/abs/1606.08415
|
| 40 |
+
"""
|
| 41 |
+
x = tf.convert_to_tensor(x)
|
| 42 |
+
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
|
| 43 |
+
|
| 44 |
+
return x * cdf
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _gelu_new(x):
|
| 48 |
+
"""
|
| 49 |
+
Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
x: float Tensor to perform activation
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
`x` with the GELU activation applied.
|
| 56 |
+
"""
|
| 57 |
+
x = tf.convert_to_tensor(x)
|
| 58 |
+
pi = tf.cast(math.pi, x.dtype)
|
| 59 |
+
coeff = tf.cast(0.044715, x.dtype)
|
| 60 |
+
cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
|
| 61 |
+
|
| 62 |
+
return x * cdf
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def mish(x):
|
| 66 |
+
x = tf.convert_to_tensor(x)
|
| 67 |
+
|
| 68 |
+
return x * tf.tanh(tf.math.softplus(x))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def gelu_fast(x):
|
| 72 |
+
x = tf.convert_to_tensor(x)
|
| 73 |
+
coeff1 = tf.cast(0.044715, x.dtype)
|
| 74 |
+
coeff2 = tf.cast(0.7978845608, x.dtype)
|
| 75 |
+
|
| 76 |
+
return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def quick_gelu(x):
|
| 80 |
+
x = tf.convert_to_tensor(x)
|
| 81 |
+
coeff = tf.cast(1.702, x.dtype)
|
| 82 |
+
return x * tf.math.sigmoid(coeff * x)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def gelu_10(x):
|
| 86 |
+
"""
|
| 87 |
+
Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
|
| 88 |
+
it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
|
| 89 |
+
https://arxiv.org/abs/2004.09602
|
| 90 |
+
|
| 91 |
+
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
| 92 |
+
initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
| 93 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
|
| 94 |
+
https://arxiv.org/abs/1606.08415 :param x: :return:
|
| 95 |
+
"""
|
| 96 |
+
return tf.clip_by_value(_gelu(x), -10, 10)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def glu(x, axis=-1):
|
| 100 |
+
"""
|
| 101 |
+
Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
|
| 102 |
+
the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
`x`: float Tensor to perform activation
|
| 106 |
+
`axis`: dimension across which `x` be split in half
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
`x` with the GLU activation applied (with its size halved across the dimension `axis`).
|
| 110 |
+
"""
|
| 111 |
+
a, b = tf.split(x, 2, axis=axis)
|
| 112 |
+
return a * tf.math.sigmoid(b)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if parse(tf.version.VERSION) >= parse("2.4"):
|
| 116 |
+
|
| 117 |
+
def approximate_gelu_wrap(x):
|
| 118 |
+
return keras.activations.gelu(x, approximate=True)
|
| 119 |
+
|
| 120 |
+
gelu = keras.activations.gelu
|
| 121 |
+
gelu_new = approximate_gelu_wrap
|
| 122 |
+
else:
|
| 123 |
+
gelu = _gelu
|
| 124 |
+
gelu_new = _gelu_new
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
ACT2FN = {
|
| 128 |
+
"gelu": gelu,
|
| 129 |
+
"gelu_10": gelu_10,
|
| 130 |
+
"gelu_fast": gelu_fast,
|
| 131 |
+
"gelu_new": gelu_new,
|
| 132 |
+
"glu": glu,
|
| 133 |
+
"mish": mish,
|
| 134 |
+
"quick_gelu": quick_gelu,
|
| 135 |
+
"relu": keras.activations.relu,
|
| 136 |
+
"sigmoid": keras.activations.sigmoid,
|
| 137 |
+
"silu": keras.activations.swish,
|
| 138 |
+
"swish": keras.activations.swish,
|
| 139 |
+
"tanh": keras.activations.tanh,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def get_tf_activation(activation_string):
|
| 144 |
+
if activation_string in ACT2FN:
|
| 145 |
+
return ACT2FN[activation_string]
|
| 146 |
+
else:
|
| 147 |
+
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
.venv/lib/python3.11/site-packages/transformers/audio_utils.py
ADDED
|
@@ -0,0 +1,1123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
|
| 17 |
+
and remove unnecessary dependencies.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import warnings
|
| 21 |
+
from typing import List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
|
| 27 |
+
"""
|
| 28 |
+
Convert frequency from hertz to mels.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
freq (`float` or `np.ndarray`):
|
| 32 |
+
The frequency, or multiple frequencies, in hertz (Hz).
|
| 33 |
+
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
| 34 |
+
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
`float` or `np.ndarray`: The frequencies on the mel scale.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
| 41 |
+
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
| 42 |
+
|
| 43 |
+
if mel_scale == "htk":
|
| 44 |
+
return 2595.0 * np.log10(1.0 + (freq / 700.0))
|
| 45 |
+
elif mel_scale == "kaldi":
|
| 46 |
+
return 1127.0 * np.log(1.0 + (freq / 700.0))
|
| 47 |
+
|
| 48 |
+
min_log_hertz = 1000.0
|
| 49 |
+
min_log_mel = 15.0
|
| 50 |
+
logstep = 27.0 / np.log(6.4)
|
| 51 |
+
mels = 3.0 * freq / 200.0
|
| 52 |
+
|
| 53 |
+
if isinstance(freq, np.ndarray):
|
| 54 |
+
log_region = freq >= min_log_hertz
|
| 55 |
+
mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
|
| 56 |
+
elif freq >= min_log_hertz:
|
| 57 |
+
mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
|
| 58 |
+
|
| 59 |
+
return mels
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
|
| 63 |
+
"""
|
| 64 |
+
Convert frequency from mels to hertz.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
mels (`float` or `np.ndarray`):
|
| 68 |
+
The frequency, or multiple frequencies, in mels.
|
| 69 |
+
mel_scale (`str`, *optional*, `"htk"`):
|
| 70 |
+
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
`float` or `np.ndarray`: The frequencies in hertz.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
| 77 |
+
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
| 78 |
+
|
| 79 |
+
if mel_scale == "htk":
|
| 80 |
+
return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
|
| 81 |
+
elif mel_scale == "kaldi":
|
| 82 |
+
return 700.0 * (np.exp(mels / 1127.0) - 1.0)
|
| 83 |
+
|
| 84 |
+
min_log_hertz = 1000.0
|
| 85 |
+
min_log_mel = 15.0
|
| 86 |
+
logstep = np.log(6.4) / 27.0
|
| 87 |
+
freq = 200.0 * mels / 3.0
|
| 88 |
+
|
| 89 |
+
if isinstance(mels, np.ndarray):
|
| 90 |
+
log_region = mels >= min_log_mel
|
| 91 |
+
freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
|
| 92 |
+
elif mels >= min_log_mel:
|
| 93 |
+
freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
|
| 94 |
+
|
| 95 |
+
return freq
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def hertz_to_octave(
|
| 99 |
+
freq: Union[float, np.ndarray], tuning: Optional[float] = 0.0, bins_per_octave: Optional[int] = 12
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Convert frequency from hertz to fractional octave numbers.
|
| 103 |
+
Adapted from *librosa*.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
freq (`float` or `np.ndarray`):
|
| 107 |
+
The frequency, or multiple frequencies, in hertz (Hz).
|
| 108 |
+
tuning (`float`, defaults to `0.`):
|
| 109 |
+
Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
|
| 110 |
+
bins_per_octave (`int`, defaults to `12`):
|
| 111 |
+
Number of bins per octave.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
`float` or `np.ndarray`: The frequencies on the octave scale.
|
| 115 |
+
"""
|
| 116 |
+
stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
|
| 117 |
+
octave = np.log2(freq / (float(stuttgart_pitch) / 16))
|
| 118 |
+
return octave
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
|
| 122 |
+
"""
|
| 123 |
+
Creates a triangular filter bank.
|
| 124 |
+
|
| 125 |
+
Adapted from *torchaudio* and *librosa*.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
|
| 129 |
+
Discrete frequencies of the FFT bins in Hz.
|
| 130 |
+
filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
|
| 131 |
+
Center frequencies of the triangular filters to create, in Hz.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
`np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
|
| 135 |
+
"""
|
| 136 |
+
filter_diff = np.diff(filter_freqs)
|
| 137 |
+
slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
|
| 138 |
+
down_slopes = -slopes[:, :-2] / filter_diff[:-1]
|
| 139 |
+
up_slopes = slopes[:, 2:] / filter_diff[1:]
|
| 140 |
+
return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def chroma_filter_bank(
|
| 144 |
+
num_frequency_bins: int,
|
| 145 |
+
num_chroma: int,
|
| 146 |
+
sampling_rate: int,
|
| 147 |
+
tuning: float = 0.0,
|
| 148 |
+
power: Optional[float] = 2.0,
|
| 149 |
+
weighting_parameters: Optional[Tuple[float]] = (5.0, 2),
|
| 150 |
+
start_at_c_chroma: Optional[bool] = True,
|
| 151 |
+
):
|
| 152 |
+
"""
|
| 153 |
+
Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.
|
| 154 |
+
|
| 155 |
+
Adapted from *librosa*.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
num_frequency_bins (`int`):
|
| 159 |
+
Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
|
| 160 |
+
num_chroma (`int`):
|
| 161 |
+
Number of chroma bins (i.e pitch classes).
|
| 162 |
+
sampling_rate (`float`):
|
| 163 |
+
Sample rate of the audio waveform.
|
| 164 |
+
tuning (`float`):
|
| 165 |
+
Tuning deviation from A440 in fractions of a chroma bin.
|
| 166 |
+
power (`float`, *optional*, defaults to 2.0):
|
| 167 |
+
If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm.
|
| 168 |
+
weighting_parameters (`Tuple[float]`, *optional*, defaults to `(5., 2.)`):
|
| 169 |
+
If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and
|
| 170 |
+
the second element being the Gaussian half-width.
|
| 171 |
+
start_at_c_chroma (`float`, *optional*, defaults to `True`):
|
| 172 |
+
If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'.
|
| 173 |
+
Returns:
|
| 174 |
+
`np.ndarray` of shape `(num_frequency_bins, num_chroma)`
|
| 175 |
+
"""
|
| 176 |
+
# Get the FFT bins, not counting the DC component
|
| 177 |
+
frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]
|
| 178 |
+
|
| 179 |
+
freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)
|
| 180 |
+
|
| 181 |
+
# make up a value for the 0 Hz bin = 1.5 octaves below bin 1
|
| 182 |
+
# (so chroma is 50% rotated from bin 1, and bin width is broad)
|
| 183 |
+
freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))
|
| 184 |
+
|
| 185 |
+
bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))
|
| 186 |
+
|
| 187 |
+
chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T
|
| 188 |
+
|
| 189 |
+
num_chroma2 = np.round(float(num_chroma) / 2)
|
| 190 |
+
|
| 191 |
+
# Project into range -num_chroma/2 .. num_chroma/2
|
| 192 |
+
# add on fixed offset of 10*num_chroma to ensure all values passed to
|
| 193 |
+
# rem are positive
|
| 194 |
+
chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2
|
| 195 |
+
|
| 196 |
+
# Gaussian bumps - 2*D to make them narrower
|
| 197 |
+
chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)
|
| 198 |
+
|
| 199 |
+
# normalize each column
|
| 200 |
+
if power is not None:
|
| 201 |
+
chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)
|
| 202 |
+
|
| 203 |
+
# Maybe apply scaling for fft bins
|
| 204 |
+
if weighting_parameters is not None:
|
| 205 |
+
center, half_width = weighting_parameters
|
| 206 |
+
chroma_filters *= np.tile(
|
| 207 |
+
np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
|
| 208 |
+
(num_chroma, 1),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if start_at_c_chroma:
|
| 212 |
+
chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)
|
| 213 |
+
|
| 214 |
+
# remove aliasing columns, copy to ensure row-contiguity
|
| 215 |
+
return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def mel_filter_bank(
|
| 219 |
+
num_frequency_bins: int,
|
| 220 |
+
num_mel_filters: int,
|
| 221 |
+
min_frequency: float,
|
| 222 |
+
max_frequency: float,
|
| 223 |
+
sampling_rate: int,
|
| 224 |
+
norm: Optional[str] = None,
|
| 225 |
+
mel_scale: str = "htk",
|
| 226 |
+
triangularize_in_mel_space: bool = False,
|
| 227 |
+
) -> np.ndarray:
|
| 228 |
+
"""
|
| 229 |
+
Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
|
| 230 |
+
various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
|
| 231 |
+
are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
|
| 232 |
+
features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
|
| 233 |
+
|
| 234 |
+
Different banks of mel filters were introduced in the literature. The following variations are supported:
|
| 235 |
+
|
| 236 |
+
- MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
|
| 237 |
+
bandwidth of `[0, 4600]` Hz.
|
| 238 |
+
- MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
|
| 239 |
+
bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
|
| 240 |
+
- MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
|
| 241 |
+
speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
|
| 242 |
+
- HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
|
| 243 |
+
12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
|
| 244 |
+
|
| 245 |
+
This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
|
| 246 |
+
`melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
num_frequency_bins (`int`):
|
| 250 |
+
Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
|
| 251 |
+
num_mel_filters (`int`):
|
| 252 |
+
Number of mel filters to generate.
|
| 253 |
+
min_frequency (`float`):
|
| 254 |
+
Lowest frequency of interest in Hz.
|
| 255 |
+
max_frequency (`float`):
|
| 256 |
+
Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
|
| 257 |
+
sampling_rate (`int`):
|
| 258 |
+
Sample rate of the audio waveform.
|
| 259 |
+
norm (`str`, *optional*):
|
| 260 |
+
If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
|
| 261 |
+
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
| 262 |
+
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
| 263 |
+
triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
|
| 264 |
+
If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
|
| 265 |
+
should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
`np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
|
| 269 |
+
projection matrix to go from a spectrogram to a mel spectrogram.
|
| 270 |
+
"""
|
| 271 |
+
if norm is not None and norm != "slaney":
|
| 272 |
+
raise ValueError('norm must be one of None or "slaney"')
|
| 273 |
+
|
| 274 |
+
# center points of the triangular mel filters
|
| 275 |
+
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
|
| 276 |
+
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
|
| 277 |
+
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
|
| 278 |
+
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
|
| 279 |
+
|
| 280 |
+
if triangularize_in_mel_space:
|
| 281 |
+
# frequencies of FFT bins in Hz, but filters triangularized in mel space
|
| 282 |
+
fft_bin_width = sampling_rate / (num_frequency_bins * 2)
|
| 283 |
+
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
|
| 284 |
+
filter_freqs = mel_freqs
|
| 285 |
+
else:
|
| 286 |
+
# frequencies of FFT bins in Hz
|
| 287 |
+
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
|
| 288 |
+
|
| 289 |
+
mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
|
| 290 |
+
|
| 291 |
+
if norm is not None and norm == "slaney":
|
| 292 |
+
# Slaney-style mel is scaled to be approx constant energy per channel
|
| 293 |
+
enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
|
| 294 |
+
mel_filters *= np.expand_dims(enorm, 0)
|
| 295 |
+
|
| 296 |
+
if (mel_filters.max(axis=0) == 0.0).any():
|
| 297 |
+
warnings.warn(
|
| 298 |
+
"At least one mel filter has all zero values. "
|
| 299 |
+
f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
|
| 300 |
+
f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
return mel_filters
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def optimal_fft_length(window_length: int) -> int:
|
| 307 |
+
"""
|
| 308 |
+
Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
|
| 309 |
+
already a power of two, rounds it up to the next power or two.
|
| 310 |
+
|
| 311 |
+
The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
|
| 312 |
+
of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
|
| 313 |
+
is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
|
| 314 |
+
it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
|
| 315 |
+
"""
|
| 316 |
+
return 2 ** int(np.ceil(np.log2(window_length)))
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def window_function(
|
| 320 |
+
window_length: int,
|
| 321 |
+
name: str = "hann",
|
| 322 |
+
periodic: bool = True,
|
| 323 |
+
frame_length: Optional[int] = None,
|
| 324 |
+
center: bool = True,
|
| 325 |
+
) -> np.ndarray:
|
| 326 |
+
"""
|
| 327 |
+
Returns an array containing the specified window. This window is intended to be used with `stft`.
|
| 328 |
+
|
| 329 |
+
The following window types are supported:
|
| 330 |
+
|
| 331 |
+
- `"boxcar"`: a rectangular window
|
| 332 |
+
- `"hamming"`: the Hamming window
|
| 333 |
+
- `"hann"`: the Hann window
|
| 334 |
+
- `"povey"`: the Povey window
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
window_length (`int`):
|
| 338 |
+
The length of the window in samples.
|
| 339 |
+
name (`str`, *optional*, defaults to `"hann"`):
|
| 340 |
+
The name of the window function.
|
| 341 |
+
periodic (`bool`, *optional*, defaults to `True`):
|
| 342 |
+
Whether the window is periodic or symmetric.
|
| 343 |
+
frame_length (`int`, *optional*):
|
| 344 |
+
The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
|
| 345 |
+
than the frame length, so that it will be zero-padded.
|
| 346 |
+
center (`bool`, *optional*, defaults to `True`):
|
| 347 |
+
Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
`np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
|
| 351 |
+
"""
|
| 352 |
+
length = window_length + 1 if periodic else window_length
|
| 353 |
+
|
| 354 |
+
if name == "boxcar":
|
| 355 |
+
window = np.ones(length)
|
| 356 |
+
elif name in ["hamming", "hamming_window"]:
|
| 357 |
+
window = np.hamming(length)
|
| 358 |
+
elif name in ["hann", "hann_window"]:
|
| 359 |
+
window = np.hanning(length)
|
| 360 |
+
elif name in ["povey"]:
|
| 361 |
+
window = np.power(np.hanning(length), 0.85)
|
| 362 |
+
else:
|
| 363 |
+
raise ValueError(f"Unknown window function '{name}'")
|
| 364 |
+
|
| 365 |
+
if periodic:
|
| 366 |
+
window = window[:-1]
|
| 367 |
+
|
| 368 |
+
if frame_length is None:
|
| 369 |
+
return window
|
| 370 |
+
|
| 371 |
+
if window_length > frame_length:
|
| 372 |
+
raise ValueError(
|
| 373 |
+
f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
padded_window = np.zeros(frame_length)
|
| 377 |
+
offset = (frame_length - window_length) // 2 if center else 0
|
| 378 |
+
padded_window[offset : offset + window_length] = window
|
| 379 |
+
return padded_window
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
# TODO This method does not support batching yet as we are mainly focused on inference.
|
| 383 |
+
def spectrogram(
|
| 384 |
+
waveform: np.ndarray,
|
| 385 |
+
window: np.ndarray,
|
| 386 |
+
frame_length: int,
|
| 387 |
+
hop_length: int,
|
| 388 |
+
fft_length: Optional[int] = None,
|
| 389 |
+
power: Optional[float] = 1.0,
|
| 390 |
+
center: bool = True,
|
| 391 |
+
pad_mode: str = "reflect",
|
| 392 |
+
onesided: bool = True,
|
| 393 |
+
preemphasis: Optional[float] = None,
|
| 394 |
+
mel_filters: Optional[np.ndarray] = None,
|
| 395 |
+
mel_floor: float = 1e-10,
|
| 396 |
+
log_mel: Optional[str] = None,
|
| 397 |
+
reference: float = 1.0,
|
| 398 |
+
min_value: float = 1e-10,
|
| 399 |
+
db_range: Optional[float] = None,
|
| 400 |
+
remove_dc_offset: Optional[bool] = None,
|
| 401 |
+
dtype: np.dtype = np.float32,
|
| 402 |
+
) -> np.ndarray:
|
| 403 |
+
"""
|
| 404 |
+
Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
|
| 405 |
+
|
| 406 |
+
This function can create the following kinds of spectrograms:
|
| 407 |
+
|
| 408 |
+
- amplitude spectrogram (`power = 1.0`)
|
| 409 |
+
- power spectrogram (`power = 2.0`)
|
| 410 |
+
- complex-valued spectrogram (`power = None`)
|
| 411 |
+
- log spectrogram (use `log_mel` argument)
|
| 412 |
+
- mel spectrogram (provide `mel_filters`)
|
| 413 |
+
- log-mel spectrogram (provide `mel_filters` and `log_mel`)
|
| 414 |
+
|
| 415 |
+
How this works:
|
| 416 |
+
|
| 417 |
+
1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
|
| 418 |
+
- hop_length` samples.
|
| 419 |
+
2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
|
| 420 |
+
3. The DFT is taken of each windowed frame.
|
| 421 |
+
4. The results are stacked into a spectrogram.
|
| 422 |
+
|
| 423 |
+
We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
|
| 424 |
+
|
| 425 |
+
- The analysis frame. This is the size of the time slices that the input waveform is split into.
|
| 426 |
+
- The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
|
| 427 |
+
- The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
|
| 428 |
+
|
| 429 |
+
In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
|
| 430 |
+
padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
|
| 431 |
+
typically the next power of two.
|
| 432 |
+
|
| 433 |
+
Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
|
| 434 |
+
`torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
|
| 435 |
+
can be constructed.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
waveform (`np.ndarray` of shape `(length,)`):
|
| 439 |
+
The input waveform. This must be a single real-valued, mono waveform.
|
| 440 |
+
window (`np.ndarray` of shape `(frame_length,)`):
|
| 441 |
+
The windowing function to apply, including zero-padding if necessary. The actual window length may be
|
| 442 |
+
shorter than `frame_length`, but we're assuming the array has already been zero-padded.
|
| 443 |
+
frame_length (`int`):
|
| 444 |
+
The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
|
| 445 |
+
allow smaller sizes.
|
| 446 |
+
hop_length (`int`):
|
| 447 |
+
The stride between successive analysis frames in samples.
|
| 448 |
+
fft_length (`int`, *optional*):
|
| 449 |
+
The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
|
| 450 |
+
For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
|
| 451 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 452 |
+
If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
|
| 453 |
+
complex numbers.
|
| 454 |
+
center (`bool`, *optional*, defaults to `True`):
|
| 455 |
+
Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
|
| 456 |
+
`t` will start at time `t * hop_length`.
|
| 457 |
+
pad_mode (`str`, *optional*, defaults to `"reflect"`):
|
| 458 |
+
Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
|
| 459 |
+
(pad with edge values), `"reflect"` (pads with mirrored values).
|
| 460 |
+
onesided (`bool`, *optional*, defaults to `True`):
|
| 461 |
+
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
|
| 462 |
+
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
|
| 463 |
+
preemphasis (`float`, *optional*)
|
| 464 |
+
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
|
| 465 |
+
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
|
| 466 |
+
The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
|
| 467 |
+
mel_floor (`float`, *optional*, defaults to 1e-10):
|
| 468 |
+
Minimum value of mel frequency banks.
|
| 469 |
+
log_mel (`str`, *optional*):
|
| 470 |
+
How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
|
| 471 |
+
the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
|
| 472 |
+
used when `power` is not `None`.
|
| 473 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 474 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 475 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 476 |
+
min_value (`float`, *optional*, defaults to `1e-10`):
|
| 477 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 478 |
+
`log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
|
| 479 |
+
amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
|
| 480 |
+
db_range (`float`, *optional*):
|
| 481 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 482 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 483 |
+
remove_dc_offset (`bool`, *optional*):
|
| 484 |
+
Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
|
| 485 |
+
order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
|
| 486 |
+
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
| 487 |
+
Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
|
| 488 |
+
`np.complex64`.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
`nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
|
| 492 |
+
`(num_mel_filters, length)` for a mel spectrogram.
|
| 493 |
+
"""
|
| 494 |
+
window_length = len(window)
|
| 495 |
+
|
| 496 |
+
if fft_length is None:
|
| 497 |
+
fft_length = frame_length
|
| 498 |
+
|
| 499 |
+
if frame_length > fft_length:
|
| 500 |
+
raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
|
| 501 |
+
|
| 502 |
+
if window_length != frame_length:
|
| 503 |
+
raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
|
| 504 |
+
|
| 505 |
+
if hop_length <= 0:
|
| 506 |
+
raise ValueError("hop_length must be greater than zero")
|
| 507 |
+
|
| 508 |
+
if waveform.ndim != 1:
|
| 509 |
+
raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
|
| 510 |
+
|
| 511 |
+
if np.iscomplexobj(waveform):
|
| 512 |
+
raise ValueError("Complex-valued input waveforms are not currently supported")
|
| 513 |
+
|
| 514 |
+
if power is None and mel_filters is not None:
|
| 515 |
+
raise ValueError(
|
| 516 |
+
"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
|
| 517 |
+
"Specify `power` to fix this issue."
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
# center pad the waveform
|
| 521 |
+
if center:
|
| 522 |
+
padding = [(int(frame_length // 2), int(frame_length // 2))]
|
| 523 |
+
waveform = np.pad(waveform, padding, mode=pad_mode)
|
| 524 |
+
|
| 525 |
+
# promote to float64, since np.fft uses float64 internally
|
| 526 |
+
waveform = waveform.astype(np.float64)
|
| 527 |
+
window = window.astype(np.float64)
|
| 528 |
+
|
| 529 |
+
# split waveform into frames of frame_length size
|
| 530 |
+
num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
|
| 531 |
+
|
| 532 |
+
num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
|
| 533 |
+
spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
|
| 534 |
+
|
| 535 |
+
# rfft is faster than fft
|
| 536 |
+
fft_func = np.fft.rfft if onesided else np.fft.fft
|
| 537 |
+
buffer = np.zeros(fft_length)
|
| 538 |
+
|
| 539 |
+
timestep = 0
|
| 540 |
+
for frame_idx in range(num_frames):
|
| 541 |
+
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
|
| 542 |
+
|
| 543 |
+
if remove_dc_offset:
|
| 544 |
+
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
|
| 545 |
+
|
| 546 |
+
if preemphasis is not None:
|
| 547 |
+
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
|
| 548 |
+
buffer[0] *= 1 - preemphasis
|
| 549 |
+
|
| 550 |
+
buffer[:frame_length] *= window
|
| 551 |
+
|
| 552 |
+
spectrogram[frame_idx] = fft_func(buffer)
|
| 553 |
+
timestep += hop_length
|
| 554 |
+
|
| 555 |
+
# note: ** is much faster than np.power
|
| 556 |
+
if power is not None:
|
| 557 |
+
spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
|
| 558 |
+
|
| 559 |
+
spectrogram = spectrogram.T
|
| 560 |
+
|
| 561 |
+
if mel_filters is not None:
|
| 562 |
+
spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
|
| 563 |
+
|
| 564 |
+
if power is not None and log_mel is not None:
|
| 565 |
+
if log_mel == "log":
|
| 566 |
+
spectrogram = np.log(spectrogram)
|
| 567 |
+
elif log_mel == "log10":
|
| 568 |
+
spectrogram = np.log10(spectrogram)
|
| 569 |
+
elif log_mel == "dB":
|
| 570 |
+
if power == 1.0:
|
| 571 |
+
spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
|
| 572 |
+
elif power == 2.0:
|
| 573 |
+
spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
|
| 574 |
+
else:
|
| 575 |
+
raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
|
| 576 |
+
else:
|
| 577 |
+
raise ValueError(f"Unknown log_mel option: {log_mel}")
|
| 578 |
+
|
| 579 |
+
spectrogram = np.asarray(spectrogram, dtype)
|
| 580 |
+
|
| 581 |
+
return spectrogram
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def spectrogram_batch(
|
| 585 |
+
waveform_list: List[np.ndarray],
|
| 586 |
+
window: np.ndarray,
|
| 587 |
+
frame_length: int,
|
| 588 |
+
hop_length: int,
|
| 589 |
+
fft_length: Optional[int] = None,
|
| 590 |
+
power: Optional[float] = 1.0,
|
| 591 |
+
center: bool = True,
|
| 592 |
+
pad_mode: str = "reflect",
|
| 593 |
+
onesided: bool = True,
|
| 594 |
+
preemphasis: Optional[float] = None,
|
| 595 |
+
mel_filters: Optional[np.ndarray] = None,
|
| 596 |
+
mel_floor: float = 1e-10,
|
| 597 |
+
log_mel: Optional[str] = None,
|
| 598 |
+
reference: float = 1.0,
|
| 599 |
+
min_value: float = 1e-10,
|
| 600 |
+
db_range: Optional[float] = None,
|
| 601 |
+
remove_dc_offset: Optional[bool] = None,
|
| 602 |
+
dtype: np.dtype = np.float32,
|
| 603 |
+
) -> List[np.ndarray]:
|
| 604 |
+
"""
|
| 605 |
+
Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing.
|
| 606 |
+
This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting.
|
| 607 |
+
|
| 608 |
+
It supports generating various types of spectrograms:
|
| 609 |
+
|
| 610 |
+
- amplitude spectrogram (`power = 1.0`)
|
| 611 |
+
- power spectrogram (`power = 2.0`)
|
| 612 |
+
- complex-valued spectrogram (`power = None`)
|
| 613 |
+
- log spectrogram (use `log_mel` argument)
|
| 614 |
+
- mel spectrogram (provide `mel_filters`)
|
| 615 |
+
- log-mel spectrogram (provide `mel_filters` and `log_mel`)
|
| 616 |
+
|
| 617 |
+
How this works:
|
| 618 |
+
|
| 619 |
+
1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
|
| 620 |
+
- hop_length` samples.
|
| 621 |
+
2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
|
| 622 |
+
3. The DFT is taken of each windowed frame.
|
| 623 |
+
4. The results are stacked into a spectrogram.
|
| 624 |
+
|
| 625 |
+
We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
|
| 626 |
+
|
| 627 |
+
- The analysis frame. This is the size of the time slices that the input waveform is split into.
|
| 628 |
+
- The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
|
| 629 |
+
- The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
|
| 630 |
+
|
| 631 |
+
In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
|
| 632 |
+
padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
|
| 633 |
+
typically the next power of two.
|
| 634 |
+
|
| 635 |
+
Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`.
|
| 636 |
+
|
| 637 |
+
Args:
|
| 638 |
+
waveform_list (`List[np.ndarray]` with arrays of shape `(length,)`):
|
| 639 |
+
The list of input waveforms, each a single-channel (mono) signal.
|
| 640 |
+
window (`np.ndarray` of shape `(frame_length,)`):
|
| 641 |
+
The windowing function to apply, including zero-padding if necessary.
|
| 642 |
+
frame_length (`int`):
|
| 643 |
+
The length of each frame for analysis.
|
| 644 |
+
hop_length (`int`):
|
| 645 |
+
The step size between successive frames.
|
| 646 |
+
fft_length (`int`, *optional*):
|
| 647 |
+
The size of the FFT buffer, defining frequency bin resolution.
|
| 648 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 649 |
+
Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex.
|
| 650 |
+
center (`bool`, *optional*, defaults to `True`):
|
| 651 |
+
Whether to center-pad the waveform frames.
|
| 652 |
+
pad_mode (`str`, *optional*, defaults to `"reflect"`):
|
| 653 |
+
The padding strategy when `center` is `True`.
|
| 654 |
+
onesided (`bool`, *optional*, defaults to `True`):
|
| 655 |
+
If True, returns a one-sided spectrogram for real input signals.
|
| 656 |
+
preemphasis (`float`, *optional*):
|
| 657 |
+
Applies a pre-emphasis filter to each frame.
|
| 658 |
+
mel_filters (`np.ndarray`, *optional*):
|
| 659 |
+
Mel filter bank for converting to mel spectrogram.
|
| 660 |
+
mel_floor (`float`, *optional*, defaults to 1e-10):
|
| 661 |
+
Floor value for mel spectrogram to avoid log(0).
|
| 662 |
+
log_mel (`str`, *optional*):
|
| 663 |
+
Specifies log scaling strategy; options are None, "log", "log10", "dB".
|
| 664 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 665 |
+
Reference value for dB conversion in log_mel.
|
| 666 |
+
min_value (`float`, *optional*, defaults to 1e-10):
|
| 667 |
+
Minimum floor value for log scale conversions.
|
| 668 |
+
db_range (`float`, *optional*):
|
| 669 |
+
Dynamic range for dB scale spectrograms.
|
| 670 |
+
remove_dc_offset (`bool`, *optional*):
|
| 671 |
+
Whether to remove the DC offset from each frame.
|
| 672 |
+
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
| 673 |
+
Data type of the output spectrogram.
|
| 674 |
+
|
| 675 |
+
Returns:
|
| 676 |
+
List[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform.
|
| 677 |
+
"""
|
| 678 |
+
window_length = len(window)
|
| 679 |
+
|
| 680 |
+
if fft_length is None:
|
| 681 |
+
fft_length = frame_length
|
| 682 |
+
|
| 683 |
+
if frame_length > fft_length:
|
| 684 |
+
raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
|
| 685 |
+
|
| 686 |
+
if window_length != frame_length:
|
| 687 |
+
raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
|
| 688 |
+
|
| 689 |
+
if hop_length <= 0:
|
| 690 |
+
raise ValueError("hop_length must be greater than zero")
|
| 691 |
+
|
| 692 |
+
# Check the dimensions of the waveform , and if waveform is complex
|
| 693 |
+
for waveform in waveform_list:
|
| 694 |
+
if waveform.ndim != 1:
|
| 695 |
+
raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
|
| 696 |
+
if np.iscomplexobj(waveform):
|
| 697 |
+
raise ValueError("Complex-valued input waveforms are not currently supported")
|
| 698 |
+
# Center pad the waveform
|
| 699 |
+
if center:
|
| 700 |
+
padding = [(int(frame_length // 2), int(frame_length // 2))]
|
| 701 |
+
waveform_list = [
|
| 702 |
+
np.pad(
|
| 703 |
+
waveform,
|
| 704 |
+
padding,
|
| 705 |
+
mode=pad_mode,
|
| 706 |
+
)
|
| 707 |
+
for waveform in waveform_list
|
| 708 |
+
]
|
| 709 |
+
original_waveform_lengths = [
|
| 710 |
+
len(waveform) for waveform in waveform_list
|
| 711 |
+
] # these lengths will be used to remove padding later
|
| 712 |
+
|
| 713 |
+
# Batch pad the waveform
|
| 714 |
+
max_length = max(original_waveform_lengths)
|
| 715 |
+
padded_waveform_batch = np.array(
|
| 716 |
+
[
|
| 717 |
+
np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
|
| 718 |
+
for waveform in waveform_list
|
| 719 |
+
],
|
| 720 |
+
dtype=dtype,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
# Promote to float64, since np.fft uses float64 internally
|
| 724 |
+
padded_waveform_batch = padded_waveform_batch.astype(np.float64)
|
| 725 |
+
window = window.astype(np.float64)
|
| 726 |
+
|
| 727 |
+
# Split waveform into frames of frame_length size
|
| 728 |
+
num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length))
|
| 729 |
+
# these lengths will be used to remove padding later
|
| 730 |
+
true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths]
|
| 731 |
+
num_batches = padded_waveform_batch.shape[0]
|
| 732 |
+
|
| 733 |
+
num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
|
| 734 |
+
spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64)
|
| 735 |
+
|
| 736 |
+
# rfft is faster than fft
|
| 737 |
+
fft_func = np.fft.rfft if onesided else np.fft.fft
|
| 738 |
+
buffer = np.zeros((num_batches, fft_length))
|
| 739 |
+
|
| 740 |
+
for frame_idx in range(num_frames):
|
| 741 |
+
timestep = frame_idx * hop_length
|
| 742 |
+
buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
|
| 743 |
+
|
| 744 |
+
if remove_dc_offset:
|
| 745 |
+
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
|
| 746 |
+
|
| 747 |
+
if preemphasis is not None:
|
| 748 |
+
buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
|
| 749 |
+
buffer[:, 0] *= 1 - preemphasis
|
| 750 |
+
|
| 751 |
+
buffer[:, :frame_length] *= window
|
| 752 |
+
|
| 753 |
+
spectrogram[:, frame_idx] = fft_func(buffer)
|
| 754 |
+
|
| 755 |
+
# Note: ** is much faster than np.power
|
| 756 |
+
if power is not None:
|
| 757 |
+
spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
|
| 758 |
+
|
| 759 |
+
# Apply mel filters if provided
|
| 760 |
+
if mel_filters is not None:
|
| 761 |
+
result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1]))
|
| 762 |
+
spectrogram = np.maximum(mel_floor, result)
|
| 763 |
+
|
| 764 |
+
# Convert to log scale if specified
|
| 765 |
+
if power is not None and log_mel is not None:
|
| 766 |
+
if log_mel == "log":
|
| 767 |
+
spectrogram = np.log(spectrogram)
|
| 768 |
+
elif log_mel == "log10":
|
| 769 |
+
spectrogram = np.log10(spectrogram)
|
| 770 |
+
elif log_mel == "dB":
|
| 771 |
+
if power == 1.0:
|
| 772 |
+
spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range)
|
| 773 |
+
elif power == 2.0:
|
| 774 |
+
spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range)
|
| 775 |
+
else:
|
| 776 |
+
raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
|
| 777 |
+
else:
|
| 778 |
+
raise ValueError(f"Unknown log_mel option: {log_mel}")
|
| 779 |
+
|
| 780 |
+
spectrogram = np.asarray(spectrogram, dtype)
|
| 781 |
+
|
| 782 |
+
spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
|
| 783 |
+
|
| 784 |
+
return spectrogram_list
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def power_to_db(
|
| 788 |
+
spectrogram: np.ndarray,
|
| 789 |
+
reference: float = 1.0,
|
| 790 |
+
min_value: float = 1e-10,
|
| 791 |
+
db_range: Optional[float] = None,
|
| 792 |
+
) -> np.ndarray:
|
| 793 |
+
"""
|
| 794 |
+
Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
|
| 795 |
+
logarithm properties for numerical stability.
|
| 796 |
+
|
| 797 |
+
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
|
| 798 |
+
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
|
| 799 |
+
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
|
| 800 |
+
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
|
| 801 |
+
|
| 802 |
+
Based on the implementation of `librosa.power_to_db`.
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
spectrogram (`np.ndarray`):
|
| 806 |
+
The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
|
| 807 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 808 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 809 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 810 |
+
min_value (`float`, *optional*, defaults to `1e-10`):
|
| 811 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 812 |
+
`log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
|
| 813 |
+
db_range (`float`, *optional*):
|
| 814 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 815 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 816 |
+
|
| 817 |
+
Returns:
|
| 818 |
+
`np.ndarray`: the spectrogram in decibels
|
| 819 |
+
"""
|
| 820 |
+
if reference <= 0.0:
|
| 821 |
+
raise ValueError("reference must be greater than zero")
|
| 822 |
+
if min_value <= 0.0:
|
| 823 |
+
raise ValueError("min_value must be greater than zero")
|
| 824 |
+
|
| 825 |
+
reference = max(min_value, reference)
|
| 826 |
+
|
| 827 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
| 828 |
+
spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
|
| 829 |
+
|
| 830 |
+
if db_range is not None:
|
| 831 |
+
if db_range <= 0.0:
|
| 832 |
+
raise ValueError("db_range must be greater than zero")
|
| 833 |
+
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
|
| 834 |
+
|
| 835 |
+
return spectrogram
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def power_to_db_batch(
|
| 839 |
+
spectrogram: np.ndarray,
|
| 840 |
+
reference: float = 1.0,
|
| 841 |
+
min_value: float = 1e-10,
|
| 842 |
+
db_range: Optional[float] = None,
|
| 843 |
+
) -> np.ndarray:
|
| 844 |
+
"""
|
| 845 |
+
Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
|
| 846 |
+
using basic logarithm properties for numerical stability.
|
| 847 |
+
|
| 848 |
+
This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram.
|
| 849 |
+
|
| 850 |
+
Args:
|
| 851 |
+
spectrogram (`np.ndarray`):
|
| 852 |
+
The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
|
| 853 |
+
Note that a power spectrogram has the amplitudes squared!
|
| 854 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 855 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 856 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 857 |
+
min_value (`float`, *optional*, defaults to `1e-10`):
|
| 858 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 859 |
+
`log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
|
| 860 |
+
db_range (`float`, *optional*):
|
| 861 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 862 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 863 |
+
|
| 864 |
+
Returns:
|
| 865 |
+
`np.ndarray`: the batch of spectrograms in decibels
|
| 866 |
+
"""
|
| 867 |
+
if reference <= 0.0:
|
| 868 |
+
raise ValueError("reference must be greater than zero")
|
| 869 |
+
if min_value <= 0.0:
|
| 870 |
+
raise ValueError("min_value must be greater than zero")
|
| 871 |
+
|
| 872 |
+
reference = max(min_value, reference)
|
| 873 |
+
|
| 874 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
| 875 |
+
spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
|
| 876 |
+
|
| 877 |
+
if db_range is not None:
|
| 878 |
+
if db_range <= 0.0:
|
| 879 |
+
raise ValueError("db_range must be greater than zero")
|
| 880 |
+
# Apply db_range clipping per batch item
|
| 881 |
+
max_values = spectrogram.max(axis=(1, 2), keepdims=True)
|
| 882 |
+
spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
|
| 883 |
+
|
| 884 |
+
return spectrogram
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def amplitude_to_db(
|
| 888 |
+
spectrogram: np.ndarray,
|
| 889 |
+
reference: float = 1.0,
|
| 890 |
+
min_value: float = 1e-5,
|
| 891 |
+
db_range: Optional[float] = None,
|
| 892 |
+
) -> np.ndarray:
|
| 893 |
+
"""
|
| 894 |
+
Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
|
| 895 |
+
basic logarithm properties for numerical stability.
|
| 896 |
+
|
| 897 |
+
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
|
| 898 |
+
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
|
| 899 |
+
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
|
| 900 |
+
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
|
| 901 |
+
|
| 902 |
+
Args:
|
| 903 |
+
spectrogram (`np.ndarray`):
|
| 904 |
+
The input amplitude (mel) spectrogram.
|
| 905 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 906 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 907 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 908 |
+
min_value (`float`, *optional*, defaults to `1e-5`):
|
| 909 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 910 |
+
`log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
|
| 911 |
+
db_range (`float`, *optional*):
|
| 912 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 913 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 914 |
+
|
| 915 |
+
Returns:
|
| 916 |
+
`np.ndarray`: the spectrogram in decibels
|
| 917 |
+
"""
|
| 918 |
+
if reference <= 0.0:
|
| 919 |
+
raise ValueError("reference must be greater than zero")
|
| 920 |
+
if min_value <= 0.0:
|
| 921 |
+
raise ValueError("min_value must be greater than zero")
|
| 922 |
+
|
| 923 |
+
reference = max(min_value, reference)
|
| 924 |
+
|
| 925 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
| 926 |
+
spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
|
| 927 |
+
|
| 928 |
+
if db_range is not None:
|
| 929 |
+
if db_range <= 0.0:
|
| 930 |
+
raise ValueError("db_range must be greater than zero")
|
| 931 |
+
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
|
| 932 |
+
|
| 933 |
+
return spectrogram
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
def amplitude_to_db_batch(
|
| 937 |
+
spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: Optional[float] = None
|
| 938 |
+
) -> np.ndarray:
|
| 939 |
+
"""
|
| 940 |
+
Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
|
| 941 |
+
using basic logarithm properties for numerical stability.
|
| 942 |
+
|
| 943 |
+
The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram.
|
| 944 |
+
|
| 945 |
+
Args:
|
| 946 |
+
spectrogram (`np.ndarray`):
|
| 947 |
+
The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
|
| 948 |
+
reference (`float`, *optional*, defaults to 1.0):
|
| 949 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
| 950 |
+
the loudest part to 0 dB. Must be greater than zero.
|
| 951 |
+
min_value (`float`, *optional*, defaults to `1e-5`):
|
| 952 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
| 953 |
+
`log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
|
| 954 |
+
db_range (`float`, *optional*):
|
| 955 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
| 956 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
| 957 |
+
|
| 958 |
+
Returns:
|
| 959 |
+
`np.ndarray`: the batch of spectrograms in decibels
|
| 960 |
+
"""
|
| 961 |
+
if reference <= 0.0:
|
| 962 |
+
raise ValueError("reference must be greater than zero")
|
| 963 |
+
if min_value <= 0.0:
|
| 964 |
+
raise ValueError("min_value must be greater than zero")
|
| 965 |
+
|
| 966 |
+
reference = max(min_value, reference)
|
| 967 |
+
|
| 968 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
| 969 |
+
spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
|
| 970 |
+
|
| 971 |
+
if db_range is not None:
|
| 972 |
+
if db_range <= 0.0:
|
| 973 |
+
raise ValueError("db_range must be greater than zero")
|
| 974 |
+
# Apply db_range clipping per batch item
|
| 975 |
+
max_values = spectrogram.max(axis=(1, 2), keepdims=True)
|
| 976 |
+
spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
|
| 977 |
+
|
| 978 |
+
return spectrogram
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
### deprecated functions below this line ###
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
def get_mel_filter_banks(
|
| 985 |
+
nb_frequency_bins: int,
|
| 986 |
+
nb_mel_filters: int,
|
| 987 |
+
frequency_min: float,
|
| 988 |
+
frequency_max: float,
|
| 989 |
+
sample_rate: int,
|
| 990 |
+
norm: Optional[str] = None,
|
| 991 |
+
mel_scale: str = "htk",
|
| 992 |
+
) -> np.array:
|
| 993 |
+
warnings.warn(
|
| 994 |
+
"The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
|
| 995 |
+
FutureWarning,
|
| 996 |
+
)
|
| 997 |
+
return mel_filter_bank(
|
| 998 |
+
num_frequency_bins=nb_frequency_bins,
|
| 999 |
+
num_mel_filters=nb_mel_filters,
|
| 1000 |
+
min_frequency=frequency_min,
|
| 1001 |
+
max_frequency=frequency_max,
|
| 1002 |
+
sampling_rate=sample_rate,
|
| 1003 |
+
norm=norm,
|
| 1004 |
+
mel_scale=mel_scale,
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
|
| 1009 |
+
"""
|
| 1010 |
+
In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed
|
| 1011 |
+
segments called `frames`.
|
| 1012 |
+
|
| 1013 |
+
The window length (window_length) defines how much of the signal is contained in each frame, while the hop length
|
| 1014 |
+
defines the step between the beginning of each new frame.
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
Args:
|
| 1018 |
+
waveform (`np.array` of shape `(sample_length,)`):
|
| 1019 |
+
The raw waveform which will be split into smaller chunks.
|
| 1020 |
+
hop_length (`int`, *optional*, defaults to 160):
|
| 1021 |
+
Step between each window of the waveform.
|
| 1022 |
+
fft_window_size (`int`, *optional*, defaults to 400):
|
| 1023 |
+
Defines the size of the window.
|
| 1024 |
+
center (`bool`, defaults to `True`):
|
| 1025 |
+
Whether or not to center each frame around the middle of the frame. Centering is done by reflecting the
|
| 1026 |
+
waveform on the left and on the right.
|
| 1027 |
+
|
| 1028 |
+
Return:
|
| 1029 |
+
framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
|
| 1030 |
+
The framed waveforms that can be fed to `np.fft`.
|
| 1031 |
+
"""
|
| 1032 |
+
warnings.warn(
|
| 1033 |
+
"The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
|
| 1034 |
+
FutureWarning,
|
| 1035 |
+
)
|
| 1036 |
+
frames = []
|
| 1037 |
+
for i in range(0, waveform.shape[0] + 1, hop_length):
|
| 1038 |
+
if center:
|
| 1039 |
+
half_window = (fft_window_size - 1) // 2 + 1
|
| 1040 |
+
start = i - half_window if i > half_window else 0
|
| 1041 |
+
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
|
| 1042 |
+
frame = waveform[start:end]
|
| 1043 |
+
if start == 0:
|
| 1044 |
+
padd_width = (-i + half_window, 0)
|
| 1045 |
+
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
| 1046 |
+
|
| 1047 |
+
elif end == waveform.shape[0]:
|
| 1048 |
+
padd_width = (0, (i - waveform.shape[0] + half_window))
|
| 1049 |
+
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
| 1050 |
+
|
| 1051 |
+
else:
|
| 1052 |
+
frame = waveform[i : i + fft_window_size]
|
| 1053 |
+
frame_width = frame.shape[0]
|
| 1054 |
+
if frame_width < waveform.shape[0]:
|
| 1055 |
+
frame = np.lib.pad(
|
| 1056 |
+
frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0
|
| 1057 |
+
)
|
| 1058 |
+
frames.append(frame)
|
| 1059 |
+
|
| 1060 |
+
frames = np.stack(frames, 0)
|
| 1061 |
+
return frames
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
|
| 1065 |
+
"""
|
| 1066 |
+
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
|
| 1067 |
+
as `torch.stft`.
|
| 1068 |
+
|
| 1069 |
+
Args:
|
| 1070 |
+
frames (`np.array` of dimension `(num_frames, fft_window_size)`):
|
| 1071 |
+
A framed audio signal obtained using `audio_utils.fram_wav`.
|
| 1072 |
+
windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`:
|
| 1073 |
+
A array representing the function that will be used to reduces the amplitude of the discontinuities at the
|
| 1074 |
+
boundaries of each frame when computing the STFT. Each frame will be multiplied by the windowing_function.
|
| 1075 |
+
For more information on the discontinuities, called *Spectral leakage*, refer to [this
|
| 1076 |
+
tutorial]https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf
|
| 1077 |
+
fft_window_size (`int`, *optional*):
|
| 1078 |
+
Size of the window om which the Fourier transform is applied. This controls the frequency resolution of the
|
| 1079 |
+
spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. The number of
|
| 1080 |
+
frequency bins (`nb_frequency_bins`) used to divide the window into equal strips is equal to
|
| 1081 |
+
`(1+fft_window_size)//2`. An increase of the fft_window_size slows the calculus time proportionnally.
|
| 1082 |
+
|
| 1083 |
+
Example:
|
| 1084 |
+
|
| 1085 |
+
```python
|
| 1086 |
+
>>> from transformers.audio_utils import stft, fram_wave
|
| 1087 |
+
>>> import numpy as np
|
| 1088 |
+
|
| 1089 |
+
>>> audio = np.random.rand(50)
|
| 1090 |
+
>>> fft_window_size = 10
|
| 1091 |
+
>>> hop_length = 2
|
| 1092 |
+
>>> framed_audio = fram_wave(audio, hop_length, fft_window_size)
|
| 1093 |
+
>>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1))
|
| 1094 |
+
```
|
| 1095 |
+
|
| 1096 |
+
Returns:
|
| 1097 |
+
spectrogram (`np.ndarray`):
|
| 1098 |
+
A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm
|
| 1099 |
+
"""
|
| 1100 |
+
warnings.warn(
|
| 1101 |
+
"The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
|
| 1102 |
+
FutureWarning,
|
| 1103 |
+
)
|
| 1104 |
+
frame_size = frames.shape[1]
|
| 1105 |
+
|
| 1106 |
+
if fft_window_size is None:
|
| 1107 |
+
fft_window_size = frame_size
|
| 1108 |
+
|
| 1109 |
+
if fft_window_size < frame_size:
|
| 1110 |
+
raise ValueError("FFT size must greater or equal the frame size")
|
| 1111 |
+
# number of FFT bins to store
|
| 1112 |
+
nb_frequency_bins = (fft_window_size >> 1) + 1
|
| 1113 |
+
|
| 1114 |
+
spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64)
|
| 1115 |
+
fft_signal = np.zeros(fft_window_size)
|
| 1116 |
+
|
| 1117 |
+
for f, frame in enumerate(frames):
|
| 1118 |
+
if windowing_function is not None:
|
| 1119 |
+
np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
|
| 1120 |
+
else:
|
| 1121 |
+
fft_signal[:frame_size] = frame
|
| 1122 |
+
spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
|
| 1123 |
+
return spectrogram.T
|
.venv/lib/python3.11/site-packages/transformers/cache_utils.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/transformers/configuration_utils.py
ADDED
|
@@ -0,0 +1,1187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""Configuration base class and utilities."""
|
| 17 |
+
|
| 18 |
+
import copy
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
import warnings
|
| 23 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
from packaging import version
|
| 26 |
+
|
| 27 |
+
from . import __version__
|
| 28 |
+
from .dynamic_module_utils import custom_object_save
|
| 29 |
+
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
| 30 |
+
from .utils import (
|
| 31 |
+
CONFIG_NAME,
|
| 32 |
+
PushToHubMixin,
|
| 33 |
+
add_model_info_to_auto_map,
|
| 34 |
+
add_model_info_to_custom_pipelines,
|
| 35 |
+
cached_file,
|
| 36 |
+
copy_func,
|
| 37 |
+
download_url,
|
| 38 |
+
extract_commit_hash,
|
| 39 |
+
is_remote_url,
|
| 40 |
+
is_torch_available,
|
| 41 |
+
logging,
|
| 42 |
+
)
|
| 43 |
+
from .utils.generic import is_timm_config_dict
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PretrainedConfig(PushToHubMixin):
|
| 52 |
+
# no-format
|
| 53 |
+
r"""
|
| 54 |
+
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
| 55 |
+
methods for loading/downloading/saving configurations.
|
| 56 |
+
|
| 57 |
+
<Tip>
|
| 58 |
+
|
| 59 |
+
A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
|
| 60 |
+
initialize a model does **not** load the model weights. It only affects the model's configuration.
|
| 61 |
+
|
| 62 |
+
</Tip>
|
| 63 |
+
|
| 64 |
+
Class attributes (overridden by derived classes):
|
| 65 |
+
|
| 66 |
+
- **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
|
| 67 |
+
the correct object in [`~transformers.AutoConfig`].
|
| 68 |
+
- **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
|
| 69 |
+
config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
|
| 70 |
+
[`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
|
| 71 |
+
- **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
|
| 72 |
+
outputs of the model during inference.
|
| 73 |
+
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
|
| 74 |
+
naming of attributes.
|
| 75 |
+
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
|
| 76 |
+
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
|
| 77 |
+
|
| 78 |
+
Common attributes (present in all subclasses):
|
| 79 |
+
|
| 80 |
+
- **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
|
| 81 |
+
embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
|
| 82 |
+
- **hidden_size** (`int`) -- The hidden size of the model.
|
| 83 |
+
- **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
|
| 84 |
+
model.
|
| 85 |
+
- **num_hidden_layers** (`int`) -- The number of blocks in the model.
|
| 86 |
+
|
| 87 |
+
<Tip warning={true}>
|
| 88 |
+
|
| 89 |
+
Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading
|
| 90 |
+
some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set
|
| 91 |
+
them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more
|
| 92 |
+
information about the individual parameters.
|
| 93 |
+
|
| 94 |
+
</Tip>
|
| 95 |
+
|
| 96 |
+
Arg:
|
| 97 |
+
name_or_path (`str`, *optional*, defaults to `""`):
|
| 98 |
+
Store the string that was passed to [`PreTrainedModel.from_pretrained`] or
|
| 99 |
+
[`TFPreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` if the configuration was created
|
| 100 |
+
with such a method.
|
| 101 |
+
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
| 102 |
+
Whether or not the model should return all hidden-states.
|
| 103 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 104 |
+
Whether or not the model should returns all attentions.
|
| 105 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 106 |
+
Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
|
| 107 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `False`):
|
| 108 |
+
Whether the model is used as an encoder/decoder or not.
|
| 109 |
+
is_decoder (`bool`, *optional*, defaults to `False`):
|
| 110 |
+
Whether the model is used as decoder or not (in which case it's used as an encoder).
|
| 111 |
+
cross_attention_hidden_size** (`bool`, *optional*):
|
| 112 |
+
The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
|
| 113 |
+
setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
|
| 114 |
+
add_cross_attention (`bool`, *optional*, defaults to `False`):
|
| 115 |
+
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
|
| 116 |
+
that can be used as decoder models within the [`EncoderDecoderModel`] class, which consists of all models
|
| 117 |
+
in `AUTO_MODELS_FOR_CAUSAL_LM`.
|
| 118 |
+
tie_encoder_decoder (`bool`, *optional*, defaults to `False`):
|
| 119 |
+
Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
|
| 120 |
+
and decoder model to have the exact same parameter names.
|
| 121 |
+
prune_heads (`Dict[int, List[int]]`, *optional*, defaults to `{}`):
|
| 122 |
+
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of
|
| 123 |
+
heads to prune in said layer.
|
| 124 |
+
|
| 125 |
+
For instance `{1: [0, 2], 2: [2, 3]}` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
| 126 |
+
chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
|
| 127 |
+
The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
|
| 128 |
+
the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
|
| 129 |
+
sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
|
| 130 |
+
Forward Chunking work?](../glossary.html#feed-forward-chunking).
|
| 131 |
+
|
| 132 |
+
> Parameters for fine-tuning tasks
|
| 133 |
+
|
| 134 |
+
architectures (`List[str]`, *optional*):
|
| 135 |
+
Model architectures that can be used with the model pretrained weights.
|
| 136 |
+
finetuning_task (`str`, *optional*):
|
| 137 |
+
Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow
|
| 138 |
+
or PyTorch) checkpoint.
|
| 139 |
+
id2label (`Dict[int, str]`, *optional*):
|
| 140 |
+
A map from index (for instance prediction index, or target index) to label.
|
| 141 |
+
label2id (`Dict[str, int]`, *optional*): A map from label to index for the model.
|
| 142 |
+
num_labels (`int`, *optional*):
|
| 143 |
+
Number of labels to use in the last layer added to the model, typically for a classification task.
|
| 144 |
+
task_specific_params (`Dict[str, Any]`, *optional*):
|
| 145 |
+
Additional keyword arguments to store for the current task.
|
| 146 |
+
problem_type (`str`, *optional*):
|
| 147 |
+
Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
|
| 148 |
+
`"single_label_classification"` or `"multi_label_classification"`.
|
| 149 |
+
|
| 150 |
+
> Parameters linked to the tokenizer
|
| 151 |
+
|
| 152 |
+
tokenizer_class (`str`, *optional*):
|
| 153 |
+
The name of the associated tokenizer class to use (if none is set, will use the tokenizer associated to the
|
| 154 |
+
model by default).
|
| 155 |
+
prefix (`str`, *optional*):
|
| 156 |
+
A specific prompt that should be added at the beginning of each text before calling the model.
|
| 157 |
+
bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token.
|
| 158 |
+
pad_token_id (`int`, *optional*): The id of the _padding_ token.
|
| 159 |
+
eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token.
|
| 160 |
+
decoder_start_token_id (`int`, *optional*):
|
| 161 |
+
If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
|
| 162 |
+
sep_token_id (`int`, *optional*): The id of the _separation_ token.
|
| 163 |
+
|
| 164 |
+
> PyTorch specific parameters
|
| 165 |
+
|
| 166 |
+
torchscript (`bool`, *optional*, defaults to `False`):
|
| 167 |
+
Whether or not the model should be used with Torchscript.
|
| 168 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
| 169 |
+
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
| 170 |
+
model has a output word embedding layer.
|
| 171 |
+
torch_dtype (`str`, *optional*):
|
| 172 |
+
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
|
| 173 |
+
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
|
| 174 |
+
model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
|
| 175 |
+
`float16` weights. Since the config object is stored in plain text, this attribute contains just the
|
| 176 |
+
floating type string without the `torch.` prefix. For example, for `torch.float16` ``torch_dtype` is the
|
| 177 |
+
`"float16"` string.
|
| 178 |
+
|
| 179 |
+
This attribute is currently not being used during model loading time, but this may change in the future
|
| 180 |
+
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
|
| 181 |
+
|
| 182 |
+
> TensorFlow specific parameters
|
| 183 |
+
|
| 184 |
+
use_bfloat16 (`bool`, *optional*, defaults to `False`):
|
| 185 |
+
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
|
| 186 |
+
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
|
| 187 |
+
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
|
| 188 |
+
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
|
| 189 |
+
v5.
|
| 190 |
+
loss_type (`str`, *optional*):
|
| 191 |
+
The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will
|
| 192 |
+
be automatically infered from the model architecture.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
model_type: str = ""
|
| 196 |
+
base_config_key: str = ""
|
| 197 |
+
sub_configs: Dict[str, "PretrainedConfig"] = {}
|
| 198 |
+
is_composition: bool = False
|
| 199 |
+
attribute_map: Dict[str, str] = {}
|
| 200 |
+
base_model_tp_plan: Optional[Dict[str, Any]] = None
|
| 201 |
+
_auto_class: Optional[str] = None
|
| 202 |
+
|
| 203 |
+
def __setattr__(self, key, value):
|
| 204 |
+
if key in super().__getattribute__("attribute_map"):
|
| 205 |
+
key = super().__getattribute__("attribute_map")[key]
|
| 206 |
+
super().__setattr__(key, value)
|
| 207 |
+
|
| 208 |
+
def __getattribute__(self, key):
|
| 209 |
+
if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
|
| 210 |
+
key = super().__getattribute__("attribute_map")[key]
|
| 211 |
+
return super().__getattribute__(key)
|
| 212 |
+
|
| 213 |
+
def __init__(self, **kwargs):
|
| 214 |
+
# Attributes with defaults
|
| 215 |
+
self.return_dict = kwargs.pop("return_dict", True)
|
| 216 |
+
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
| 217 |
+
self.output_attentions = kwargs.pop("output_attentions", False)
|
| 218 |
+
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
| 219 |
+
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
|
| 220 |
+
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
| 221 |
+
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
|
| 222 |
+
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
| 223 |
+
self.tie_word_embeddings = kwargs.pop(
|
| 224 |
+
"tie_word_embeddings", True
|
| 225 |
+
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
|
| 226 |
+
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
| 227 |
+
|
| 228 |
+
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
| 229 |
+
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
| 230 |
+
self.is_decoder = kwargs.pop("is_decoder", False)
|
| 231 |
+
self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
|
| 232 |
+
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
| 233 |
+
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
|
| 234 |
+
|
| 235 |
+
# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
|
| 236 |
+
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
|
| 237 |
+
for parameter_name, default_value in self._get_global_generation_defaults().items():
|
| 238 |
+
setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
|
| 239 |
+
|
| 240 |
+
# Fine-tuning task arguments
|
| 241 |
+
self.architectures = kwargs.pop("architectures", None)
|
| 242 |
+
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
| 243 |
+
self.id2label = kwargs.pop("id2label", None)
|
| 244 |
+
self.label2id = kwargs.pop("label2id", None)
|
| 245 |
+
if self.label2id is not None and not isinstance(self.label2id, dict):
|
| 246 |
+
raise ValueError("Argument label2id should be a dictionary.")
|
| 247 |
+
if self.id2label is not None:
|
| 248 |
+
if not isinstance(self.id2label, dict):
|
| 249 |
+
raise ValueError("Argument id2label should be a dictionary.")
|
| 250 |
+
num_labels = kwargs.pop("num_labels", None)
|
| 251 |
+
if num_labels is not None and len(self.id2label) != num_labels:
|
| 252 |
+
logger.warning(
|
| 253 |
+
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
|
| 254 |
+
f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
|
| 255 |
+
)
|
| 256 |
+
self.id2label = {int(key): value for key, value in self.id2label.items()}
|
| 257 |
+
# Keys are always strings in JSON so convert ids to int here.
|
| 258 |
+
else:
|
| 259 |
+
self.num_labels = kwargs.pop("num_labels", 2)
|
| 260 |
+
|
| 261 |
+
if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
|
| 262 |
+
# we will start using self.torch_dtype in v5, but to be consistent with
|
| 263 |
+
# from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
|
| 264 |
+
if is_torch_available():
|
| 265 |
+
import torch
|
| 266 |
+
|
| 267 |
+
self.torch_dtype = getattr(torch, self.torch_dtype)
|
| 268 |
+
|
| 269 |
+
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
| 270 |
+
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
| 271 |
+
self.prefix = kwargs.pop("prefix", None)
|
| 272 |
+
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
| 273 |
+
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
| 274 |
+
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
| 275 |
+
self.sep_token_id = kwargs.pop("sep_token_id", None)
|
| 276 |
+
|
| 277 |
+
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
| 278 |
+
|
| 279 |
+
# task specific arguments
|
| 280 |
+
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
| 281 |
+
|
| 282 |
+
# regression / multi-label classification
|
| 283 |
+
self.problem_type = kwargs.pop("problem_type", None)
|
| 284 |
+
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
|
| 285 |
+
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
|
| 286 |
+
raise ValueError(
|
| 287 |
+
f"The config parameter `problem_type` was not understood: received {self.problem_type} "
|
| 288 |
+
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# TPU arguments
|
| 292 |
+
if kwargs.pop("xla_device", None) is not None:
|
| 293 |
+
logger.warning(
|
| 294 |
+
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
|
| 295 |
+
"safely remove it from your `config.json` file."
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Name or path to the pretrained checkpoint
|
| 299 |
+
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
| 300 |
+
# Config hash
|
| 301 |
+
self._commit_hash = kwargs.pop("_commit_hash", None)
|
| 302 |
+
|
| 303 |
+
# Attention implementation to use, if relevant.
|
| 304 |
+
self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
|
| 305 |
+
self._attn_implementation_autoset = False
|
| 306 |
+
|
| 307 |
+
# Drop the transformers version info
|
| 308 |
+
self.transformers_version = kwargs.pop("transformers_version", None)
|
| 309 |
+
|
| 310 |
+
# Deal with gradient checkpointing
|
| 311 |
+
if kwargs.get("gradient_checkpointing", False):
|
| 312 |
+
warnings.warn(
|
| 313 |
+
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
|
| 314 |
+
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
|
| 315 |
+
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Additional attributes without default values
|
| 319 |
+
for key, value in kwargs.items():
|
| 320 |
+
try:
|
| 321 |
+
setattr(self, key, value)
|
| 322 |
+
except AttributeError as err:
|
| 323 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
| 324 |
+
raise err
|
| 325 |
+
|
| 326 |
+
@property
|
| 327 |
+
def name_or_path(self) -> str:
|
| 328 |
+
return getattr(self, "_name_or_path", None)
|
| 329 |
+
|
| 330 |
+
@name_or_path.setter
|
| 331 |
+
def name_or_path(self, value):
|
| 332 |
+
self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
|
| 333 |
+
|
| 334 |
+
@property
|
| 335 |
+
def use_return_dict(self) -> bool:
|
| 336 |
+
"""
|
| 337 |
+
`bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
|
| 338 |
+
"""
|
| 339 |
+
# If torchscript is set, force `return_dict=False` to avoid jit errors
|
| 340 |
+
return self.return_dict and not self.torchscript
|
| 341 |
+
|
| 342 |
+
@property
|
| 343 |
+
def num_labels(self) -> int:
|
| 344 |
+
"""
|
| 345 |
+
`int`: The number of labels for classification models.
|
| 346 |
+
"""
|
| 347 |
+
return len(self.id2label)
|
| 348 |
+
|
| 349 |
+
@num_labels.setter
|
| 350 |
+
def num_labels(self, num_labels: int):
|
| 351 |
+
if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
|
| 352 |
+
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
|
| 353 |
+
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
| 354 |
+
|
| 355 |
+
@property
|
| 356 |
+
def _attn_implementation(self):
|
| 357 |
+
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
|
| 358 |
+
if hasattr(self, "_attn_implementation_internal"):
|
| 359 |
+
if self._attn_implementation_internal is None:
|
| 360 |
+
# `config.attn_implementation` should never be None, for backward compatibility.
|
| 361 |
+
return "eager"
|
| 362 |
+
else:
|
| 363 |
+
return self._attn_implementation_internal
|
| 364 |
+
else:
|
| 365 |
+
return "eager"
|
| 366 |
+
|
| 367 |
+
@_attn_implementation.setter
|
| 368 |
+
def _attn_implementation(self, value):
|
| 369 |
+
self._attn_implementation_internal = value
|
| 370 |
+
|
| 371 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
| 372 |
+
"""
|
| 373 |
+
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
| 374 |
+
[`~PretrainedConfig.from_pretrained`] class method.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
save_directory (`str` or `os.PathLike`):
|
| 378 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
| 379 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 380 |
+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
| 381 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 382 |
+
namespace).
|
| 383 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 384 |
+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 385 |
+
"""
|
| 386 |
+
self._set_token_in_kwargs(kwargs)
|
| 387 |
+
|
| 388 |
+
if os.path.isfile(save_directory):
|
| 389 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 390 |
+
|
| 391 |
+
non_default_generation_parameters = self._get_non_default_generation_parameters()
|
| 392 |
+
if len(non_default_generation_parameters) > 0:
|
| 393 |
+
# TODO (joao): this should be an exception if the user has modified the loaded config. See #33886
|
| 394 |
+
warnings.warn(
|
| 395 |
+
"Some non-default generation parameters are set in the model config. These should go into either a) "
|
| 396 |
+
"`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
|
| 397 |
+
"(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)."
|
| 398 |
+
"This warning will become an exception in the future."
|
| 399 |
+
f"\nNon-default generation parameters: {str(non_default_generation_parameters)}",
|
| 400 |
+
UserWarning,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 404 |
+
|
| 405 |
+
if push_to_hub:
|
| 406 |
+
commit_message = kwargs.pop("commit_message", None)
|
| 407 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
| 408 |
+
repo_id = self._create_repo(repo_id, **kwargs)
|
| 409 |
+
files_timestamps = self._get_files_timestamps(save_directory)
|
| 410 |
+
|
| 411 |
+
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
| 412 |
+
# loaded from the Hub.
|
| 413 |
+
if self._auto_class is not None:
|
| 414 |
+
custom_object_save(self, save_directory, config=self)
|
| 415 |
+
|
| 416 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 417 |
+
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
| 418 |
+
|
| 419 |
+
self.to_json_file(output_config_file, use_diff=True)
|
| 420 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
| 421 |
+
|
| 422 |
+
if push_to_hub:
|
| 423 |
+
self._upload_modified_files(
|
| 424 |
+
save_directory,
|
| 425 |
+
repo_id,
|
| 426 |
+
files_timestamps,
|
| 427 |
+
commit_message=commit_message,
|
| 428 |
+
token=kwargs.get("token"),
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
@staticmethod
|
| 432 |
+
def _set_token_in_kwargs(kwargs, token=None):
|
| 433 |
+
"""Temporary method to deal with `token` and `use_auth_token`.
|
| 434 |
+
|
| 435 |
+
This method is to avoid apply the same changes in all model config classes that overwrite `from_pretrained`.
|
| 436 |
+
|
| 437 |
+
Need to clean up `use_auth_token` in a follow PR.
|
| 438 |
+
"""
|
| 439 |
+
# Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet.
|
| 440 |
+
if token is None:
|
| 441 |
+
token = kwargs.pop("token", None)
|
| 442 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 443 |
+
|
| 444 |
+
if use_auth_token is not None:
|
| 445 |
+
warnings.warn(
|
| 446 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 447 |
+
FutureWarning,
|
| 448 |
+
)
|
| 449 |
+
if token is not None:
|
| 450 |
+
raise ValueError(
|
| 451 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 452 |
+
)
|
| 453 |
+
token = use_auth_token
|
| 454 |
+
|
| 455 |
+
if token is not None:
|
| 456 |
+
kwargs["token"] = token
|
| 457 |
+
|
| 458 |
+
@classmethod
|
| 459 |
+
def from_pretrained(
|
| 460 |
+
cls,
|
| 461 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 462 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 463 |
+
force_download: bool = False,
|
| 464 |
+
local_files_only: bool = False,
|
| 465 |
+
token: Optional[Union[str, bool]] = None,
|
| 466 |
+
revision: str = "main",
|
| 467 |
+
**kwargs,
|
| 468 |
+
) -> "PretrainedConfig":
|
| 469 |
+
r"""
|
| 470 |
+
Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 474 |
+
This can be either:
|
| 475 |
+
|
| 476 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 477 |
+
huggingface.co.
|
| 478 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 479 |
+
[`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 480 |
+
- a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
|
| 481 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 482 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 483 |
+
standard cache should not be used.
|
| 484 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 485 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if
|
| 486 |
+
they exist.
|
| 487 |
+
resume_download:
|
| 488 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 489 |
+
Will be removed in v5 of Transformers.
|
| 490 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 491 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 492 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 493 |
+
token (`str` or `bool`, *optional*):
|
| 494 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
| 495 |
+
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
| 496 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 497 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 498 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 499 |
+
identifier allowed by git.
|
| 500 |
+
|
| 501 |
+
<Tip>
|
| 502 |
+
|
| 503 |
+
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
|
| 504 |
+
|
| 505 |
+
</Tip>
|
| 506 |
+
|
| 507 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 508 |
+
If `False`, then this function returns just the final configuration object.
|
| 509 |
+
|
| 510 |
+
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
| 511 |
+
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
| 512 |
+
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
| 513 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 514 |
+
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
| 515 |
+
specify the folder name here.
|
| 516 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 517 |
+
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
| 518 |
+
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
| 519 |
+
by the `return_unused_kwargs` keyword parameter.
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
[`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
|
| 523 |
+
|
| 524 |
+
Examples:
|
| 525 |
+
|
| 526 |
+
```python
|
| 527 |
+
# We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
|
| 528 |
+
# derived class: BertConfig
|
| 529 |
+
config = BertConfig.from_pretrained(
|
| 530 |
+
"google-bert/bert-base-uncased"
|
| 531 |
+
) # Download configuration from huggingface.co and cache.
|
| 532 |
+
config = BertConfig.from_pretrained(
|
| 533 |
+
"./test/saved_model/"
|
| 534 |
+
) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
|
| 535 |
+
config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
|
| 536 |
+
config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
|
| 537 |
+
assert config.output_attentions == True
|
| 538 |
+
config, unused_kwargs = BertConfig.from_pretrained(
|
| 539 |
+
"google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
|
| 540 |
+
)
|
| 541 |
+
assert config.output_attentions == True
|
| 542 |
+
assert unused_kwargs == {"foo": False}
|
| 543 |
+
```"""
|
| 544 |
+
kwargs["cache_dir"] = cache_dir
|
| 545 |
+
kwargs["force_download"] = force_download
|
| 546 |
+
kwargs["local_files_only"] = local_files_only
|
| 547 |
+
kwargs["revision"] = revision
|
| 548 |
+
|
| 549 |
+
cls._set_token_in_kwargs(kwargs, token)
|
| 550 |
+
|
| 551 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 552 |
+
if cls.base_config_key and cls.base_config_key in config_dict:
|
| 553 |
+
config_dict = config_dict[cls.base_config_key]
|
| 554 |
+
|
| 555 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 556 |
+
# sometimes the config has no `base_config_key` if the config is used in several composite models
|
| 557 |
+
# e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
|
| 558 |
+
for k, v in config_dict.items():
|
| 559 |
+
if isinstance(v, dict) and v.get("model_type") == cls.model_type:
|
| 560 |
+
config_dict = v
|
| 561 |
+
|
| 562 |
+
# raise warning only if we still can't see a match in `model_type`
|
| 563 |
+
if config_dict["model_type"] != cls.model_type:
|
| 564 |
+
logger.warning(
|
| 565 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 566 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 570 |
+
|
| 571 |
+
@classmethod
|
| 572 |
+
def get_config_dict(
|
| 573 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
| 574 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 575 |
+
"""
|
| 576 |
+
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
| 577 |
+
[`PretrainedConfig`] using `from_dict`.
|
| 578 |
+
|
| 579 |
+
Parameters:
|
| 580 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 581 |
+
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
|
| 585 |
+
|
| 586 |
+
"""
|
| 587 |
+
cls._set_token_in_kwargs(kwargs)
|
| 588 |
+
|
| 589 |
+
original_kwargs = copy.deepcopy(kwargs)
|
| 590 |
+
# Get config dict associated with the base config file
|
| 591 |
+
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 592 |
+
if config_dict is None:
|
| 593 |
+
return {}, kwargs
|
| 594 |
+
if "_commit_hash" in config_dict:
|
| 595 |
+
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
| 596 |
+
|
| 597 |
+
# That config file may point us toward another config file to use.
|
| 598 |
+
if "configuration_files" in config_dict:
|
| 599 |
+
configuration_file = get_configuration_file(config_dict["configuration_files"])
|
| 600 |
+
config_dict, kwargs = cls._get_config_dict(
|
| 601 |
+
pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
return config_dict, kwargs
|
| 605 |
+
|
| 606 |
+
@classmethod
|
| 607 |
+
def _get_config_dict(
|
| 608 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
| 609 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 610 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 611 |
+
force_download = kwargs.pop("force_download", False)
|
| 612 |
+
resume_download = kwargs.pop("resume_download", None)
|
| 613 |
+
proxies = kwargs.pop("proxies", None)
|
| 614 |
+
token = kwargs.pop("token", None)
|
| 615 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 616 |
+
revision = kwargs.pop("revision", None)
|
| 617 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 618 |
+
subfolder = kwargs.pop("subfolder", "")
|
| 619 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
| 620 |
+
from_auto_class = kwargs.pop("_from_auto", False)
|
| 621 |
+
commit_hash = kwargs.pop("_commit_hash", None)
|
| 622 |
+
|
| 623 |
+
gguf_file = kwargs.get("gguf_file", None)
|
| 624 |
+
|
| 625 |
+
if trust_remote_code is True:
|
| 626 |
+
logger.warning(
|
| 627 |
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
|
| 628 |
+
" ignored."
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
|
| 632 |
+
if from_pipeline is not None:
|
| 633 |
+
user_agent["using_pipeline"] = from_pipeline
|
| 634 |
+
|
| 635 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
| 636 |
+
|
| 637 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
| 638 |
+
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
| 639 |
+
# Special case when pretrained_model_name_or_path is a local file
|
| 640 |
+
resolved_config_file = pretrained_model_name_or_path
|
| 641 |
+
is_local = True
|
| 642 |
+
elif is_remote_url(pretrained_model_name_or_path):
|
| 643 |
+
configuration_file = pretrained_model_name_or_path if gguf_file is None else gguf_file
|
| 644 |
+
resolved_config_file = download_url(pretrained_model_name_or_path)
|
| 645 |
+
else:
|
| 646 |
+
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
|
| 647 |
+
|
| 648 |
+
try:
|
| 649 |
+
# Load from local folder or from cache or download from model Hub and cache
|
| 650 |
+
resolved_config_file = cached_file(
|
| 651 |
+
pretrained_model_name_or_path,
|
| 652 |
+
configuration_file,
|
| 653 |
+
cache_dir=cache_dir,
|
| 654 |
+
force_download=force_download,
|
| 655 |
+
proxies=proxies,
|
| 656 |
+
resume_download=resume_download,
|
| 657 |
+
local_files_only=local_files_only,
|
| 658 |
+
token=token,
|
| 659 |
+
user_agent=user_agent,
|
| 660 |
+
revision=revision,
|
| 661 |
+
subfolder=subfolder,
|
| 662 |
+
_commit_hash=commit_hash,
|
| 663 |
+
)
|
| 664 |
+
if resolved_config_file is None:
|
| 665 |
+
return None, kwargs
|
| 666 |
+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
| 667 |
+
except EnvironmentError:
|
| 668 |
+
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
| 669 |
+
# the original exception.
|
| 670 |
+
raise
|
| 671 |
+
except Exception:
|
| 672 |
+
# For any other exception, we throw a generic error.
|
| 673 |
+
raise EnvironmentError(
|
| 674 |
+
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
|
| 675 |
+
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
|
| 676 |
+
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
|
| 677 |
+
f" containing a {configuration_file} file"
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
try:
|
| 681 |
+
if gguf_file:
|
| 682 |
+
config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
|
| 683 |
+
else:
|
| 684 |
+
# Load config dict
|
| 685 |
+
config_dict = cls._dict_from_json_file(resolved_config_file)
|
| 686 |
+
|
| 687 |
+
config_dict["_commit_hash"] = commit_hash
|
| 688 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
| 689 |
+
raise EnvironmentError(
|
| 690 |
+
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
if is_local:
|
| 694 |
+
logger.info(f"loading configuration file {resolved_config_file}")
|
| 695 |
+
else:
|
| 696 |
+
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
| 697 |
+
|
| 698 |
+
if "auto_map" in config_dict and not is_local:
|
| 699 |
+
config_dict["auto_map"] = add_model_info_to_auto_map(
|
| 700 |
+
config_dict["auto_map"], pretrained_model_name_or_path
|
| 701 |
+
)
|
| 702 |
+
if "custom_pipelines" in config_dict and not is_local:
|
| 703 |
+
config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
|
| 704 |
+
config_dict["custom_pipelines"], pretrained_model_name_or_path
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
# timm models are not saved with the model_type in the config file
|
| 708 |
+
if "model_type" not in config_dict and is_timm_config_dict(config_dict):
|
| 709 |
+
config_dict["model_type"] = "timm_wrapper"
|
| 710 |
+
|
| 711 |
+
return config_dict, kwargs
|
| 712 |
+
|
| 713 |
+
@classmethod
|
| 714 |
+
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
|
| 715 |
+
"""
|
| 716 |
+
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
|
| 717 |
+
|
| 718 |
+
Args:
|
| 719 |
+
config_dict (`Dict[str, Any]`):
|
| 720 |
+
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
|
| 721 |
+
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
|
| 722 |
+
kwargs (`Dict[str, Any]`):
|
| 723 |
+
Additional parameters from which to initialize the configuration object.
|
| 724 |
+
|
| 725 |
+
Returns:
|
| 726 |
+
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
|
| 727 |
+
"""
|
| 728 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
| 729 |
+
# Those arguments may be passed along for our internal telemetry.
|
| 730 |
+
# We remove them so they don't appear in `return_unused_kwargs`.
|
| 731 |
+
kwargs.pop("_from_auto", None)
|
| 732 |
+
kwargs.pop("_from_pipeline", None)
|
| 733 |
+
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
|
| 734 |
+
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
| 735 |
+
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
| 736 |
+
|
| 737 |
+
# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
|
| 738 |
+
config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
|
| 739 |
+
|
| 740 |
+
config = cls(**config_dict)
|
| 741 |
+
|
| 742 |
+
if hasattr(config, "pruned_heads"):
|
| 743 |
+
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
|
| 744 |
+
|
| 745 |
+
# Update config with kwargs if needed
|
| 746 |
+
if "num_labels" in kwargs and "id2label" in kwargs:
|
| 747 |
+
num_labels = kwargs["num_labels"]
|
| 748 |
+
id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
|
| 749 |
+
if len(id2label) != num_labels:
|
| 750 |
+
raise ValueError(
|
| 751 |
+
f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
|
| 752 |
+
f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
|
| 753 |
+
"one of them."
|
| 754 |
+
)
|
| 755 |
+
to_remove = []
|
| 756 |
+
for key, value in kwargs.items():
|
| 757 |
+
if hasattr(config, key):
|
| 758 |
+
current_attr = getattr(config, key)
|
| 759 |
+
# To authorize passing a custom subconfig as kwarg in models that have nested configs.
|
| 760 |
+
if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
|
| 761 |
+
value = current_attr.__class__(**value)
|
| 762 |
+
setattr(config, key, value)
|
| 763 |
+
if key != "torch_dtype":
|
| 764 |
+
to_remove.append(key)
|
| 765 |
+
for key in to_remove:
|
| 766 |
+
kwargs.pop(key, None)
|
| 767 |
+
|
| 768 |
+
logger.info(f"Model config {config}")
|
| 769 |
+
if return_unused_kwargs:
|
| 770 |
+
return config, kwargs
|
| 771 |
+
else:
|
| 772 |
+
return config
|
| 773 |
+
|
| 774 |
+
@classmethod
|
| 775 |
+
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
|
| 776 |
+
"""
|
| 777 |
+
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
|
| 778 |
+
|
| 779 |
+
Args:
|
| 780 |
+
json_file (`str` or `os.PathLike`):
|
| 781 |
+
Path to the JSON file containing the parameters.
|
| 782 |
+
|
| 783 |
+
Returns:
|
| 784 |
+
[`PretrainedConfig`]: The configuration object instantiated from that JSON file.
|
| 785 |
+
|
| 786 |
+
"""
|
| 787 |
+
config_dict = cls._dict_from_json_file(json_file)
|
| 788 |
+
return cls(**config_dict)
|
| 789 |
+
|
| 790 |
+
@classmethod
|
| 791 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
| 792 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 793 |
+
text = reader.read()
|
| 794 |
+
return json.loads(text)
|
| 795 |
+
|
| 796 |
+
def __eq__(self, other):
|
| 797 |
+
return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)
|
| 798 |
+
|
| 799 |
+
def __repr__(self):
|
| 800 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
| 801 |
+
|
| 802 |
+
def __iter__(self):
|
| 803 |
+
for attr in self.__dict__:
|
| 804 |
+
yield attr
|
| 805 |
+
|
| 806 |
+
def to_diff_dict(self) -> Dict[str, Any]:
|
| 807 |
+
"""
|
| 808 |
+
Removes all attributes from config which correspond to the default config attributes for better readability and
|
| 809 |
+
serializes to a Python dictionary.
|
| 810 |
+
|
| 811 |
+
Returns:
|
| 812 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 813 |
+
"""
|
| 814 |
+
config_dict = self.to_dict()
|
| 815 |
+
|
| 816 |
+
# get the default config dict
|
| 817 |
+
default_config_dict = PretrainedConfig().to_dict()
|
| 818 |
+
|
| 819 |
+
# get class specific config dict
|
| 820 |
+
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
|
| 821 |
+
|
| 822 |
+
serializable_config_dict = {}
|
| 823 |
+
|
| 824 |
+
# only serialize values that differ from the default config
|
| 825 |
+
for key, value in config_dict.items():
|
| 826 |
+
if (
|
| 827 |
+
isinstance(getattr(self, key, None), PretrainedConfig)
|
| 828 |
+
and key in class_config_dict
|
| 829 |
+
and isinstance(class_config_dict[key], dict)
|
| 830 |
+
):
|
| 831 |
+
# For nested configs we need to clean the diff recursively
|
| 832 |
+
diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
|
| 833 |
+
if "model_type" in value:
|
| 834 |
+
# Needs to be set even if it's not in the diff
|
| 835 |
+
diff["model_type"] = value["model_type"]
|
| 836 |
+
if len(diff) > 0:
|
| 837 |
+
serializable_config_dict[key] = diff
|
| 838 |
+
elif (
|
| 839 |
+
key not in default_config_dict
|
| 840 |
+
or key == "transformers_version"
|
| 841 |
+
or value != default_config_dict[key]
|
| 842 |
+
or (key in class_config_dict and value != class_config_dict[key])
|
| 843 |
+
):
|
| 844 |
+
serializable_config_dict[key] = value
|
| 845 |
+
|
| 846 |
+
if hasattr(self, "quantization_config"):
|
| 847 |
+
serializable_config_dict["quantization_config"] = (
|
| 848 |
+
self.quantization_config.to_dict()
|
| 849 |
+
if not isinstance(self.quantization_config, dict)
|
| 850 |
+
else self.quantization_config
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
| 854 |
+
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
|
| 855 |
+
|
| 856 |
+
self.dict_torch_dtype_to_str(serializable_config_dict)
|
| 857 |
+
|
| 858 |
+
if "_attn_implementation_internal" in serializable_config_dict:
|
| 859 |
+
del serializable_config_dict["_attn_implementation_internal"]
|
| 860 |
+
# Do not serialize `base_model_tp_plan` for now
|
| 861 |
+
if "base_model_tp_plan" in serializable_config_dict:
|
| 862 |
+
del serializable_config_dict["base_model_tp_plan"]
|
| 863 |
+
|
| 864 |
+
return serializable_config_dict
|
| 865 |
+
|
| 866 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 867 |
+
"""
|
| 868 |
+
Serializes this instance to a Python dictionary.
|
| 869 |
+
|
| 870 |
+
Returns:
|
| 871 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
| 872 |
+
"""
|
| 873 |
+
output = copy.deepcopy(self.__dict__)
|
| 874 |
+
if hasattr(self.__class__, "model_type"):
|
| 875 |
+
output["model_type"] = self.__class__.model_type
|
| 876 |
+
if "_auto_class" in output:
|
| 877 |
+
del output["_auto_class"]
|
| 878 |
+
if "_commit_hash" in output:
|
| 879 |
+
del output["_commit_hash"]
|
| 880 |
+
if "_attn_implementation_internal" in output:
|
| 881 |
+
del output["_attn_implementation_internal"]
|
| 882 |
+
# Do not serialize `base_model_tp_plan` for now
|
| 883 |
+
if "base_model_tp_plan" in output:
|
| 884 |
+
del output["base_model_tp_plan"]
|
| 885 |
+
|
| 886 |
+
# Transformers version when serializing the model
|
| 887 |
+
output["transformers_version"] = __version__
|
| 888 |
+
|
| 889 |
+
for key, value in output.items():
|
| 890 |
+
# Deal with nested configs like CLIP
|
| 891 |
+
if isinstance(value, PretrainedConfig):
|
| 892 |
+
value = value.to_dict()
|
| 893 |
+
del value["transformers_version"]
|
| 894 |
+
|
| 895 |
+
output[key] = value
|
| 896 |
+
|
| 897 |
+
if hasattr(self, "quantization_config"):
|
| 898 |
+
output["quantization_config"] = (
|
| 899 |
+
self.quantization_config.to_dict()
|
| 900 |
+
if not isinstance(self.quantization_config, dict)
|
| 901 |
+
else self.quantization_config
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
| 905 |
+
_ = output.pop("_pre_quantization_dtype", None)
|
| 906 |
+
|
| 907 |
+
self.dict_torch_dtype_to_str(output)
|
| 908 |
+
|
| 909 |
+
return output
|
| 910 |
+
|
| 911 |
+
def to_json_string(self, use_diff: bool = True) -> str:
|
| 912 |
+
"""
|
| 913 |
+
Serializes this instance to a JSON string.
|
| 914 |
+
|
| 915 |
+
Args:
|
| 916 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
| 917 |
+
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
|
| 918 |
+
is serialized to JSON string.
|
| 919 |
+
|
| 920 |
+
Returns:
|
| 921 |
+
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
| 922 |
+
"""
|
| 923 |
+
if use_diff is True:
|
| 924 |
+
config_dict = self.to_diff_dict()
|
| 925 |
+
else:
|
| 926 |
+
config_dict = self.to_dict()
|
| 927 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
| 928 |
+
|
| 929 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
| 930 |
+
"""
|
| 931 |
+
Save this instance to a JSON file.
|
| 932 |
+
|
| 933 |
+
Args:
|
| 934 |
+
json_file_path (`str` or `os.PathLike`):
|
| 935 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
| 936 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
| 937 |
+
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
|
| 938 |
+
is serialized to JSON file.
|
| 939 |
+
"""
|
| 940 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 941 |
+
writer.write(self.to_json_string(use_diff=use_diff))
|
| 942 |
+
|
| 943 |
+
def update(self, config_dict: Dict[str, Any]):
|
| 944 |
+
"""
|
| 945 |
+
Updates attributes of this class with attributes from `config_dict`.
|
| 946 |
+
|
| 947 |
+
Args:
|
| 948 |
+
config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
|
| 949 |
+
"""
|
| 950 |
+
for key, value in config_dict.items():
|
| 951 |
+
setattr(self, key, value)
|
| 952 |
+
|
| 953 |
+
def update_from_string(self, update_str: str):
|
| 954 |
+
"""
|
| 955 |
+
Updates attributes of this class with attributes from `update_str`.
|
| 956 |
+
|
| 957 |
+
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
|
| 958 |
+
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
| 959 |
+
|
| 960 |
+
The keys to change have to already exist in the config object.
|
| 961 |
+
|
| 962 |
+
Args:
|
| 963 |
+
update_str (`str`): String with attributes that should be updated for this class.
|
| 964 |
+
|
| 965 |
+
"""
|
| 966 |
+
|
| 967 |
+
d = dict(x.split("=") for x in update_str.split(","))
|
| 968 |
+
for k, v in d.items():
|
| 969 |
+
if not hasattr(self, k):
|
| 970 |
+
raise ValueError(f"key {k} isn't in the original config dict")
|
| 971 |
+
|
| 972 |
+
old_v = getattr(self, k)
|
| 973 |
+
if isinstance(old_v, bool):
|
| 974 |
+
if v.lower() in ["true", "1", "y", "yes"]:
|
| 975 |
+
v = True
|
| 976 |
+
elif v.lower() in ["false", "0", "n", "no"]:
|
| 977 |
+
v = False
|
| 978 |
+
else:
|
| 979 |
+
raise ValueError(f"can't derive true or false from {v} (key {k})")
|
| 980 |
+
elif isinstance(old_v, int):
|
| 981 |
+
v = int(v)
|
| 982 |
+
elif isinstance(old_v, float):
|
| 983 |
+
v = float(v)
|
| 984 |
+
elif not isinstance(old_v, str):
|
| 985 |
+
raise TypeError(
|
| 986 |
+
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
setattr(self, k, v)
|
| 990 |
+
|
| 991 |
+
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
| 992 |
+
"""
|
| 993 |
+
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
|
| 994 |
+
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
| 995 |
+
string, which can then be stored in the json format.
|
| 996 |
+
"""
|
| 997 |
+
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
| 998 |
+
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
| 999 |
+
for value in d.values():
|
| 1000 |
+
if isinstance(value, dict):
|
| 1001 |
+
self.dict_torch_dtype_to_str(value)
|
| 1002 |
+
|
| 1003 |
+
@classmethod
|
| 1004 |
+
def register_for_auto_class(cls, auto_class="AutoConfig"):
|
| 1005 |
+
"""
|
| 1006 |
+
Register this class with a given auto class. This should only be used for custom configurations as the ones in
|
| 1007 |
+
the library are already mapped with `AutoConfig`.
|
| 1008 |
+
|
| 1009 |
+
<Tip warning={true}>
|
| 1010 |
+
|
| 1011 |
+
This API is experimental and may have some slight breaking changes in the next releases.
|
| 1012 |
+
|
| 1013 |
+
</Tip>
|
| 1014 |
+
|
| 1015 |
+
Args:
|
| 1016 |
+
auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
|
| 1017 |
+
The auto class to register this new configuration with.
|
| 1018 |
+
"""
|
| 1019 |
+
if not isinstance(auto_class, str):
|
| 1020 |
+
auto_class = auto_class.__name__
|
| 1021 |
+
|
| 1022 |
+
import transformers.models.auto as auto_module
|
| 1023 |
+
|
| 1024 |
+
if not hasattr(auto_module, auto_class):
|
| 1025 |
+
raise ValueError(f"{auto_class} is not a valid auto class.")
|
| 1026 |
+
|
| 1027 |
+
cls._auto_class = auto_class
|
| 1028 |
+
|
| 1029 |
+
@staticmethod
|
| 1030 |
+
def _get_global_generation_defaults() -> Dict[str, Any]:
|
| 1031 |
+
return {
|
| 1032 |
+
"max_length": 20,
|
| 1033 |
+
"min_length": 0,
|
| 1034 |
+
"do_sample": False,
|
| 1035 |
+
"early_stopping": False,
|
| 1036 |
+
"num_beams": 1,
|
| 1037 |
+
"num_beam_groups": 1,
|
| 1038 |
+
"diversity_penalty": 0.0,
|
| 1039 |
+
"temperature": 1.0,
|
| 1040 |
+
"top_k": 50,
|
| 1041 |
+
"top_p": 1.0,
|
| 1042 |
+
"typical_p": 1.0,
|
| 1043 |
+
"repetition_penalty": 1.0,
|
| 1044 |
+
"length_penalty": 1.0,
|
| 1045 |
+
"no_repeat_ngram_size": 0,
|
| 1046 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 1047 |
+
"bad_words_ids": None,
|
| 1048 |
+
"num_return_sequences": 1,
|
| 1049 |
+
"output_scores": False,
|
| 1050 |
+
"return_dict_in_generate": False,
|
| 1051 |
+
"forced_bos_token_id": None,
|
| 1052 |
+
"forced_eos_token_id": None,
|
| 1053 |
+
"remove_invalid_values": False,
|
| 1054 |
+
"exponential_decay_length_penalty": None,
|
| 1055 |
+
"suppress_tokens": None,
|
| 1056 |
+
"begin_suppress_tokens": None,
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
|
| 1060 |
+
"""
|
| 1061 |
+
Gets the non-default generation parameters on the PretrainedConfig instance
|
| 1062 |
+
"""
|
| 1063 |
+
non_default_generation_parameters = {}
|
| 1064 |
+
decoder_attribute_name = None
|
| 1065 |
+
|
| 1066 |
+
# Composite models don't have a default config, use their decoder config as a fallback for default values
|
| 1067 |
+
# If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
|
| 1068 |
+
try:
|
| 1069 |
+
default_config = self.__class__()
|
| 1070 |
+
except ValueError:
|
| 1071 |
+
decoder_config = self.get_text_config(decoder=True)
|
| 1072 |
+
if decoder_config is not self:
|
| 1073 |
+
default_config = decoder_config.__class__()
|
| 1074 |
+
else:
|
| 1075 |
+
default_config = None
|
| 1076 |
+
|
| 1077 |
+
# If it is a composite model, we want to check the subconfig that will be used for generation
|
| 1078 |
+
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
|
| 1079 |
+
|
| 1080 |
+
for parameter_name, default_global_value in self._get_global_generation_defaults().items():
|
| 1081 |
+
if hasattr(self_decoder_config, parameter_name):
|
| 1082 |
+
is_default_in_config = is_default_generation_value = None
|
| 1083 |
+
parameter_value = getattr(self_decoder_config, parameter_name)
|
| 1084 |
+
# Three cases in which is okay for the model config to hold generation config parameters:
|
| 1085 |
+
# 1. The parameter is set to `None`, effectivelly delegating its value to the generation config
|
| 1086 |
+
if parameter_value is None:
|
| 1087 |
+
continue
|
| 1088 |
+
# 2. If we have a default config, then the instance should hold the same generation defaults
|
| 1089 |
+
if default_config is not None:
|
| 1090 |
+
is_default_in_config = parameter_value == getattr(default_config, parameter_name)
|
| 1091 |
+
# 3. if we don't have a default config, then the instance should hold the global generation defaults
|
| 1092 |
+
else:
|
| 1093 |
+
is_default_generation_value = parameter_value == default_global_value
|
| 1094 |
+
|
| 1095 |
+
is_non_default = (is_default_in_config is False) or (
|
| 1096 |
+
is_default_in_config is None and is_default_generation_value is False
|
| 1097 |
+
)
|
| 1098 |
+
if is_non_default:
|
| 1099 |
+
non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name)
|
| 1100 |
+
|
| 1101 |
+
return non_default_generation_parameters
|
| 1102 |
+
|
| 1103 |
+
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
| 1104 |
+
"""
|
| 1105 |
+
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
| 1106 |
+
itself. On specific composite models, it is under a set of valid names.
|
| 1107 |
+
|
| 1108 |
+
If `decoder` is set to `True`, then only search for decoder config names.
|
| 1109 |
+
"""
|
| 1110 |
+
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
| 1111 |
+
encoder_possible_text_config_names = ("text_encoder",)
|
| 1112 |
+
if decoder:
|
| 1113 |
+
possible_text_config_names = decoder_possible_text_config_names
|
| 1114 |
+
else:
|
| 1115 |
+
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
|
| 1116 |
+
|
| 1117 |
+
valid_text_config_names = []
|
| 1118 |
+
for text_config_name in possible_text_config_names:
|
| 1119 |
+
if hasattr(self, text_config_name):
|
| 1120 |
+
text_config = getattr(self, text_config_name, None)
|
| 1121 |
+
if text_config is not None:
|
| 1122 |
+
valid_text_config_names += [text_config_name]
|
| 1123 |
+
|
| 1124 |
+
if len(valid_text_config_names) > 1:
|
| 1125 |
+
raise ValueError(
|
| 1126 |
+
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
|
| 1127 |
+
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
| 1128 |
+
)
|
| 1129 |
+
elif len(valid_text_config_names) == 1:
|
| 1130 |
+
return getattr(self, valid_text_config_names[0])
|
| 1131 |
+
return self
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
def get_configuration_file(configuration_files: List[str]) -> str:
|
| 1135 |
+
"""
|
| 1136 |
+
Get the configuration file to use for this version of transformers.
|
| 1137 |
+
|
| 1138 |
+
Args:
|
| 1139 |
+
configuration_files (`List[str]`): The list of available configuration files.
|
| 1140 |
+
|
| 1141 |
+
Returns:
|
| 1142 |
+
`str`: The configuration file to use.
|
| 1143 |
+
"""
|
| 1144 |
+
configuration_files_map = {}
|
| 1145 |
+
for file_name in configuration_files:
|
| 1146 |
+
search = _re_configuration_file.search(file_name)
|
| 1147 |
+
if search is not None:
|
| 1148 |
+
v = search.groups()[0]
|
| 1149 |
+
configuration_files_map[v] = file_name
|
| 1150 |
+
available_versions = sorted(configuration_files_map.keys())
|
| 1151 |
+
|
| 1152 |
+
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
|
| 1153 |
+
configuration_file = CONFIG_NAME
|
| 1154 |
+
transformers_version = version.parse(__version__)
|
| 1155 |
+
for v in available_versions:
|
| 1156 |
+
if version.parse(v) <= transformers_version:
|
| 1157 |
+
configuration_file = configuration_files_map[v]
|
| 1158 |
+
else:
|
| 1159 |
+
# No point going further since the versions are sorted.
|
| 1160 |
+
break
|
| 1161 |
+
|
| 1162 |
+
return configuration_file
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
def recursive_diff_dict(dict_a, dict_b, config_obj=None):
|
| 1166 |
+
"""
|
| 1167 |
+
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
|
| 1168 |
+
values from `dict_a` that are different from values in `dict_b`.
|
| 1169 |
+
"""
|
| 1170 |
+
diff = {}
|
| 1171 |
+
default = config_obj.__class__().to_dict() if config_obj is not None else {}
|
| 1172 |
+
for key, value in dict_a.items():
|
| 1173 |
+
obj_value = getattr(config_obj, str(key), None)
|
| 1174 |
+
if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
|
| 1175 |
+
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
|
| 1176 |
+
if len(diff_value) > 0:
|
| 1177 |
+
diff[key] = diff_value
|
| 1178 |
+
elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
|
| 1179 |
+
diff[key] = value
|
| 1180 |
+
return diff
|
| 1181 |
+
|
| 1182 |
+
|
| 1183 |
+
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
|
| 1184 |
+
if PretrainedConfig.push_to_hub.__doc__ is not None:
|
| 1185 |
+
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
|
| 1186 |
+
object="config", object_class="AutoConfig", object_files="configuration file"
|
| 1187 |
+
)
|
.venv/lib/python3.11/site-packages/transformers/convert_graph_to_onnx.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import warnings
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
from os import listdir, makedirs
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Dict, List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
from packaging.version import Version, parse
|
| 22 |
+
|
| 23 |
+
from transformers.pipelines import Pipeline, pipeline
|
| 24 |
+
from transformers.tokenization_utils import BatchEncoding
|
| 25 |
+
from transformers.utils import ModelOutput, is_tf_available, is_torch_available
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# This is the minimal required version to
|
| 29 |
+
# support some ONNX Runtime features
|
| 30 |
+
ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
SUPPORTED_PIPELINES = [
|
| 34 |
+
"feature-extraction",
|
| 35 |
+
"ner",
|
| 36 |
+
"sentiment-analysis",
|
| 37 |
+
"fill-mask",
|
| 38 |
+
"question-answering",
|
| 39 |
+
"text-generation",
|
| 40 |
+
"translation_en_to_fr",
|
| 41 |
+
"translation_en_to_de",
|
| 42 |
+
"translation_en_to_ro",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class OnnxConverterArgumentParser(ArgumentParser):
|
| 47 |
+
"""
|
| 48 |
+
Wraps all the script arguments supported to export transformers models to ONNX IR
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self):
|
| 52 |
+
super().__init__("ONNX Converter")
|
| 53 |
+
|
| 54 |
+
self.add_argument(
|
| 55 |
+
"--pipeline",
|
| 56 |
+
type=str,
|
| 57 |
+
choices=SUPPORTED_PIPELINES,
|
| 58 |
+
default="feature-extraction",
|
| 59 |
+
)
|
| 60 |
+
self.add_argument(
|
| 61 |
+
"--model",
|
| 62 |
+
type=str,
|
| 63 |
+
required=True,
|
| 64 |
+
help="Model's id or path (ex: google-bert/bert-base-cased)",
|
| 65 |
+
)
|
| 66 |
+
self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: google-bert/bert-base-cased)")
|
| 67 |
+
self.add_argument(
|
| 68 |
+
"--framework",
|
| 69 |
+
type=str,
|
| 70 |
+
choices=["pt", "tf"],
|
| 71 |
+
help="Framework for loading the model",
|
| 72 |
+
)
|
| 73 |
+
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
|
| 74 |
+
self.add_argument(
|
| 75 |
+
"--check-loading",
|
| 76 |
+
action="store_true",
|
| 77 |
+
help="Check ONNX is able to load the model",
|
| 78 |
+
)
|
| 79 |
+
self.add_argument(
|
| 80 |
+
"--use-external-format",
|
| 81 |
+
action="store_true",
|
| 82 |
+
help="Allow exporting model >= than 2Gb",
|
| 83 |
+
)
|
| 84 |
+
self.add_argument(
|
| 85 |
+
"--quantize",
|
| 86 |
+
action="store_true",
|
| 87 |
+
help="Quantize the neural network to be run with int8",
|
| 88 |
+
)
|
| 89 |
+
self.add_argument("output")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def generate_identified_filename(filename: Path, identifier: str) -> Path:
|
| 93 |
+
"""
|
| 94 |
+
Append a string-identifier at the end (before the extension, if any) to the provided filepath
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
filename: pathlib.Path The actual path object we would like to add an identifier suffix
|
| 98 |
+
identifier: The suffix to add
|
| 99 |
+
|
| 100 |
+
Returns: String with concatenated identifier at the end of the filename
|
| 101 |
+
"""
|
| 102 |
+
return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def check_onnxruntime_requirements(minimum_version: Version):
|
| 106 |
+
"""
|
| 107 |
+
Check onnxruntime is installed and if the installed version match is recent enough
|
| 108 |
+
|
| 109 |
+
Raises:
|
| 110 |
+
ImportError: If onnxruntime is not installed or too old version is found
|
| 111 |
+
"""
|
| 112 |
+
try:
|
| 113 |
+
import onnxruntime
|
| 114 |
+
|
| 115 |
+
# Parse the version of the installed onnxruntime
|
| 116 |
+
ort_version = parse(onnxruntime.__version__)
|
| 117 |
+
|
| 118 |
+
# We require 1.4.0 minimum
|
| 119 |
+
if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
|
| 120 |
+
raise ImportError(
|
| 121 |
+
f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
|
| 122 |
+
f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
|
| 123 |
+
"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
except ImportError:
|
| 127 |
+
raise ImportError(
|
| 128 |
+
"onnxruntime doesn't seem to be currently installed. "
|
| 129 |
+
"Please install the onnxruntime by running `pip install onnxruntime`"
|
| 130 |
+
" and relaunch the conversion."
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def ensure_valid_input(model, tokens, input_names):
|
| 135 |
+
"""
|
| 136 |
+
Ensure inputs are presented in the correct order, without any Non
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
model: The model used to forward the input data
|
| 140 |
+
tokens: BatchEncoding holding the input data
|
| 141 |
+
input_names: The name of the inputs
|
| 142 |
+
|
| 143 |
+
Returns: Tuple
|
| 144 |
+
|
| 145 |
+
"""
|
| 146 |
+
print("Ensuring inputs are in correct order")
|
| 147 |
+
|
| 148 |
+
model_args_name = model.forward.__code__.co_varnames
|
| 149 |
+
model_args, ordered_input_names = [], []
|
| 150 |
+
for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
|
| 151 |
+
if arg_name in input_names:
|
| 152 |
+
ordered_input_names.append(arg_name)
|
| 153 |
+
model_args.append(tokens[arg_name])
|
| 154 |
+
else:
|
| 155 |
+
print(f"{arg_name} is not present in the generated input list.")
|
| 156 |
+
break
|
| 157 |
+
|
| 158 |
+
print(f"Generated inputs order: {ordered_input_names}")
|
| 159 |
+
return ordered_input_names, tuple(model_args)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
|
| 163 |
+
"""
|
| 164 |
+
Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
nlp: The pipeline object holding the model to be exported
|
| 168 |
+
framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
|
| 172 |
+
- List of the inferred input variable names
|
| 173 |
+
- List of the inferred output variable names
|
| 174 |
+
- Dictionary with input/output variables names as key and shape tensor as value
|
| 175 |
+
- a BatchEncoding reference which was used to infer all the above information
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
|
| 179 |
+
if isinstance(tensor, (tuple, list)):
|
| 180 |
+
return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
|
| 181 |
+
|
| 182 |
+
else:
|
| 183 |
+
# Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
|
| 184 |
+
axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
|
| 185 |
+
if is_input:
|
| 186 |
+
if len(tensor.shape) == 2:
|
| 187 |
+
axes[1] = "sequence"
|
| 188 |
+
else:
|
| 189 |
+
raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
|
| 190 |
+
else:
|
| 191 |
+
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
|
| 192 |
+
axes.update({dim: "sequence" for dim in seq_axes})
|
| 193 |
+
|
| 194 |
+
print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
|
| 195 |
+
return axes
|
| 196 |
+
|
| 197 |
+
tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
|
| 198 |
+
seq_len = tokens.input_ids.shape[-1]
|
| 199 |
+
outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
|
| 200 |
+
if isinstance(outputs, ModelOutput):
|
| 201 |
+
outputs = outputs.to_tuple()
|
| 202 |
+
if not isinstance(outputs, (list, tuple)):
|
| 203 |
+
outputs = (outputs,)
|
| 204 |
+
|
| 205 |
+
# Generate input names & axes
|
| 206 |
+
input_vars = list(tokens.keys())
|
| 207 |
+
input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
|
| 208 |
+
|
| 209 |
+
# flatten potentially grouped outputs (past for gpt2, attentions)
|
| 210 |
+
outputs_flat = []
|
| 211 |
+
for output in outputs:
|
| 212 |
+
if isinstance(output, (tuple, list)):
|
| 213 |
+
outputs_flat.extend(output)
|
| 214 |
+
else:
|
| 215 |
+
outputs_flat.append(output)
|
| 216 |
+
|
| 217 |
+
# Generate output names & axes
|
| 218 |
+
output_names = [f"output_{i}" for i in range(len(outputs_flat))]
|
| 219 |
+
output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
|
| 220 |
+
|
| 221 |
+
# Create the aggregated axes representation
|
| 222 |
+
dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
|
| 223 |
+
return input_vars, output_names, dynamic_axes, tokens
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def load_graph_from_args(
|
| 227 |
+
pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs
|
| 228 |
+
) -> Pipeline:
|
| 229 |
+
"""
|
| 230 |
+
Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
|
| 234 |
+
framework: The actual model to convert the pipeline from ("pt" or "tf")
|
| 235 |
+
model: The model name which will be loaded by the pipeline
|
| 236 |
+
tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value
|
| 237 |
+
|
| 238 |
+
Returns: Pipeline object
|
| 239 |
+
|
| 240 |
+
"""
|
| 241 |
+
# If no tokenizer provided
|
| 242 |
+
if tokenizer is None:
|
| 243 |
+
tokenizer = model
|
| 244 |
+
|
| 245 |
+
# Check the wanted framework is available
|
| 246 |
+
if framework == "pt" and not is_torch_available():
|
| 247 |
+
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
| 248 |
+
if framework == "tf" and not is_tf_available():
|
| 249 |
+
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
|
| 250 |
+
|
| 251 |
+
print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")
|
| 252 |
+
|
| 253 |
+
# Allocate tokenizer and model
|
| 254 |
+
return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
|
| 258 |
+
"""
|
| 259 |
+
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
nlp: The pipeline to be exported
|
| 263 |
+
opset: The actual version of the ONNX operator set to use
|
| 264 |
+
output: Path where will be stored the generated ONNX model
|
| 265 |
+
use_external_format: Split the model definition from its parameters to allow model bigger than 2GB
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
|
| 269 |
+
"""
|
| 270 |
+
if not is_torch_available():
|
| 271 |
+
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
| 272 |
+
|
| 273 |
+
import torch
|
| 274 |
+
from torch.onnx import export
|
| 275 |
+
|
| 276 |
+
print(f"Using framework PyTorch: {torch.__version__}")
|
| 277 |
+
|
| 278 |
+
with torch.no_grad():
|
| 279 |
+
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
|
| 280 |
+
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
|
| 281 |
+
|
| 282 |
+
export(
|
| 283 |
+
nlp.model,
|
| 284 |
+
model_args,
|
| 285 |
+
f=output.as_posix(),
|
| 286 |
+
input_names=ordered_input_names,
|
| 287 |
+
output_names=output_names,
|
| 288 |
+
dynamic_axes=dynamic_axes,
|
| 289 |
+
do_constant_folding=True,
|
| 290 |
+
opset_version=opset,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
|
| 295 |
+
"""
|
| 296 |
+
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
nlp: The pipeline to be exported
|
| 300 |
+
opset: The actual version of the ONNX operator set to use
|
| 301 |
+
output: Path where will be stored the generated ONNX model
|
| 302 |
+
|
| 303 |
+
Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow
|
| 304 |
+
|
| 305 |
+
"""
|
| 306 |
+
if not is_tf_available():
|
| 307 |
+
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
|
| 308 |
+
|
| 309 |
+
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
|
| 310 |
+
|
| 311 |
+
try:
|
| 312 |
+
import tensorflow as tf
|
| 313 |
+
import tf2onnx
|
| 314 |
+
from tf2onnx import __version__ as t2ov
|
| 315 |
+
|
| 316 |
+
print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
|
| 317 |
+
|
| 318 |
+
# Build
|
| 319 |
+
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
|
| 320 |
+
|
| 321 |
+
# Forward
|
| 322 |
+
nlp.model.predict(tokens.data)
|
| 323 |
+
input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()]
|
| 324 |
+
model_proto, _ = tf2onnx.convert.from_keras(
|
| 325 |
+
nlp.model, input_signature, opset=opset, output_path=output.as_posix()
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
except ImportError as e:
|
| 329 |
+
raise Exception(
|
| 330 |
+
f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def convert(
|
| 335 |
+
framework: str,
|
| 336 |
+
model: str,
|
| 337 |
+
output: Path,
|
| 338 |
+
opset: int,
|
| 339 |
+
tokenizer: Optional[str] = None,
|
| 340 |
+
use_external_format: bool = False,
|
| 341 |
+
pipeline_name: str = "feature-extraction",
|
| 342 |
+
**model_kwargs,
|
| 343 |
+
):
|
| 344 |
+
"""
|
| 345 |
+
Convert the pipeline object to the ONNX Intermediate Representation (IR) format
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
framework: The framework the pipeline is backed by ("pt" or "tf")
|
| 349 |
+
model: The name of the model to load for the pipeline
|
| 350 |
+
output: The path where the ONNX graph will be stored
|
| 351 |
+
opset: The actual version of the ONNX operator set to use
|
| 352 |
+
tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
|
| 353 |
+
use_external_format:
|
| 354 |
+
Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
|
| 355 |
+
pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
|
| 356 |
+
model_kwargs: Keyword arguments to be forwarded to the model constructor
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
|
| 360 |
+
"""
|
| 361 |
+
warnings.warn(
|
| 362 |
+
"The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of"
|
| 363 |
+
" Transformers",
|
| 364 |
+
FutureWarning,
|
| 365 |
+
)
|
| 366 |
+
print(f"ONNX opset version set to: {opset}")
|
| 367 |
+
|
| 368 |
+
# Load the pipeline
|
| 369 |
+
nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)
|
| 370 |
+
|
| 371 |
+
if not output.parent.exists():
|
| 372 |
+
print(f"Creating folder {output.parent}")
|
| 373 |
+
makedirs(output.parent.as_posix())
|
| 374 |
+
elif len(listdir(output.parent.as_posix())) > 0:
|
| 375 |
+
raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")
|
| 376 |
+
|
| 377 |
+
# Export the graph
|
| 378 |
+
if framework == "pt":
|
| 379 |
+
convert_pytorch(nlp, opset, output, use_external_format)
|
| 380 |
+
else:
|
| 381 |
+
convert_tensorflow(nlp, opset, output)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def optimize(onnx_model_path: Path) -> Path:
|
| 385 |
+
"""
|
| 386 |
+
Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the
|
| 387 |
+
optimizations possible
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
onnx_model_path: filepath where the model binary description is stored
|
| 391 |
+
|
| 392 |
+
Returns: Path where the optimized model binary description has been saved
|
| 393 |
+
|
| 394 |
+
"""
|
| 395 |
+
from onnxruntime import InferenceSession, SessionOptions
|
| 396 |
+
|
| 397 |
+
# Generate model name with suffix "optimized"
|
| 398 |
+
opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
|
| 399 |
+
sess_option = SessionOptions()
|
| 400 |
+
sess_option.optimized_model_filepath = opt_model_path.as_posix()
|
| 401 |
+
_ = InferenceSession(onnx_model_path.as_posix(), sess_option)
|
| 402 |
+
|
| 403 |
+
print(f"Optimized model has been written at {opt_model_path}: \N{HEAVY CHECK MARK}")
|
| 404 |
+
print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")
|
| 405 |
+
|
| 406 |
+
return opt_model_path
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def quantize(onnx_model_path: Path) -> Path:
|
| 410 |
+
"""
|
| 411 |
+
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
onnx_model_path: Path to location the exported ONNX model is stored
|
| 415 |
+
|
| 416 |
+
Returns: The Path generated for the quantized
|
| 417 |
+
"""
|
| 418 |
+
import onnx
|
| 419 |
+
import onnxruntime
|
| 420 |
+
from onnx.onnx_pb import ModelProto
|
| 421 |
+
from onnxruntime.quantization import QuantizationMode
|
| 422 |
+
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
|
| 423 |
+
from onnxruntime.quantization.registry import IntegerOpsRegistry
|
| 424 |
+
|
| 425 |
+
# Load the ONNX model
|
| 426 |
+
onnx_model = onnx.load(onnx_model_path.as_posix())
|
| 427 |
+
|
| 428 |
+
if parse(onnx.__version__) < parse("1.5.0"):
|
| 429 |
+
print(
|
| 430 |
+
"Models larger than 2GB will fail to quantize due to protobuf constraint.\n"
|
| 431 |
+
"Please upgrade to onnxruntime >= 1.5.0."
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Copy it
|
| 435 |
+
copy_model = ModelProto()
|
| 436 |
+
copy_model.CopyFrom(onnx_model)
|
| 437 |
+
|
| 438 |
+
# Construct quantizer
|
| 439 |
+
# onnxruntime renamed input_qType to activation_qType in v1.13.1, so we
|
| 440 |
+
# check the onnxruntime version to ensure backward compatibility.
|
| 441 |
+
# See also: https://github.com/microsoft/onnxruntime/pull/12873
|
| 442 |
+
if parse(onnxruntime.__version__) < parse("1.13.1"):
|
| 443 |
+
quantizer = ONNXQuantizer(
|
| 444 |
+
model=copy_model,
|
| 445 |
+
per_channel=False,
|
| 446 |
+
reduce_range=False,
|
| 447 |
+
mode=QuantizationMode.IntegerOps,
|
| 448 |
+
static=False,
|
| 449 |
+
weight_qType=True,
|
| 450 |
+
input_qType=False,
|
| 451 |
+
tensors_range=None,
|
| 452 |
+
nodes_to_quantize=None,
|
| 453 |
+
nodes_to_exclude=None,
|
| 454 |
+
op_types_to_quantize=list(IntegerOpsRegistry),
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
quantizer = ONNXQuantizer(
|
| 458 |
+
model=copy_model,
|
| 459 |
+
per_channel=False,
|
| 460 |
+
reduce_range=False,
|
| 461 |
+
mode=QuantizationMode.IntegerOps,
|
| 462 |
+
static=False,
|
| 463 |
+
weight_qType=True,
|
| 464 |
+
activation_qType=False,
|
| 465 |
+
tensors_range=None,
|
| 466 |
+
nodes_to_quantize=None,
|
| 467 |
+
nodes_to_exclude=None,
|
| 468 |
+
op_types_to_quantize=list(IntegerOpsRegistry),
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# Quantize and export
|
| 472 |
+
quantizer.quantize_model()
|
| 473 |
+
|
| 474 |
+
# Append "-quantized" at the end of the model's name
|
| 475 |
+
quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
|
| 476 |
+
|
| 477 |
+
# Save model
|
| 478 |
+
print(f"Quantized model has been written at {quantized_model_path}: \N{HEAVY CHECK MARK}")
|
| 479 |
+
onnx.save_model(quantizer.model.model, quantized_model_path.as_posix())
|
| 480 |
+
|
| 481 |
+
return quantized_model_path
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def verify(path: Path):
|
| 485 |
+
from onnxruntime import InferenceSession, SessionOptions
|
| 486 |
+
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
|
| 487 |
+
|
| 488 |
+
print(f"Checking ONNX model loading from: {path} ...")
|
| 489 |
+
try:
|
| 490 |
+
onnx_options = SessionOptions()
|
| 491 |
+
_ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
|
| 492 |
+
print(f"Model {path} correctly loaded: \N{HEAVY CHECK MARK}")
|
| 493 |
+
except RuntimeException as re:
|
| 494 |
+
print(f"Error while loading the model {re}: \N{HEAVY BALLOT X}")
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
if __name__ == "__main__":
|
| 498 |
+
parser = OnnxConverterArgumentParser()
|
| 499 |
+
args = parser.parse_args()
|
| 500 |
+
|
| 501 |
+
# Make sure output is absolute path
|
| 502 |
+
args.output = Path(args.output).absolute()
|
| 503 |
+
|
| 504 |
+
try:
|
| 505 |
+
print("\n====== Converting model to ONNX ======")
|
| 506 |
+
# Convert
|
| 507 |
+
convert(
|
| 508 |
+
args.framework,
|
| 509 |
+
args.model,
|
| 510 |
+
args.output,
|
| 511 |
+
args.opset,
|
| 512 |
+
args.tokenizer,
|
| 513 |
+
args.use_external_format,
|
| 514 |
+
args.pipeline,
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
if args.quantize:
|
| 518 |
+
# Ensure requirements for quantization on onnxruntime is met
|
| 519 |
+
check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)
|
| 520 |
+
|
| 521 |
+
# onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch
|
| 522 |
+
if args.framework == "tf":
|
| 523 |
+
print(
|
| 524 |
+
"\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
|
| 525 |
+
"\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
|
| 526 |
+
"\t For more information, please refer to the onnxruntime documentation:\n"
|
| 527 |
+
"\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
print("\n====== Optimizing ONNX model ======")
|
| 531 |
+
|
| 532 |
+
# Quantization works best when using the optimized version of the model
|
| 533 |
+
args.optimized_output = optimize(args.output)
|
| 534 |
+
|
| 535 |
+
# Do the quantization on the right graph
|
| 536 |
+
args.quantized_output = quantize(args.optimized_output)
|
| 537 |
+
|
| 538 |
+
# And verify
|
| 539 |
+
if args.check_loading:
|
| 540 |
+
print("\n====== Check exported ONNX model(s) ======")
|
| 541 |
+
verify(args.output)
|
| 542 |
+
|
| 543 |
+
if hasattr(args, "optimized_output"):
|
| 544 |
+
verify(args.optimized_output)
|
| 545 |
+
|
| 546 |
+
if hasattr(args, "quantized_output"):
|
| 547 |
+
verify(args.quantized_output)
|
| 548 |
+
|
| 549 |
+
except Exception as e:
|
| 550 |
+
print(f"Error while converting the model: {e}")
|
| 551 |
+
exit(1)
|
.venv/lib/python3.11/site-packages/transformers/convert_pytorch_checkpoint_to_tf2.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert pytorch checkpoints to TensorFlow"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
from . import (
|
| 21 |
+
AlbertConfig,
|
| 22 |
+
BartConfig,
|
| 23 |
+
BertConfig,
|
| 24 |
+
CamembertConfig,
|
| 25 |
+
CTRLConfig,
|
| 26 |
+
DistilBertConfig,
|
| 27 |
+
DPRConfig,
|
| 28 |
+
ElectraConfig,
|
| 29 |
+
FlaubertConfig,
|
| 30 |
+
GPT2Config,
|
| 31 |
+
LayoutLMConfig,
|
| 32 |
+
LxmertConfig,
|
| 33 |
+
OpenAIGPTConfig,
|
| 34 |
+
RobertaConfig,
|
| 35 |
+
T5Config,
|
| 36 |
+
TFAlbertForPreTraining,
|
| 37 |
+
TFBartForConditionalGeneration,
|
| 38 |
+
TFBartForSequenceClassification,
|
| 39 |
+
TFBertForPreTraining,
|
| 40 |
+
TFBertForQuestionAnswering,
|
| 41 |
+
TFBertForSequenceClassification,
|
| 42 |
+
TFCamembertForMaskedLM,
|
| 43 |
+
TFCTRLLMHeadModel,
|
| 44 |
+
TFDistilBertForMaskedLM,
|
| 45 |
+
TFDistilBertForQuestionAnswering,
|
| 46 |
+
TFDPRContextEncoder,
|
| 47 |
+
TFDPRQuestionEncoder,
|
| 48 |
+
TFDPRReader,
|
| 49 |
+
TFElectraForPreTraining,
|
| 50 |
+
TFFlaubertWithLMHeadModel,
|
| 51 |
+
TFGPT2LMHeadModel,
|
| 52 |
+
TFLayoutLMForMaskedLM,
|
| 53 |
+
TFLxmertForPreTraining,
|
| 54 |
+
TFLxmertVisualFeatureEncoder,
|
| 55 |
+
TFOpenAIGPTLMHeadModel,
|
| 56 |
+
TFRobertaForCausalLM,
|
| 57 |
+
TFRobertaForMaskedLM,
|
| 58 |
+
TFRobertaForSequenceClassification,
|
| 59 |
+
TFT5ForConditionalGeneration,
|
| 60 |
+
TFTransfoXLLMHeadModel,
|
| 61 |
+
TFWav2Vec2Model,
|
| 62 |
+
TFXLMRobertaForMaskedLM,
|
| 63 |
+
TFXLMWithLMHeadModel,
|
| 64 |
+
TFXLNetLMHeadModel,
|
| 65 |
+
TransfoXLConfig,
|
| 66 |
+
Wav2Vec2Config,
|
| 67 |
+
Wav2Vec2Model,
|
| 68 |
+
XLMConfig,
|
| 69 |
+
XLMRobertaConfig,
|
| 70 |
+
XLNetConfig,
|
| 71 |
+
is_torch_available,
|
| 72 |
+
load_pytorch_checkpoint_in_tf2_model,
|
| 73 |
+
)
|
| 74 |
+
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if is_torch_available():
|
| 78 |
+
import numpy as np
|
| 79 |
+
import torch
|
| 80 |
+
|
| 81 |
+
from . import (
|
| 82 |
+
AlbertForPreTraining,
|
| 83 |
+
BartForConditionalGeneration,
|
| 84 |
+
BertForPreTraining,
|
| 85 |
+
BertForQuestionAnswering,
|
| 86 |
+
BertForSequenceClassification,
|
| 87 |
+
CamembertForMaskedLM,
|
| 88 |
+
CTRLLMHeadModel,
|
| 89 |
+
DistilBertForMaskedLM,
|
| 90 |
+
DistilBertForQuestionAnswering,
|
| 91 |
+
DPRContextEncoder,
|
| 92 |
+
DPRQuestionEncoder,
|
| 93 |
+
DPRReader,
|
| 94 |
+
ElectraForPreTraining,
|
| 95 |
+
FlaubertWithLMHeadModel,
|
| 96 |
+
GPT2LMHeadModel,
|
| 97 |
+
LayoutLMForMaskedLM,
|
| 98 |
+
LxmertForPreTraining,
|
| 99 |
+
LxmertVisualFeatureEncoder,
|
| 100 |
+
OpenAIGPTLMHeadModel,
|
| 101 |
+
RobertaForMaskedLM,
|
| 102 |
+
RobertaForSequenceClassification,
|
| 103 |
+
T5ForConditionalGeneration,
|
| 104 |
+
TransfoXLLMHeadModel,
|
| 105 |
+
XLMRobertaForMaskedLM,
|
| 106 |
+
XLMWithLMHeadModel,
|
| 107 |
+
XLNetLMHeadModel,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
logging.set_verbosity_info()
|
| 112 |
+
|
| 113 |
+
MODEL_CLASSES = {
|
| 114 |
+
"bart": (
|
| 115 |
+
BartConfig,
|
| 116 |
+
TFBartForConditionalGeneration,
|
| 117 |
+
TFBartForSequenceClassification,
|
| 118 |
+
BartForConditionalGeneration,
|
| 119 |
+
),
|
| 120 |
+
"bert": (
|
| 121 |
+
BertConfig,
|
| 122 |
+
TFBertForPreTraining,
|
| 123 |
+
BertForPreTraining,
|
| 124 |
+
),
|
| 125 |
+
"google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
|
| 126 |
+
BertConfig,
|
| 127 |
+
TFBertForQuestionAnswering,
|
| 128 |
+
BertForQuestionAnswering,
|
| 129 |
+
),
|
| 130 |
+
"google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
|
| 131 |
+
BertConfig,
|
| 132 |
+
TFBertForQuestionAnswering,
|
| 133 |
+
BertForQuestionAnswering,
|
| 134 |
+
),
|
| 135 |
+
"google-bert/bert-base-cased-finetuned-mrpc": (
|
| 136 |
+
BertConfig,
|
| 137 |
+
TFBertForSequenceClassification,
|
| 138 |
+
BertForSequenceClassification,
|
| 139 |
+
),
|
| 140 |
+
"dpr": (
|
| 141 |
+
DPRConfig,
|
| 142 |
+
TFDPRQuestionEncoder,
|
| 143 |
+
TFDPRContextEncoder,
|
| 144 |
+
TFDPRReader,
|
| 145 |
+
DPRQuestionEncoder,
|
| 146 |
+
DPRContextEncoder,
|
| 147 |
+
DPRReader,
|
| 148 |
+
),
|
| 149 |
+
"openai-community/gpt2": (
|
| 150 |
+
GPT2Config,
|
| 151 |
+
TFGPT2LMHeadModel,
|
| 152 |
+
GPT2LMHeadModel,
|
| 153 |
+
),
|
| 154 |
+
"xlnet": (
|
| 155 |
+
XLNetConfig,
|
| 156 |
+
TFXLNetLMHeadModel,
|
| 157 |
+
XLNetLMHeadModel,
|
| 158 |
+
),
|
| 159 |
+
"xlm": (
|
| 160 |
+
XLMConfig,
|
| 161 |
+
TFXLMWithLMHeadModel,
|
| 162 |
+
XLMWithLMHeadModel,
|
| 163 |
+
),
|
| 164 |
+
"xlm-roberta": (
|
| 165 |
+
XLMRobertaConfig,
|
| 166 |
+
TFXLMRobertaForMaskedLM,
|
| 167 |
+
XLMRobertaForMaskedLM,
|
| 168 |
+
),
|
| 169 |
+
"transfo-xl": (
|
| 170 |
+
TransfoXLConfig,
|
| 171 |
+
TFTransfoXLLMHeadModel,
|
| 172 |
+
TransfoXLLMHeadModel,
|
| 173 |
+
),
|
| 174 |
+
"openai-community/openai-gpt": (
|
| 175 |
+
OpenAIGPTConfig,
|
| 176 |
+
TFOpenAIGPTLMHeadModel,
|
| 177 |
+
OpenAIGPTLMHeadModel,
|
| 178 |
+
),
|
| 179 |
+
"roberta": (
|
| 180 |
+
RobertaConfig,
|
| 181 |
+
TFRobertaForCausalLM,
|
| 182 |
+
TFRobertaForMaskedLM,
|
| 183 |
+
RobertaForMaskedLM,
|
| 184 |
+
),
|
| 185 |
+
"layoutlm": (
|
| 186 |
+
LayoutLMConfig,
|
| 187 |
+
TFLayoutLMForMaskedLM,
|
| 188 |
+
LayoutLMForMaskedLM,
|
| 189 |
+
),
|
| 190 |
+
"FacebookAI/roberta-large-mnli": (
|
| 191 |
+
RobertaConfig,
|
| 192 |
+
TFRobertaForSequenceClassification,
|
| 193 |
+
RobertaForSequenceClassification,
|
| 194 |
+
),
|
| 195 |
+
"camembert": (
|
| 196 |
+
CamembertConfig,
|
| 197 |
+
TFCamembertForMaskedLM,
|
| 198 |
+
CamembertForMaskedLM,
|
| 199 |
+
),
|
| 200 |
+
"flaubert": (
|
| 201 |
+
FlaubertConfig,
|
| 202 |
+
TFFlaubertWithLMHeadModel,
|
| 203 |
+
FlaubertWithLMHeadModel,
|
| 204 |
+
),
|
| 205 |
+
"distilbert": (
|
| 206 |
+
DistilBertConfig,
|
| 207 |
+
TFDistilBertForMaskedLM,
|
| 208 |
+
DistilBertForMaskedLM,
|
| 209 |
+
),
|
| 210 |
+
"distilbert-base-distilled-squad": (
|
| 211 |
+
DistilBertConfig,
|
| 212 |
+
TFDistilBertForQuestionAnswering,
|
| 213 |
+
DistilBertForQuestionAnswering,
|
| 214 |
+
),
|
| 215 |
+
"lxmert": (
|
| 216 |
+
LxmertConfig,
|
| 217 |
+
TFLxmertForPreTraining,
|
| 218 |
+
LxmertForPreTraining,
|
| 219 |
+
),
|
| 220 |
+
"lxmert-visual-feature-encoder": (
|
| 221 |
+
LxmertConfig,
|
| 222 |
+
TFLxmertVisualFeatureEncoder,
|
| 223 |
+
LxmertVisualFeatureEncoder,
|
| 224 |
+
),
|
| 225 |
+
"Salesforce/ctrl": (
|
| 226 |
+
CTRLConfig,
|
| 227 |
+
TFCTRLLMHeadModel,
|
| 228 |
+
CTRLLMHeadModel,
|
| 229 |
+
),
|
| 230 |
+
"albert": (
|
| 231 |
+
AlbertConfig,
|
| 232 |
+
TFAlbertForPreTraining,
|
| 233 |
+
AlbertForPreTraining,
|
| 234 |
+
),
|
| 235 |
+
"t5": (
|
| 236 |
+
T5Config,
|
| 237 |
+
TFT5ForConditionalGeneration,
|
| 238 |
+
T5ForConditionalGeneration,
|
| 239 |
+
),
|
| 240 |
+
"electra": (
|
| 241 |
+
ElectraConfig,
|
| 242 |
+
TFElectraForPreTraining,
|
| 243 |
+
ElectraForPreTraining,
|
| 244 |
+
),
|
| 245 |
+
"wav2vec2": (
|
| 246 |
+
Wav2Vec2Config,
|
| 247 |
+
TFWav2Vec2Model,
|
| 248 |
+
Wav2Vec2Model,
|
| 249 |
+
),
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def convert_pt_checkpoint_to_tf(
|
| 254 |
+
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
|
| 255 |
+
):
|
| 256 |
+
if model_type not in MODEL_CLASSES:
|
| 257 |
+
raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")
|
| 258 |
+
|
| 259 |
+
config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
|
| 260 |
+
|
| 261 |
+
# Initialise TF model
|
| 262 |
+
if config_file in aws_config_map:
|
| 263 |
+
config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
|
| 264 |
+
config = config_class.from_json_file(config_file)
|
| 265 |
+
config.output_hidden_states = True
|
| 266 |
+
config.output_attentions = True
|
| 267 |
+
print(f"Building TensorFlow model from configuration: {config}")
|
| 268 |
+
tf_model = model_class(config)
|
| 269 |
+
|
| 270 |
+
# Load weights from tf checkpoint
|
| 271 |
+
if pytorch_checkpoint_path in aws_config_map.keys():
|
| 272 |
+
pytorch_checkpoint_path = cached_file(
|
| 273 |
+
pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
|
| 274 |
+
)
|
| 275 |
+
# Load PyTorch checkpoint in tf2 model:
|
| 276 |
+
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
| 277 |
+
|
| 278 |
+
if compare_with_pt_model:
|
| 279 |
+
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
| 280 |
+
|
| 281 |
+
weights_only_kwarg = {"weights_only": True}
|
| 282 |
+
state_dict = torch.load(
|
| 283 |
+
pytorch_checkpoint_path,
|
| 284 |
+
map_location="cpu",
|
| 285 |
+
**weights_only_kwarg,
|
| 286 |
+
)
|
| 287 |
+
pt_model = pt_model_class.from_pretrained(
|
| 288 |
+
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
pto = pt_model(**pt_model.dummy_inputs)
|
| 293 |
+
|
| 294 |
+
np_pt = pto[0].numpy()
|
| 295 |
+
np_tf = tfo[0].numpy()
|
| 296 |
+
diff = np.amax(np.abs(np_pt - np_tf))
|
| 297 |
+
print(f"Max absolute difference between models outputs {diff}")
|
| 298 |
+
assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"
|
| 299 |
+
|
| 300 |
+
# Save pytorch-model
|
| 301 |
+
print(f"Save TensorFlow model to {tf_dump_path}")
|
| 302 |
+
tf_model.save_weights(tf_dump_path, save_format="h5")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def convert_all_pt_checkpoints_to_tf(
|
| 306 |
+
args_model_type,
|
| 307 |
+
tf_dump_path,
|
| 308 |
+
model_shortcut_names_or_path=None,
|
| 309 |
+
config_shortcut_names_or_path=None,
|
| 310 |
+
compare_with_pt_model=False,
|
| 311 |
+
use_cached_models=False,
|
| 312 |
+
remove_cached_files=False,
|
| 313 |
+
only_convert_finetuned_models=False,
|
| 314 |
+
):
|
| 315 |
+
if args_model_type is None:
|
| 316 |
+
model_types = list(MODEL_CLASSES.keys())
|
| 317 |
+
else:
|
| 318 |
+
model_types = [args_model_type]
|
| 319 |
+
|
| 320 |
+
for j, model_type in enumerate(model_types, start=1):
|
| 321 |
+
print("=" * 100)
|
| 322 |
+
print(f" Converting model type {j}/{len(model_types)}: {model_type}")
|
| 323 |
+
print("=" * 100)
|
| 324 |
+
if model_type not in MODEL_CLASSES:
|
| 325 |
+
raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")
|
| 326 |
+
|
| 327 |
+
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
| 328 |
+
|
| 329 |
+
if model_shortcut_names_or_path is None:
|
| 330 |
+
model_shortcut_names_or_path = list(aws_model_maps.keys())
|
| 331 |
+
if config_shortcut_names_or_path is None:
|
| 332 |
+
config_shortcut_names_or_path = model_shortcut_names_or_path
|
| 333 |
+
|
| 334 |
+
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
|
| 335 |
+
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
|
| 336 |
+
):
|
| 337 |
+
print("-" * 100)
|
| 338 |
+
if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
|
| 339 |
+
if not only_convert_finetuned_models:
|
| 340 |
+
print(f" Skipping finetuned checkpoint {model_shortcut_name}")
|
| 341 |
+
continue
|
| 342 |
+
model_type = model_shortcut_name
|
| 343 |
+
elif only_convert_finetuned_models:
|
| 344 |
+
print(f" Skipping not finetuned checkpoint {model_shortcut_name}")
|
| 345 |
+
continue
|
| 346 |
+
print(
|
| 347 |
+
f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}"
|
| 348 |
+
)
|
| 349 |
+
print("-" * 100)
|
| 350 |
+
|
| 351 |
+
if config_shortcut_name in aws_config_map:
|
| 352 |
+
config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
|
| 353 |
+
else:
|
| 354 |
+
config_file = config_shortcut_name
|
| 355 |
+
|
| 356 |
+
if model_shortcut_name in aws_model_maps:
|
| 357 |
+
model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
|
| 358 |
+
else:
|
| 359 |
+
model_file = model_shortcut_name
|
| 360 |
+
|
| 361 |
+
if os.path.isfile(model_shortcut_name):
|
| 362 |
+
model_shortcut_name = "converted_model"
|
| 363 |
+
|
| 364 |
+
convert_pt_checkpoint_to_tf(
|
| 365 |
+
model_type=model_type,
|
| 366 |
+
pytorch_checkpoint_path=model_file,
|
| 367 |
+
config_file=config_file,
|
| 368 |
+
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
|
| 369 |
+
compare_with_pt_model=compare_with_pt_model,
|
| 370 |
+
)
|
| 371 |
+
if remove_cached_files:
|
| 372 |
+
os.remove(config_file)
|
| 373 |
+
os.remove(model_file)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if __name__ == "__main__":
|
| 377 |
+
parser = argparse.ArgumentParser()
|
| 378 |
+
# Required parameters
|
| 379 |
+
parser.add_argument(
|
| 380 |
+
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
|
| 381 |
+
)
|
| 382 |
+
parser.add_argument(
|
| 383 |
+
"--model_type",
|
| 384 |
+
default=None,
|
| 385 |
+
type=str,
|
| 386 |
+
help=(
|
| 387 |
+
f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
|
| 388 |
+
"convert all the models from AWS."
|
| 389 |
+
),
|
| 390 |
+
)
|
| 391 |
+
parser.add_argument(
|
| 392 |
+
"--pytorch_checkpoint_path",
|
| 393 |
+
default=None,
|
| 394 |
+
type=str,
|
| 395 |
+
help=(
|
| 396 |
+
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
|
| 397 |
+
"If not given, will download and convert all the checkpoints from AWS."
|
| 398 |
+
),
|
| 399 |
+
)
|
| 400 |
+
parser.add_argument(
|
| 401 |
+
"--config_file",
|
| 402 |
+
default=None,
|
| 403 |
+
type=str,
|
| 404 |
+
help=(
|
| 405 |
+
"The config json file corresponding to the pre-trained model. \n"
|
| 406 |
+
"This specifies the model architecture. If not given and "
|
| 407 |
+
"--pytorch_checkpoint_path is not given or is a shortcut name "
|
| 408 |
+
"use the configuration associated to the shortcut name on the AWS"
|
| 409 |
+
),
|
| 410 |
+
)
|
| 411 |
+
parser.add_argument(
|
| 412 |
+
"--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--use_cached_models",
|
| 416 |
+
action="store_true",
|
| 417 |
+
help="Use cached models if possible instead of updating to latest checkpoint versions.",
|
| 418 |
+
)
|
| 419 |
+
parser.add_argument(
|
| 420 |
+
"--remove_cached_files",
|
| 421 |
+
action="store_true",
|
| 422 |
+
help="Remove pytorch models after conversion (save memory when converting in batches).",
|
| 423 |
+
)
|
| 424 |
+
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
|
| 425 |
+
args = parser.parse_args()
|
| 426 |
+
|
| 427 |
+
# if args.pytorch_checkpoint_path is not None:
|
| 428 |
+
# convert_pt_checkpoint_to_tf(args.model_type.lower(),
|
| 429 |
+
# args.pytorch_checkpoint_path,
|
| 430 |
+
# args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
|
| 431 |
+
# args.tf_dump_path,
|
| 432 |
+
# compare_with_pt_model=args.compare_with_pt_model,
|
| 433 |
+
# use_cached_models=args.use_cached_models)
|
| 434 |
+
# else:
|
| 435 |
+
convert_all_pt_checkpoints_to_tf(
|
| 436 |
+
args.model_type.lower() if args.model_type is not None else None,
|
| 437 |
+
args.tf_dump_path,
|
| 438 |
+
model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
|
| 439 |
+
if args.pytorch_checkpoint_path is not None
|
| 440 |
+
else None,
|
| 441 |
+
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
| 442 |
+
compare_with_pt_model=args.compare_with_pt_model,
|
| 443 |
+
use_cached_models=args.use_cached_models,
|
| 444 |
+
remove_cached_files=args.remove_cached_files,
|
| 445 |
+
only_convert_finetuned_models=args.only_convert_finetuned_models,
|
| 446 |
+
)
|
.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py
ADDED
|
@@ -0,0 +1,1642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Utilities to convert slow tokenizers in their fast tokenizers counterparts.
|
| 17 |
+
|
| 18 |
+
All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
|
| 19 |
+
allow to make our dependency on SentencePiece optional.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import warnings
|
| 23 |
+
from typing import Dict, List, Tuple
|
| 24 |
+
|
| 25 |
+
from packaging import version
|
| 26 |
+
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
| 27 |
+
from tokenizers.models import BPE, Unigram, WordPiece
|
| 28 |
+
|
| 29 |
+
from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
|
| 30 |
+
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def import_protobuf(error_message=""):
|
| 37 |
+
if is_sentencepiece_available():
|
| 38 |
+
from sentencepiece import sentencepiece_model_pb2
|
| 39 |
+
|
| 40 |
+
return sentencepiece_model_pb2
|
| 41 |
+
if is_protobuf_available():
|
| 42 |
+
import google.protobuf
|
| 43 |
+
|
| 44 |
+
if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
|
| 45 |
+
from transformers.utils import sentencepiece_model_pb2
|
| 46 |
+
else:
|
| 47 |
+
from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
|
| 48 |
+
return sentencepiece_model_pb2
|
| 49 |
+
else:
|
| 50 |
+
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
|
| 54 |
+
if add_prefix_space:
|
| 55 |
+
prepend_scheme = "always"
|
| 56 |
+
if not getattr(original_tokenizer, "legacy", True):
|
| 57 |
+
prepend_scheme = "first"
|
| 58 |
+
else:
|
| 59 |
+
prepend_scheme = "never"
|
| 60 |
+
return prepend_scheme
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def generate_merges(vocab, vocab_scores):
|
| 64 |
+
reverse = vocab_scores is not None
|
| 65 |
+
vocab_scores = dict(vocab_scores) if reverse else vocab
|
| 66 |
+
|
| 67 |
+
merges = []
|
| 68 |
+
for merge, piece_score in vocab_scores.items():
|
| 69 |
+
local = []
|
| 70 |
+
for index in range(1, len(merge)):
|
| 71 |
+
piece_l, piece_r = merge[:index], merge[index:]
|
| 72 |
+
if piece_l in vocab and piece_r in vocab:
|
| 73 |
+
local.append((piece_l, piece_r, piece_score))
|
| 74 |
+
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
|
| 75 |
+
merges.extend(local)
|
| 76 |
+
|
| 77 |
+
merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
|
| 78 |
+
merges = [(val[0], val[1]) for val in merges]
|
| 79 |
+
return merges
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class SentencePieceExtractor:
|
| 83 |
+
"""
|
| 84 |
+
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, model: str):
|
| 88 |
+
requires_backends(self, "sentencepiece")
|
| 89 |
+
from sentencepiece import SentencePieceProcessor
|
| 90 |
+
|
| 91 |
+
self.sp = SentencePieceProcessor()
|
| 92 |
+
self.sp.Load(model)
|
| 93 |
+
|
| 94 |
+
def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
|
| 95 |
+
"""
|
| 96 |
+
By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
|
| 97 |
+
order the merges with respect to the piece scores instead.
|
| 98 |
+
"""
|
| 99 |
+
sp = self.sp
|
| 100 |
+
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
|
| 101 |
+
|
| 102 |
+
merges = generate_merges(vocab, vocab_scores)
|
| 103 |
+
|
| 104 |
+
return vocab, merges
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class GemmaSentencePieceExtractor(SentencePieceExtractor):
|
| 108 |
+
def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
|
| 109 |
+
"""
|
| 110 |
+
By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
|
| 111 |
+
order the merges with respect to the piece scores instead.
|
| 112 |
+
"""
|
| 113 |
+
sp = self.sp
|
| 114 |
+
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
|
| 115 |
+
|
| 116 |
+
# there is a missing token in the vocab. We have to do this to support merges
|
| 117 |
+
# "<0x09>" is the bytefallback for `\t`
|
| 118 |
+
vocab["\t"] = vocab.get("<0x09>")
|
| 119 |
+
|
| 120 |
+
merges = generate_merges(vocab, vocab_scores)
|
| 121 |
+
return vocab, merges
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def check_number_comma(piece: str) -> bool:
|
| 125 |
+
return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class Converter:
|
| 129 |
+
def __init__(self, original_tokenizer):
|
| 130 |
+
self.original_tokenizer = original_tokenizer
|
| 131 |
+
|
| 132 |
+
def converted(self) -> Tokenizer:
|
| 133 |
+
raise NotImplementedError()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class BertConverter(Converter):
|
| 137 |
+
def converted(self) -> Tokenizer:
|
| 138 |
+
vocab = self.original_tokenizer.vocab
|
| 139 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
| 140 |
+
|
| 141 |
+
tokenize_chinese_chars = False
|
| 142 |
+
strip_accents = False
|
| 143 |
+
do_lower_case = False
|
| 144 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
| 145 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
| 146 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
| 147 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
| 148 |
+
|
| 149 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
| 150 |
+
clean_text=True,
|
| 151 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
| 152 |
+
strip_accents=strip_accents,
|
| 153 |
+
lowercase=do_lower_case,
|
| 154 |
+
)
|
| 155 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
| 156 |
+
|
| 157 |
+
cls = str(self.original_tokenizer.cls_token)
|
| 158 |
+
sep = str(self.original_tokenizer.sep_token)
|
| 159 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
| 160 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
| 161 |
+
|
| 162 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 163 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
| 164 |
+
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
|
| 165 |
+
special_tokens=[
|
| 166 |
+
(cls, cls_token_id),
|
| 167 |
+
(sep, sep_token_id),
|
| 168 |
+
],
|
| 169 |
+
)
|
| 170 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
| 171 |
+
|
| 172 |
+
return tokenizer
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class SplinterConverter(Converter):
|
| 176 |
+
def converted(self) -> Tokenizer:
|
| 177 |
+
vocab = self.original_tokenizer.vocab
|
| 178 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
| 179 |
+
|
| 180 |
+
tokenize_chinese_chars = False
|
| 181 |
+
strip_accents = False
|
| 182 |
+
do_lower_case = False
|
| 183 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
| 184 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
| 185 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
| 186 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
| 187 |
+
|
| 188 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
| 189 |
+
clean_text=True,
|
| 190 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
| 191 |
+
strip_accents=strip_accents,
|
| 192 |
+
lowercase=do_lower_case,
|
| 193 |
+
)
|
| 194 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
| 195 |
+
|
| 196 |
+
cls = str(self.original_tokenizer.cls_token)
|
| 197 |
+
sep = str(self.original_tokenizer.sep_token)
|
| 198 |
+
question = str(self.original_tokenizer.question_token)
|
| 199 |
+
dot = "."
|
| 200 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
| 201 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
| 202 |
+
question_token_id = self.original_tokenizer.question_token_id
|
| 203 |
+
dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
|
| 204 |
+
|
| 205 |
+
if self.original_tokenizer.padding_side == "right":
|
| 206 |
+
pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
|
| 207 |
+
else:
|
| 208 |
+
pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
|
| 209 |
+
|
| 210 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 211 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
| 212 |
+
pair=pair,
|
| 213 |
+
special_tokens=[
|
| 214 |
+
(cls, cls_token_id),
|
| 215 |
+
(sep, sep_token_id),
|
| 216 |
+
(question, question_token_id),
|
| 217 |
+
(dot, dot_token_id),
|
| 218 |
+
],
|
| 219 |
+
)
|
| 220 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
| 221 |
+
|
| 222 |
+
return tokenizer
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class FunnelConverter(Converter):
|
| 226 |
+
def converted(self) -> Tokenizer:
|
| 227 |
+
vocab = self.original_tokenizer.vocab
|
| 228 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
| 229 |
+
|
| 230 |
+
tokenize_chinese_chars = False
|
| 231 |
+
strip_accents = False
|
| 232 |
+
do_lower_case = False
|
| 233 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
| 234 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
| 235 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
| 236 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
| 237 |
+
|
| 238 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
| 239 |
+
clean_text=True,
|
| 240 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
| 241 |
+
strip_accents=strip_accents,
|
| 242 |
+
lowercase=do_lower_case,
|
| 243 |
+
)
|
| 244 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
| 245 |
+
|
| 246 |
+
cls = str(self.original_tokenizer.cls_token)
|
| 247 |
+
sep = str(self.original_tokenizer.sep_token)
|
| 248 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
| 249 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
| 250 |
+
|
| 251 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 252 |
+
single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
|
| 253 |
+
pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
|
| 254 |
+
special_tokens=[
|
| 255 |
+
(cls, cls_token_id),
|
| 256 |
+
(sep, sep_token_id),
|
| 257 |
+
],
|
| 258 |
+
)
|
| 259 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
| 260 |
+
|
| 261 |
+
return tokenizer
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class MPNetConverter(Converter):
|
| 265 |
+
def converted(self) -> Tokenizer:
|
| 266 |
+
vocab = self.original_tokenizer.vocab
|
| 267 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
| 268 |
+
|
| 269 |
+
tokenize_chinese_chars = False
|
| 270 |
+
strip_accents = False
|
| 271 |
+
do_lower_case = False
|
| 272 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
| 273 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
| 274 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
| 275 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
| 276 |
+
|
| 277 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
| 278 |
+
clean_text=True,
|
| 279 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
| 280 |
+
strip_accents=strip_accents,
|
| 281 |
+
lowercase=do_lower_case,
|
| 282 |
+
)
|
| 283 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
| 284 |
+
|
| 285 |
+
cls = str(self.original_tokenizer.cls_token)
|
| 286 |
+
sep = str(self.original_tokenizer.sep_token)
|
| 287 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
| 288 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
| 289 |
+
|
| 290 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 291 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
| 292 |
+
pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
|
| 293 |
+
special_tokens=[
|
| 294 |
+
(cls, cls_token_id),
|
| 295 |
+
(sep, sep_token_id),
|
| 296 |
+
],
|
| 297 |
+
)
|
| 298 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
| 299 |
+
|
| 300 |
+
return tokenizer
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class OpenAIGPTConverter(Converter):
|
| 304 |
+
def converted(self) -> Tokenizer:
|
| 305 |
+
vocab = self.original_tokenizer.encoder
|
| 306 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
| 307 |
+
unk_token = self.original_tokenizer.unk_token
|
| 308 |
+
|
| 309 |
+
tokenizer = Tokenizer(
|
| 310 |
+
BPE(
|
| 311 |
+
vocab=vocab,
|
| 312 |
+
merges=merges,
|
| 313 |
+
dropout=None,
|
| 314 |
+
unk_token=str(unk_token),
|
| 315 |
+
end_of_word_suffix="</w>",
|
| 316 |
+
fuse_unk=False,
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
if tokenizer.token_to_id(str(unk_token)) is not None:
|
| 321 |
+
tokenizer.add_special_tokens([str(unk_token)])
|
| 322 |
+
|
| 323 |
+
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
|
| 324 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
| 325 |
+
tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
|
| 326 |
+
|
| 327 |
+
return tokenizer
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class GPT2Converter(Converter):
|
| 331 |
+
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
|
| 332 |
+
if not vocab:
|
| 333 |
+
vocab = self.original_tokenizer.encoder
|
| 334 |
+
if not merges:
|
| 335 |
+
merges = list(self.original_tokenizer.bpe_ranks)
|
| 336 |
+
|
| 337 |
+
tokenizer = Tokenizer(
|
| 338 |
+
BPE(
|
| 339 |
+
vocab=vocab,
|
| 340 |
+
merges=merges,
|
| 341 |
+
dropout=None,
|
| 342 |
+
continuing_subword_prefix="",
|
| 343 |
+
end_of_word_suffix="",
|
| 344 |
+
fuse_unk=False,
|
| 345 |
+
)
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
|
| 349 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
|
| 350 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 351 |
+
if getattr(self.original_tokenizer, "add_bos_token", False):
|
| 352 |
+
bos = self.original_tokenizer.bos_token
|
| 353 |
+
bos_token_id = self.original_tokenizer.bos_token_id
|
| 354 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 355 |
+
single=f"{bos}:0 $A:0",
|
| 356 |
+
pair=f"{bos}:0 $A:0 $B:1",
|
| 357 |
+
special_tokens=[
|
| 358 |
+
(bos, bos_token_id),
|
| 359 |
+
],
|
| 360 |
+
)
|
| 361 |
+
else:
|
| 362 |
+
# XXX trim_offsets=False actually means this post_processor doesn't
|
| 363 |
+
# really do anything.
|
| 364 |
+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
| 365 |
+
return tokenizer
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class HerbertConverter(Converter):
|
| 369 |
+
def converted(self) -> Tokenizer:
|
| 370 |
+
tokenizer_info_str = "#version:"
|
| 371 |
+
token_suffix = "</w>"
|
| 372 |
+
|
| 373 |
+
vocab = self.original_tokenizer.encoder
|
| 374 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
| 375 |
+
if tokenizer_info_str in merges[0][0]:
|
| 376 |
+
merges = merges[1:]
|
| 377 |
+
|
| 378 |
+
tokenizer = Tokenizer(
|
| 379 |
+
BPE(
|
| 380 |
+
vocab,
|
| 381 |
+
merges,
|
| 382 |
+
dropout=None,
|
| 383 |
+
unk_token=self.original_tokenizer.unk_token,
|
| 384 |
+
end_of_word_suffix=token_suffix,
|
| 385 |
+
)
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
|
| 389 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
| 390 |
+
tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
|
| 391 |
+
tokenizer.post_processor = processors.BertProcessing(
|
| 392 |
+
sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
|
| 393 |
+
cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
return tokenizer
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class Qwen2Converter(Converter):
|
| 400 |
+
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
|
| 401 |
+
if not vocab:
|
| 402 |
+
vocab = self.original_tokenizer.encoder
|
| 403 |
+
if not merges:
|
| 404 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
| 405 |
+
|
| 406 |
+
tokenizer = Tokenizer(
|
| 407 |
+
BPE(
|
| 408 |
+
vocab=vocab,
|
| 409 |
+
merges=merges,
|
| 410 |
+
dropout=None,
|
| 411 |
+
unk_token=None,
|
| 412 |
+
continuing_subword_prefix="",
|
| 413 |
+
end_of_word_suffix="",
|
| 414 |
+
fuse_unk=False,
|
| 415 |
+
byte_fallback=False,
|
| 416 |
+
)
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
tokenizer.normalizer = normalizers.NFC()
|
| 420 |
+
|
| 421 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
| 422 |
+
[
|
| 423 |
+
pre_tokenizers.Split(
|
| 424 |
+
Regex(
|
| 425 |
+
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
| 426 |
+
),
|
| 427 |
+
behavior="isolated",
|
| 428 |
+
invert=False,
|
| 429 |
+
),
|
| 430 |
+
pre_tokenizers.ByteLevel(
|
| 431 |
+
add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
|
| 432 |
+
use_regex=False,
|
| 433 |
+
),
|
| 434 |
+
]
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 438 |
+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
| 439 |
+
|
| 440 |
+
return tokenizer
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class RobertaConverter(Converter):
|
| 444 |
+
def converted(self) -> Tokenizer:
|
| 445 |
+
ot = self.original_tokenizer
|
| 446 |
+
vocab = ot.encoder
|
| 447 |
+
merges = list(ot.bpe_ranks.keys())
|
| 448 |
+
|
| 449 |
+
tokenizer = Tokenizer(
|
| 450 |
+
BPE(
|
| 451 |
+
vocab=vocab,
|
| 452 |
+
merges=merges,
|
| 453 |
+
dropout=None,
|
| 454 |
+
continuing_subword_prefix="",
|
| 455 |
+
end_of_word_suffix="",
|
| 456 |
+
fuse_unk=False,
|
| 457 |
+
)
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
|
| 461 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 462 |
+
tokenizer.post_processor = processors.RobertaProcessing(
|
| 463 |
+
sep=(ot.sep_token, ot.sep_token_id),
|
| 464 |
+
cls=(ot.cls_token, ot.cls_token_id),
|
| 465 |
+
add_prefix_space=ot.add_prefix_space,
|
| 466 |
+
trim_offsets=True, # True by default on Roberta (historical)
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
return tokenizer
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class RoFormerConverter(Converter):
|
| 473 |
+
def converted(self) -> Tokenizer:
|
| 474 |
+
from .models.roformer.tokenization_utils import JiebaPreTokenizer
|
| 475 |
+
|
| 476 |
+
vocab = self.original_tokenizer.vocab
|
| 477 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
| 478 |
+
|
| 479 |
+
strip_accents = False
|
| 480 |
+
do_lower_case = False
|
| 481 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
| 482 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
| 483 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
| 484 |
+
|
| 485 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
| 486 |
+
clean_text=True,
|
| 487 |
+
handle_chinese_chars=False,
|
| 488 |
+
strip_accents=strip_accents,
|
| 489 |
+
lowercase=do_lower_case,
|
| 490 |
+
)
|
| 491 |
+
tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
|
| 492 |
+
|
| 493 |
+
cls = str(self.original_tokenizer.cls_token)
|
| 494 |
+
sep = str(self.original_tokenizer.sep_token)
|
| 495 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
| 496 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
| 497 |
+
|
| 498 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 499 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
| 500 |
+
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
|
| 501 |
+
special_tokens=[
|
| 502 |
+
(cls, cls_token_id),
|
| 503 |
+
(sep, sep_token_id),
|
| 504 |
+
],
|
| 505 |
+
)
|
| 506 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
| 507 |
+
|
| 508 |
+
return tokenizer
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class DebertaConverter(Converter):
|
| 512 |
+
def converted(self) -> Tokenizer:
|
| 513 |
+
ot = self.original_tokenizer
|
| 514 |
+
vocab = ot.encoder
|
| 515 |
+
merges = list(ot.bpe_ranks.keys())
|
| 516 |
+
|
| 517 |
+
tokenizer = Tokenizer(
|
| 518 |
+
BPE(
|
| 519 |
+
vocab=vocab,
|
| 520 |
+
merges=merges,
|
| 521 |
+
dropout=None,
|
| 522 |
+
continuing_subword_prefix="",
|
| 523 |
+
end_of_word_suffix="",
|
| 524 |
+
fuse_unk=False,
|
| 525 |
+
)
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
|
| 529 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 530 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 531 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
| 532 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
| 533 |
+
special_tokens=[
|
| 534 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
| 535 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
| 536 |
+
],
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
return tokenizer
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class SpmConverter(Converter):
|
| 543 |
+
handle_byte_fallback = False
|
| 544 |
+
SpmExtractor = SentencePieceExtractor
|
| 545 |
+
special_tokens = {}
|
| 546 |
+
|
| 547 |
+
def __init__(self, *args):
|
| 548 |
+
requires_backends(self, "protobuf")
|
| 549 |
+
|
| 550 |
+
super().__init__(*args)
|
| 551 |
+
|
| 552 |
+
# from .utils import sentencepiece_model_pb2 as model_pb2
|
| 553 |
+
model_pb2 = import_protobuf()
|
| 554 |
+
|
| 555 |
+
m = model_pb2.ModelProto()
|
| 556 |
+
with open(self.original_tokenizer.vocab_file, "rb") as f:
|
| 557 |
+
m.ParseFromString(f.read())
|
| 558 |
+
self.proto = m
|
| 559 |
+
|
| 560 |
+
if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
|
| 561 |
+
warnings.warn(
|
| 562 |
+
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
| 563 |
+
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
| 564 |
+
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
| 565 |
+
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
def vocab(self, proto):
|
| 569 |
+
return [(piece.piece, piece.score) for piece in proto.pieces]
|
| 570 |
+
|
| 571 |
+
def unk_id(self, proto):
|
| 572 |
+
return proto.trainer_spec.unk_id
|
| 573 |
+
|
| 574 |
+
def tokenizer(self, proto):
|
| 575 |
+
model_type = proto.trainer_spec.model_type
|
| 576 |
+
vocab_scores = self.vocab(proto)
|
| 577 |
+
|
| 578 |
+
if model_type == 1:
|
| 579 |
+
tokenizer = Tokenizer(
|
| 580 |
+
Unigram(
|
| 581 |
+
vocab_scores,
|
| 582 |
+
unk_id=self.unk_id(proto),
|
| 583 |
+
byte_fallback=self.handle_byte_fallback,
|
| 584 |
+
)
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
elif model_type == 2:
|
| 588 |
+
_, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
| 589 |
+
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
|
| 590 |
+
tokenizer = Tokenizer(
|
| 591 |
+
BPE(
|
| 592 |
+
bpe_vocab,
|
| 593 |
+
merges,
|
| 594 |
+
unk_token=proto.trainer_spec.unk_piece,
|
| 595 |
+
fuse_unk=True,
|
| 596 |
+
byte_fallback=self.handle_byte_fallback,
|
| 597 |
+
dropout=None,
|
| 598 |
+
)
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
else:
|
| 602 |
+
raise Exception(
|
| 603 |
+
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# control tokens are special
|
| 607 |
+
# user defined symbols are not
|
| 608 |
+
# both user and control tokens are AddedTokens
|
| 609 |
+
# Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
|
| 610 |
+
spm_added_tokens = [
|
| 611 |
+
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
|
| 612 |
+
for id, p in enumerate(proto.pieces)
|
| 613 |
+
if p.type in [3, 4]
|
| 614 |
+
]
|
| 615 |
+
tokenizer.add_tokens(
|
| 616 |
+
[
|
| 617 |
+
AddedToken(token, normalized=False, special=special)
|
| 618 |
+
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
|
| 619 |
+
]
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
return tokenizer
|
| 623 |
+
|
| 624 |
+
def normalizer(self, proto):
|
| 625 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
| 626 |
+
_normalizers = [
|
| 627 |
+
normalizers.Strip(left=False, right=True), # stripping is important
|
| 628 |
+
normalizers.Replace(Regex(" {2,}"), "▁"),
|
| 629 |
+
]
|
| 630 |
+
if not precompiled_charsmap:
|
| 631 |
+
return normalizers.Sequence(_normalizers)
|
| 632 |
+
else:
|
| 633 |
+
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
|
| 634 |
+
|
| 635 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
| 636 |
+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
| 637 |
+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
|
| 638 |
+
|
| 639 |
+
def post_processor(self):
|
| 640 |
+
return None
|
| 641 |
+
|
| 642 |
+
def decoder(self, replacement, add_prefix_space):
|
| 643 |
+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
| 644 |
+
return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
|
| 645 |
+
|
| 646 |
+
def converted(self) -> Tokenizer:
|
| 647 |
+
tokenizer = self.tokenizer(self.proto)
|
| 648 |
+
|
| 649 |
+
# Tokenizer assemble
|
| 650 |
+
normalizer = self.normalizer(self.proto)
|
| 651 |
+
if normalizer is not None:
|
| 652 |
+
tokenizer.normalizer = normalizer
|
| 653 |
+
|
| 654 |
+
replacement = "▁"
|
| 655 |
+
add_prefix_space = True
|
| 656 |
+
if hasattr(self.original_tokenizer, "add_prefix_space"):
|
| 657 |
+
add_prefix_space = self.original_tokenizer.add_prefix_space
|
| 658 |
+
|
| 659 |
+
pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
|
| 660 |
+
if pre_tokenizer is not None:
|
| 661 |
+
tokenizer.pre_tokenizer = pre_tokenizer
|
| 662 |
+
|
| 663 |
+
tokenizer.decoder = self.decoder(replacement, add_prefix_space)
|
| 664 |
+
post_processor = self.post_processor()
|
| 665 |
+
if post_processor:
|
| 666 |
+
tokenizer.post_processor = post_processor
|
| 667 |
+
|
| 668 |
+
return tokenizer
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
class AlbertConverter(SpmConverter):
|
| 672 |
+
def vocab(self, proto):
|
| 673 |
+
return [
|
| 674 |
+
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
|
| 675 |
+
for piece in proto.pieces
|
| 676 |
+
]
|
| 677 |
+
|
| 678 |
+
def normalizer(self, proto):
|
| 679 |
+
list_normalizers = [
|
| 680 |
+
normalizers.Replace("``", '"'),
|
| 681 |
+
normalizers.Replace("''", '"'),
|
| 682 |
+
]
|
| 683 |
+
if not self.original_tokenizer.keep_accents:
|
| 684 |
+
list_normalizers.append(normalizers.NFKD())
|
| 685 |
+
list_normalizers.append(normalizers.StripAccents())
|
| 686 |
+
if self.original_tokenizer.do_lower_case:
|
| 687 |
+
list_normalizers.append(normalizers.Lowercase())
|
| 688 |
+
|
| 689 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
| 690 |
+
|
| 691 |
+
if precompiled_charsmap:
|
| 692 |
+
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
|
| 693 |
+
|
| 694 |
+
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
|
| 695 |
+
return normalizers.Sequence(list_normalizers)
|
| 696 |
+
|
| 697 |
+
def post_processor(self):
|
| 698 |
+
return processors.TemplateProcessing(
|
| 699 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
| 700 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
| 701 |
+
special_tokens=[
|
| 702 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
| 703 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
| 704 |
+
],
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
class BarthezConverter(SpmConverter):
|
| 709 |
+
def unk_id(self, proto):
|
| 710 |
+
unk_id = 3
|
| 711 |
+
return unk_id
|
| 712 |
+
|
| 713 |
+
def post_processor(self):
|
| 714 |
+
return processors.TemplateProcessing(
|
| 715 |
+
single="<s> $A </s>",
|
| 716 |
+
pair="<s> $A </s> </s> $B </s>",
|
| 717 |
+
special_tokens=[
|
| 718 |
+
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
| 719 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 720 |
+
],
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
class CamembertConverter(SpmConverter):
|
| 725 |
+
def vocab(self, proto):
|
| 726 |
+
vocab = [
|
| 727 |
+
("<s>NOTUSED", 0.0),
|
| 728 |
+
("<pad>", 0.0),
|
| 729 |
+
("</s>NOTUSED", 0.0),
|
| 730 |
+
("<unk>", 0.0),
|
| 731 |
+
("<unk>NOTUSED", -100),
|
| 732 |
+
]
|
| 733 |
+
# We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
|
| 734 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
|
| 735 |
+
vocab += [("<mask>", 0.0)]
|
| 736 |
+
return vocab
|
| 737 |
+
|
| 738 |
+
def unk_id(self, proto):
|
| 739 |
+
# See vocab unk position
|
| 740 |
+
return 3
|
| 741 |
+
|
| 742 |
+
def post_processor(self):
|
| 743 |
+
return processors.TemplateProcessing(
|
| 744 |
+
single="<s> $A </s>",
|
| 745 |
+
pair="<s> $A </s> </s> $B </s>",
|
| 746 |
+
special_tokens=[
|
| 747 |
+
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
| 748 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 749 |
+
],
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
class DebertaV2Converter(SpmConverter):
|
| 754 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
| 755 |
+
list_pretokenizers = []
|
| 756 |
+
if self.original_tokenizer.split_by_punct:
|
| 757 |
+
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
|
| 758 |
+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
| 759 |
+
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
|
| 760 |
+
return pre_tokenizers.Sequence(list_pretokenizers)
|
| 761 |
+
|
| 762 |
+
def normalizer(self, proto):
|
| 763 |
+
list_normalizers = []
|
| 764 |
+
if self.original_tokenizer.do_lower_case:
|
| 765 |
+
list_normalizers.append(normalizers.Lowercase())
|
| 766 |
+
list_normalizers.append(normalizers.Strip())
|
| 767 |
+
|
| 768 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
| 769 |
+
if precompiled_charsmap:
|
| 770 |
+
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
|
| 771 |
+
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
|
| 772 |
+
|
| 773 |
+
return normalizers.Sequence(list_normalizers)
|
| 774 |
+
|
| 775 |
+
def post_processor(self):
|
| 776 |
+
return processors.TemplateProcessing(
|
| 777 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
| 778 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
| 779 |
+
special_tokens=[
|
| 780 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
| 781 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
| 782 |
+
],
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
class MBartConverter(SpmConverter):
|
| 787 |
+
def vocab(self, proto):
|
| 788 |
+
vocab = [
|
| 789 |
+
("<s>", 0.0),
|
| 790 |
+
("<pad>", 0.0),
|
| 791 |
+
("</s>", 0.0),
|
| 792 |
+
("<unk>", 0.0),
|
| 793 |
+
]
|
| 794 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
| 795 |
+
vocab += [
|
| 796 |
+
("ar_AR", 0.0),
|
| 797 |
+
("cs_CZ", 0.0),
|
| 798 |
+
("de_DE", 0.0),
|
| 799 |
+
("en_XX", 0.0),
|
| 800 |
+
("es_XX", 0.0),
|
| 801 |
+
("et_EE", 0.0),
|
| 802 |
+
("fi_FI", 0.0),
|
| 803 |
+
("fr_XX", 0.0),
|
| 804 |
+
("gu_IN", 0.0),
|
| 805 |
+
("hi_IN", 0.0),
|
| 806 |
+
("it_IT", 0.0),
|
| 807 |
+
("ja_XX", 0.0),
|
| 808 |
+
("kk_KZ", 0.0),
|
| 809 |
+
("ko_KR", 0.0),
|
| 810 |
+
("lt_LT", 0.0),
|
| 811 |
+
("lv_LV", 0.0),
|
| 812 |
+
("my_MM", 0.0),
|
| 813 |
+
("ne_NP", 0.0),
|
| 814 |
+
("nl_XX", 0.0),
|
| 815 |
+
("ro_RO", 0.0),
|
| 816 |
+
("ru_RU", 0.0),
|
| 817 |
+
("si_LK", 0.0),
|
| 818 |
+
("tr_TR", 0.0),
|
| 819 |
+
("vi_VN", 0.0),
|
| 820 |
+
("zh_CN", 0.0),
|
| 821 |
+
]
|
| 822 |
+
vocab += [("<mask>", 0.0)]
|
| 823 |
+
return vocab
|
| 824 |
+
|
| 825 |
+
def unk_id(self, proto):
|
| 826 |
+
return 3
|
| 827 |
+
|
| 828 |
+
def post_processor(self):
|
| 829 |
+
return processors.TemplateProcessing(
|
| 830 |
+
single="$A </s> en_XX",
|
| 831 |
+
pair="$A $B </s> en_XX",
|
| 832 |
+
special_tokens=[
|
| 833 |
+
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
|
| 834 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 835 |
+
],
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
class MBart50Converter(SpmConverter):
|
| 840 |
+
def vocab(self, proto):
|
| 841 |
+
vocab = [
|
| 842 |
+
("<s>", 0.0),
|
| 843 |
+
("<pad>", 0.0),
|
| 844 |
+
("</s>", 0.0),
|
| 845 |
+
("<unk>", 0.0),
|
| 846 |
+
]
|
| 847 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
| 848 |
+
vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip
|
| 849 |
+
vocab += [("<mask>", 0.0)]
|
| 850 |
+
return vocab
|
| 851 |
+
|
| 852 |
+
def unk_id(self, proto):
|
| 853 |
+
return 3
|
| 854 |
+
|
| 855 |
+
def post_processor(self):
|
| 856 |
+
return processors.TemplateProcessing(
|
| 857 |
+
single="en_XX $A </s>",
|
| 858 |
+
pair="en_XX $A $B </s>",
|
| 859 |
+
special_tokens=[
|
| 860 |
+
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
|
| 861 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 862 |
+
],
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
class NllbConverter(SpmConverter):
|
| 867 |
+
def vocab(self, proto):
|
| 868 |
+
vocab = [
|
| 869 |
+
("<s>", 0.0),
|
| 870 |
+
("<pad>", 0.0),
|
| 871 |
+
("</s>", 0.0),
|
| 872 |
+
("<unk>", 0.0),
|
| 873 |
+
]
|
| 874 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
| 875 |
+
return vocab
|
| 876 |
+
|
| 877 |
+
def unk_id(self, proto):
|
| 878 |
+
return 3
|
| 879 |
+
|
| 880 |
+
def post_processor(self):
|
| 881 |
+
return processors.TemplateProcessing(
|
| 882 |
+
single="eng_Latn $A </s>",
|
| 883 |
+
pair="eng_Latn $A $B </s>",
|
| 884 |
+
special_tokens=[
|
| 885 |
+
("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
|
| 886 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 887 |
+
],
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
class SeamlessM4TConverter(SpmConverter):
|
| 892 |
+
def vocab(self, proto):
|
| 893 |
+
vocab = [
|
| 894 |
+
("<pad>", 0.0),
|
| 895 |
+
("<unk>", 0.0),
|
| 896 |
+
("<s>", 0.0),
|
| 897 |
+
("</s>", 0.0),
|
| 898 |
+
]
|
| 899 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
| 900 |
+
return vocab
|
| 901 |
+
|
| 902 |
+
def unk_id(self, proto):
|
| 903 |
+
return self.original_tokenizer.unk_token_id
|
| 904 |
+
|
| 905 |
+
def post_processor(self):
|
| 906 |
+
return processors.TemplateProcessing(
|
| 907 |
+
single="__eng__ $A </s>",
|
| 908 |
+
pair="__eng__ $A $B </s>",
|
| 909 |
+
special_tokens=[
|
| 910 |
+
("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
|
| 911 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 912 |
+
],
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
class XLMRobertaConverter(SpmConverter):
|
| 917 |
+
def vocab(self, proto):
|
| 918 |
+
vocab = [
|
| 919 |
+
("<s>", 0.0),
|
| 920 |
+
("<pad>", 0.0),
|
| 921 |
+
("</s>", 0.0),
|
| 922 |
+
("<unk>", 0.0),
|
| 923 |
+
]
|
| 924 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
| 925 |
+
vocab += [("<mask>", 0.0)]
|
| 926 |
+
return vocab
|
| 927 |
+
|
| 928 |
+
def unk_id(self, proto):
|
| 929 |
+
unk_id = 3
|
| 930 |
+
return unk_id
|
| 931 |
+
|
| 932 |
+
def post_processor(self):
|
| 933 |
+
return processors.TemplateProcessing(
|
| 934 |
+
single="<s> $A </s>",
|
| 935 |
+
pair="<s> $A </s> </s> $B </s>",
|
| 936 |
+
special_tokens=[
|
| 937 |
+
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
| 938 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 939 |
+
],
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
class XLNetConverter(SpmConverter):
|
| 944 |
+
def vocab(self, proto):
|
| 945 |
+
return [
|
| 946 |
+
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
|
| 947 |
+
for piece in proto.pieces
|
| 948 |
+
]
|
| 949 |
+
|
| 950 |
+
def normalizer(self, proto):
|
| 951 |
+
list_normalizers = [
|
| 952 |
+
normalizers.Replace("``", '"'),
|
| 953 |
+
normalizers.Replace("''", '"'),
|
| 954 |
+
]
|
| 955 |
+
if not self.original_tokenizer.keep_accents:
|
| 956 |
+
list_normalizers.append(normalizers.NFKD())
|
| 957 |
+
list_normalizers.append(normalizers.StripAccents())
|
| 958 |
+
if self.original_tokenizer.do_lower_case:
|
| 959 |
+
list_normalizers.append(normalizers.Lowercase())
|
| 960 |
+
|
| 961 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
| 962 |
+
|
| 963 |
+
if precompiled_charsmap:
|
| 964 |
+
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
|
| 965 |
+
|
| 966 |
+
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
|
| 967 |
+
return normalizers.Sequence(list_normalizers)
|
| 968 |
+
|
| 969 |
+
def post_processor(self):
|
| 970 |
+
return processors.TemplateProcessing(
|
| 971 |
+
single="$A:0 <sep>:0 <cls>:2",
|
| 972 |
+
pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
|
| 973 |
+
special_tokens=[
|
| 974 |
+
("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
|
| 975 |
+
("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
|
| 976 |
+
],
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
class ReformerConverter(SpmConverter):
|
| 981 |
+
pass
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
class RemBertConverter(SpmConverter):
|
| 985 |
+
# Inspired from AlbertConverter
|
| 986 |
+
def normalizer(self, proto):
|
| 987 |
+
list_normalizers = [
|
| 988 |
+
normalizers.Replace("``", '"'),
|
| 989 |
+
normalizers.Replace("''", '"'),
|
| 990 |
+
normalizers.Replace(Regex(" {2,}"), " "),
|
| 991 |
+
]
|
| 992 |
+
if not self.original_tokenizer.keep_accents:
|
| 993 |
+
list_normalizers.append(normalizers.NFKD())
|
| 994 |
+
list_normalizers.append(normalizers.StripAccents())
|
| 995 |
+
if self.original_tokenizer.do_lower_case:
|
| 996 |
+
list_normalizers.append(normalizers.Lowercase())
|
| 997 |
+
|
| 998 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
| 999 |
+
|
| 1000 |
+
if precompiled_charsmap:
|
| 1001 |
+
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
|
| 1002 |
+
|
| 1003 |
+
return normalizers.Sequence(list_normalizers)
|
| 1004 |
+
|
| 1005 |
+
def post_processor(self):
|
| 1006 |
+
return processors.TemplateProcessing(
|
| 1007 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
| 1008 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
| 1009 |
+
special_tokens=[
|
| 1010 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
| 1011 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
| 1012 |
+
],
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
class BertGenerationConverter(SpmConverter):
|
| 1017 |
+
pass
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
class PegasusConverter(SpmConverter):
|
| 1021 |
+
def vocab(self, proto):
|
| 1022 |
+
vocab = [
|
| 1023 |
+
(self.original_tokenizer.pad_token, 0.0),
|
| 1024 |
+
(self.original_tokenizer.eos_token, 0.0),
|
| 1025 |
+
]
|
| 1026 |
+
|
| 1027 |
+
if self.original_tokenizer.mask_token_sent is not None:
|
| 1028 |
+
vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
|
| 1029 |
+
|
| 1030 |
+
if (
|
| 1031 |
+
self.original_tokenizer.mask_token is not None
|
| 1032 |
+
and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
|
| 1033 |
+
):
|
| 1034 |
+
vocab += [(self.original_tokenizer.mask_token, 0.0)]
|
| 1035 |
+
|
| 1036 |
+
vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
|
| 1037 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
|
| 1038 |
+
return vocab
|
| 1039 |
+
|
| 1040 |
+
def unk_id(self, proto):
|
| 1041 |
+
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
|
| 1042 |
+
|
| 1043 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
| 1044 |
+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
| 1045 |
+
return pre_tokenizers.Sequence(
|
| 1046 |
+
[
|
| 1047 |
+
pre_tokenizers.WhitespaceSplit(),
|
| 1048 |
+
pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
|
| 1049 |
+
]
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
def post_processor(self):
|
| 1053 |
+
eos = self.original_tokenizer.eos_token
|
| 1054 |
+
special_tokens = [
|
| 1055 |
+
(eos, self.original_tokenizer.eos_token_id),
|
| 1056 |
+
]
|
| 1057 |
+
return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
|
| 1058 |
+
|
| 1059 |
+
|
| 1060 |
+
class T5Converter(SpmConverter):
|
| 1061 |
+
def vocab(self, proto):
|
| 1062 |
+
num_extra_ids = self.original_tokenizer._extra_ids
|
| 1063 |
+
vocab = [(piece.piece, piece.score) for piece in proto.pieces]
|
| 1064 |
+
vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
|
| 1065 |
+
return vocab
|
| 1066 |
+
|
| 1067 |
+
def post_processor(self):
|
| 1068 |
+
return processors.TemplateProcessing(
|
| 1069 |
+
single=["$A", "</s>"],
|
| 1070 |
+
pair=["$A", "</s>", "$B", "</s>"],
|
| 1071 |
+
special_tokens=[
|
| 1072 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 1073 |
+
],
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
class UdopConverter(SpmConverter):
|
| 1078 |
+
def post_processor(self):
|
| 1079 |
+
return processors.TemplateProcessing(
|
| 1080 |
+
single=["$A", "</s>"],
|
| 1081 |
+
pair=["$A", "</s>", "$B", "</s>"],
|
| 1082 |
+
special_tokens=[
|
| 1083 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 1084 |
+
],
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
class WhisperConverter(Converter):
|
| 1089 |
+
def converted(self) -> Tokenizer:
|
| 1090 |
+
vocab = self.original_tokenizer.encoder
|
| 1091 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
| 1092 |
+
|
| 1093 |
+
tokenizer = Tokenizer(
|
| 1094 |
+
BPE(
|
| 1095 |
+
vocab=vocab,
|
| 1096 |
+
merges=merges,
|
| 1097 |
+
dropout=None,
|
| 1098 |
+
continuing_subword_prefix="",
|
| 1099 |
+
end_of_word_suffix="",
|
| 1100 |
+
fuse_unk=False,
|
| 1101 |
+
)
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
|
| 1105 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 1106 |
+
|
| 1107 |
+
prefix_token_ids = self.original_tokenizer.prefix_tokens
|
| 1108 |
+
prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
|
| 1109 |
+
eos = self.original_tokenizer.eos_token
|
| 1110 |
+
eos_token_id = self.original_tokenizer.eos_token_id
|
| 1111 |
+
prefix_template = " ".join([f"{token}:0" for token in prefixes])
|
| 1112 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 1113 |
+
single=f"{prefix_template} $A:0 {eos}:0",
|
| 1114 |
+
pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
|
| 1115 |
+
special_tokens=[
|
| 1116 |
+
(eos, eos_token_id),
|
| 1117 |
+
*zip(prefixes, prefix_token_ids),
|
| 1118 |
+
],
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
return tokenizer
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
class BigBirdConverter(SpmConverter):
|
| 1125 |
+
def post_processor(self):
|
| 1126 |
+
return processors.TemplateProcessing(
|
| 1127 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
| 1128 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
| 1129 |
+
special_tokens=[
|
| 1130 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
| 1131 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
| 1132 |
+
],
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
|
| 1136 |
+
class CLIPConverter(Converter):
|
| 1137 |
+
def converted(self) -> Tokenizer:
|
| 1138 |
+
vocab = self.original_tokenizer.encoder
|
| 1139 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
| 1140 |
+
unk_token = self.original_tokenizer.unk_token
|
| 1141 |
+
|
| 1142 |
+
tokenizer = Tokenizer(
|
| 1143 |
+
BPE(
|
| 1144 |
+
vocab=vocab,
|
| 1145 |
+
merges=merges,
|
| 1146 |
+
dropout=None,
|
| 1147 |
+
continuing_subword_prefix="",
|
| 1148 |
+
end_of_word_suffix="</w>",
|
| 1149 |
+
fuse_unk=False,
|
| 1150 |
+
unk_token=str(unk_token),
|
| 1151 |
+
)
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
tokenizer.normalizer = normalizers.Sequence(
|
| 1155 |
+
[normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
|
| 1156 |
+
)
|
| 1157 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
| 1158 |
+
[
|
| 1159 |
+
pre_tokenizers.Split(
|
| 1160 |
+
Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
|
| 1161 |
+
behavior="removed",
|
| 1162 |
+
invert=True,
|
| 1163 |
+
),
|
| 1164 |
+
pre_tokenizers.ByteLevel(add_prefix_space=False),
|
| 1165 |
+
]
|
| 1166 |
+
)
|
| 1167 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 1168 |
+
|
| 1169 |
+
# Hack to have a ByteLevel and TemplaceProcessor
|
| 1170 |
+
tokenizer.post_processor = processors.RobertaProcessing(
|
| 1171 |
+
sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
|
| 1172 |
+
cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
|
| 1173 |
+
add_prefix_space=False,
|
| 1174 |
+
trim_offsets=False,
|
| 1175 |
+
)
|
| 1176 |
+
return tokenizer
|
| 1177 |
+
|
| 1178 |
+
|
| 1179 |
+
class LayoutLMv2Converter(Converter):
|
| 1180 |
+
def converted(self) -> Tokenizer:
|
| 1181 |
+
vocab = self.original_tokenizer.vocab
|
| 1182 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
| 1183 |
+
|
| 1184 |
+
tokenize_chinese_chars = False
|
| 1185 |
+
strip_accents = False
|
| 1186 |
+
do_lower_case = True
|
| 1187 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
| 1188 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
| 1189 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
| 1190 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
| 1191 |
+
|
| 1192 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
| 1193 |
+
clean_text=True,
|
| 1194 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
| 1195 |
+
strip_accents=strip_accents,
|
| 1196 |
+
lowercase=do_lower_case,
|
| 1197 |
+
)
|
| 1198 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
| 1199 |
+
|
| 1200 |
+
cls = str(self.original_tokenizer.cls_token)
|
| 1201 |
+
sep = str(self.original_tokenizer.sep_token)
|
| 1202 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
| 1203 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
| 1204 |
+
|
| 1205 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 1206 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
| 1207 |
+
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
|
| 1208 |
+
special_tokens=[
|
| 1209 |
+
(cls, cls_token_id),
|
| 1210 |
+
(sep, sep_token_id),
|
| 1211 |
+
],
|
| 1212 |
+
)
|
| 1213 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
| 1214 |
+
|
| 1215 |
+
return tokenizer
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
class BlenderbotConverter(Converter):
|
| 1219 |
+
def converted(self) -> Tokenizer:
|
| 1220 |
+
ot = self.original_tokenizer
|
| 1221 |
+
vocab = ot.encoder
|
| 1222 |
+
merges = list(ot.bpe_ranks.keys())
|
| 1223 |
+
|
| 1224 |
+
tokenizer = Tokenizer(
|
| 1225 |
+
BPE(
|
| 1226 |
+
vocab=vocab,
|
| 1227 |
+
merges=merges,
|
| 1228 |
+
dropout=None,
|
| 1229 |
+
continuing_subword_prefix="",
|
| 1230 |
+
end_of_word_suffix="",
|
| 1231 |
+
fuse_unk=False,
|
| 1232 |
+
)
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
|
| 1236 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 1237 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 1238 |
+
single=f"$A:0 {ot.eos_token}:0",
|
| 1239 |
+
special_tokens=[
|
| 1240 |
+
(ot.eos_token, ot.eos_token_id),
|
| 1241 |
+
],
|
| 1242 |
+
)
|
| 1243 |
+
|
| 1244 |
+
return tokenizer
|
| 1245 |
+
|
| 1246 |
+
|
| 1247 |
+
class XGLMConverter(SpmConverter):
|
| 1248 |
+
def vocab(self, proto):
|
| 1249 |
+
vocab = [
|
| 1250 |
+
("<s>", 0.0),
|
| 1251 |
+
("<pad>", 0.0),
|
| 1252 |
+
("</s>", 0.0),
|
| 1253 |
+
("<unk>", 0.0),
|
| 1254 |
+
]
|
| 1255 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
| 1256 |
+
vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)] # fmt: skip
|
| 1257 |
+
return vocab
|
| 1258 |
+
|
| 1259 |
+
def unk_id(self, proto):
|
| 1260 |
+
unk_id = 3
|
| 1261 |
+
return unk_id
|
| 1262 |
+
|
| 1263 |
+
def post_processor(self):
|
| 1264 |
+
return processors.TemplateProcessing(
|
| 1265 |
+
single="</s> $A",
|
| 1266 |
+
pair="</s> $A </s> </s> $B",
|
| 1267 |
+
special_tokens=[
|
| 1268 |
+
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
| 1269 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
| 1270 |
+
],
|
| 1271 |
+
)
|
| 1272 |
+
|
| 1273 |
+
|
| 1274 |
+
class GemmaConverter(SpmConverter):
|
| 1275 |
+
handle_byte_fallback = True
|
| 1276 |
+
SpmExtractor = GemmaSentencePieceExtractor
|
| 1277 |
+
# start and end of turn tokens must be marked as special
|
| 1278 |
+
special_tokens = {"<start_of_turn>", "<end_of_turn>"}
|
| 1279 |
+
|
| 1280 |
+
""""
|
| 1281 |
+
split_by_unicode_script: true
|
| 1282 |
+
split_by_number: true
|
| 1283 |
+
split_by_whitespace: true
|
| 1284 |
+
treat_whitespace_as_suffix: false
|
| 1285 |
+
allow_whitespace_only_pieces: true
|
| 1286 |
+
split_digits: true
|
| 1287 |
+
byte_fallback: true
|
| 1288 |
+
"""
|
| 1289 |
+
|
| 1290 |
+
def normalizer(self, proto):
|
| 1291 |
+
return normalizers.Replace(" ", "▁")
|
| 1292 |
+
|
| 1293 |
+
def vocab(self, proto):
|
| 1294 |
+
vocab = [
|
| 1295 |
+
(self.original_tokenizer.pad_token, 0.0),
|
| 1296 |
+
(self.original_tokenizer.eos_token, 0.0),
|
| 1297 |
+
(self.original_tokenizer.bos_token, 0.0),
|
| 1298 |
+
]
|
| 1299 |
+
for piece in proto.pieces[3:]:
|
| 1300 |
+
if piece.piece == "<0x09>":
|
| 1301 |
+
vocab += [("\t", piece.score)]
|
| 1302 |
+
else:
|
| 1303 |
+
vocab += [(piece.piece, piece.score)]
|
| 1304 |
+
# vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
| 1305 |
+
return vocab
|
| 1306 |
+
|
| 1307 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
| 1308 |
+
return pre_tokenizers.Split(" ", "merged_with_previous")
|
| 1309 |
+
|
| 1310 |
+
def unk_id(self, proto):
|
| 1311 |
+
unk_id = 3
|
| 1312 |
+
return unk_id
|
| 1313 |
+
|
| 1314 |
+
def decoder(self, replacement, add_prefix_space):
|
| 1315 |
+
return decoders.Sequence(
|
| 1316 |
+
[
|
| 1317 |
+
decoders.Replace("▁", " "),
|
| 1318 |
+
decoders.ByteFallback(),
|
| 1319 |
+
decoders.Fuse(),
|
| 1320 |
+
]
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
|
| 1324 |
+
class LlamaConverter(SpmConverter):
|
| 1325 |
+
handle_byte_fallback = True
|
| 1326 |
+
|
| 1327 |
+
def vocab(self, proto):
|
| 1328 |
+
vocab = [
|
| 1329 |
+
(self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
|
| 1330 |
+
(self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
|
| 1331 |
+
(self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
|
| 1332 |
+
]
|
| 1333 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
| 1334 |
+
return vocab
|
| 1335 |
+
|
| 1336 |
+
def unk_id(self, proto):
|
| 1337 |
+
unk_id = 0
|
| 1338 |
+
return unk_id
|
| 1339 |
+
|
| 1340 |
+
def decoder(self, replacement, add_prefix_space):
|
| 1341 |
+
sequence = [
|
| 1342 |
+
decoders.Replace("▁", " "),
|
| 1343 |
+
decoders.ByteFallback(),
|
| 1344 |
+
decoders.Fuse(),
|
| 1345 |
+
]
|
| 1346 |
+
if add_prefix_space:
|
| 1347 |
+
sequence += [decoders.Strip(content=" ", left=1)]
|
| 1348 |
+
return decoders.Sequence(sequence)
|
| 1349 |
+
|
| 1350 |
+
def normalizer(self, proto):
|
| 1351 |
+
if getattr(self.original_tokenizer, "legacy", True):
|
| 1352 |
+
sequence = []
|
| 1353 |
+
if getattr(self.original_tokenizer, "add_prefix_space", True):
|
| 1354 |
+
sequence += [normalizers.Prepend(prepend="▁")]
|
| 1355 |
+
sequence += [normalizers.Replace(pattern=" ", content="▁")]
|
| 1356 |
+
return normalizers.Sequence(sequence)
|
| 1357 |
+
return None # non-legacy, no normalizer
|
| 1358 |
+
|
| 1359 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
| 1360 |
+
if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
|
| 1361 |
+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
| 1362 |
+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
|
| 1363 |
+
return None
|
| 1364 |
+
|
| 1365 |
+
def post_processor(self):
|
| 1366 |
+
# the processor is defined in the LlamaTokenizerFast class.
|
| 1367 |
+
return None
|
| 1368 |
+
|
| 1369 |
+
|
| 1370 |
+
class MarkupLMConverter(Converter):
|
| 1371 |
+
def converted(self) -> Tokenizer:
|
| 1372 |
+
ot = self.original_tokenizer
|
| 1373 |
+
vocab = ot.encoder
|
| 1374 |
+
merges = list(ot.bpe_ranks.keys())
|
| 1375 |
+
|
| 1376 |
+
tokenizer = Tokenizer(
|
| 1377 |
+
BPE(
|
| 1378 |
+
vocab=vocab,
|
| 1379 |
+
merges=merges,
|
| 1380 |
+
dropout=None,
|
| 1381 |
+
continuing_subword_prefix="",
|
| 1382 |
+
end_of_word_suffix="",
|
| 1383 |
+
fuse_unk=False,
|
| 1384 |
+
unk_token=self.original_tokenizer.unk_token,
|
| 1385 |
+
)
|
| 1386 |
+
)
|
| 1387 |
+
|
| 1388 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
|
| 1389 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 1390 |
+
|
| 1391 |
+
cls = str(self.original_tokenizer.cls_token)
|
| 1392 |
+
sep = str(self.original_tokenizer.sep_token)
|
| 1393 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
| 1394 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
| 1395 |
+
|
| 1396 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
| 1397 |
+
single=f"{cls} $A {sep}",
|
| 1398 |
+
pair=f"{cls} $A {sep} $B {sep}",
|
| 1399 |
+
special_tokens=[
|
| 1400 |
+
(cls, cls_token_id),
|
| 1401 |
+
(sep, sep_token_id),
|
| 1402 |
+
],
|
| 1403 |
+
)
|
| 1404 |
+
|
| 1405 |
+
return tokenizer
|
| 1406 |
+
|
| 1407 |
+
|
| 1408 |
+
class MoshiConverter(SpmConverter):
|
| 1409 |
+
handle_byte_fallback = True
|
| 1410 |
+
|
| 1411 |
+
def __init__(self, vocab_file, model_max_length=None, **kwargs):
|
| 1412 |
+
requires_backends(self, "protobuf")
|
| 1413 |
+
|
| 1414 |
+
Converter.__init__(self, vocab_file)
|
| 1415 |
+
|
| 1416 |
+
# from .utils import sentencepiece_model_pb2 as model_pb2
|
| 1417 |
+
model_pb2 = import_protobuf()
|
| 1418 |
+
|
| 1419 |
+
m = model_pb2.ModelProto()
|
| 1420 |
+
with open(vocab_file, "rb") as f:
|
| 1421 |
+
m.ParseFromString(f.read())
|
| 1422 |
+
self.proto = m
|
| 1423 |
+
|
| 1424 |
+
def normalizer(self, proto):
|
| 1425 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
| 1426 |
+
_normalizers = [
|
| 1427 |
+
normalizers.Replace(" ", "▁"),
|
| 1428 |
+
]
|
| 1429 |
+
if not precompiled_charsmap:
|
| 1430 |
+
return normalizers.Sequence(_normalizers)
|
| 1431 |
+
else:
|
| 1432 |
+
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
|
| 1433 |
+
|
| 1434 |
+
def decoder(self, replacement, add_prefix_space):
|
| 1435 |
+
sequence = [
|
| 1436 |
+
decoders.Replace("▁", " "),
|
| 1437 |
+
decoders.ByteFallback(),
|
| 1438 |
+
decoders.Fuse(),
|
| 1439 |
+
]
|
| 1440 |
+
if add_prefix_space:
|
| 1441 |
+
sequence += [decoders.Strip(content=" ", left=1)]
|
| 1442 |
+
return decoders.Sequence(sequence)
|
| 1443 |
+
|
| 1444 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
| 1445 |
+
prepend_scheme = "first"
|
| 1446 |
+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
|
| 1447 |
+
|
| 1448 |
+
|
| 1449 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
| 1450 |
+
def bytes_to_unicode():
|
| 1451 |
+
"""
|
| 1452 |
+
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
| 1453 |
+
characters the bpe code barfs on.
|
| 1454 |
+
|
| 1455 |
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
| 1456 |
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
| 1457 |
+
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
| 1458 |
+
tables between utf-8 bytes and unicode strings.
|
| 1459 |
+
"""
|
| 1460 |
+
bs = (
|
| 1461 |
+
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| 1462 |
+
)
|
| 1463 |
+
cs = bs[:]
|
| 1464 |
+
n = 0
|
| 1465 |
+
for b in range(2**8):
|
| 1466 |
+
if b not in bs:
|
| 1467 |
+
bs.append(b)
|
| 1468 |
+
cs.append(2**8 + n)
|
| 1469 |
+
n += 1
|
| 1470 |
+
cs = [chr(n) for n in cs]
|
| 1471 |
+
return dict(zip(bs, cs))
|
| 1472 |
+
|
| 1473 |
+
|
| 1474 |
+
class TikTokenConverter:
|
| 1475 |
+
"""
|
| 1476 |
+
A general tiktoken converter.
|
| 1477 |
+
"""
|
| 1478 |
+
|
| 1479 |
+
def __init__(
|
| 1480 |
+
self,
|
| 1481 |
+
vocab_file=None,
|
| 1482 |
+
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
|
| 1483 |
+
add_prefix_space=False,
|
| 1484 |
+
additional_special_tokens=None,
|
| 1485 |
+
*args,
|
| 1486 |
+
**kwargs,
|
| 1487 |
+
):
|
| 1488 |
+
super().__init__(*args)
|
| 1489 |
+
self.vocab_file = vocab_file
|
| 1490 |
+
self.pattern = pattern
|
| 1491 |
+
self.add_prefix_space = add_prefix_space
|
| 1492 |
+
self.additional_special_tokens = additional_special_tokens
|
| 1493 |
+
|
| 1494 |
+
def extract_vocab_merges_from_model(self, tiktoken_url: str):
|
| 1495 |
+
try:
|
| 1496 |
+
from tiktoken.load import load_tiktoken_bpe
|
| 1497 |
+
except Exception:
|
| 1498 |
+
raise ValueError(
|
| 1499 |
+
"`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`."
|
| 1500 |
+
)
|
| 1501 |
+
|
| 1502 |
+
bpe_ranks = load_tiktoken_bpe(tiktoken_url)
|
| 1503 |
+
byte_encoder = bytes_to_unicode()
|
| 1504 |
+
|
| 1505 |
+
def token_bytes_to_string(b):
|
| 1506 |
+
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
| 1507 |
+
|
| 1508 |
+
merges = []
|
| 1509 |
+
vocab = {}
|
| 1510 |
+
for token, rank in bpe_ranks.items():
|
| 1511 |
+
vocab[token_bytes_to_string(token)] = rank
|
| 1512 |
+
if len(token) == 1:
|
| 1513 |
+
continue
|
| 1514 |
+
local = []
|
| 1515 |
+
for index in range(1, len(token)):
|
| 1516 |
+
piece_l, piece_r = token[:index], token[index:]
|
| 1517 |
+
if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
|
| 1518 |
+
local.append((piece_l, piece_r, rank))
|
| 1519 |
+
local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
|
| 1520 |
+
merges.extend(local)
|
| 1521 |
+
merges = sorted(merges, key=lambda val: val[2], reverse=False)
|
| 1522 |
+
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
|
| 1523 |
+
return vocab, merges
|
| 1524 |
+
|
| 1525 |
+
def tokenizer(self):
|
| 1526 |
+
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
|
| 1527 |
+
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
|
| 1528 |
+
if hasattr(tokenizer.model, "ignore_merges"):
|
| 1529 |
+
tokenizer.model.ignore_merges = True
|
| 1530 |
+
return tokenizer
|
| 1531 |
+
|
| 1532 |
+
def converted(self) -> Tokenizer:
|
| 1533 |
+
tokenizer = self.tokenizer()
|
| 1534 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
| 1535 |
+
[
|
| 1536 |
+
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
|
| 1537 |
+
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
|
| 1538 |
+
]
|
| 1539 |
+
)
|
| 1540 |
+
tokenizer.decoder = decoders.ByteLevel()
|
| 1541 |
+
tokenizer.add_special_tokens(self.additional_special_tokens)
|
| 1542 |
+
|
| 1543 |
+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
| 1544 |
+
|
| 1545 |
+
return tokenizer
|
| 1546 |
+
|
| 1547 |
+
|
| 1548 |
+
SLOW_TO_FAST_CONVERTERS = {
|
| 1549 |
+
"AlbertTokenizer": AlbertConverter,
|
| 1550 |
+
"BartTokenizer": RobertaConverter,
|
| 1551 |
+
"BarthezTokenizer": BarthezConverter,
|
| 1552 |
+
"BertTokenizer": BertConverter,
|
| 1553 |
+
"BigBirdTokenizer": BigBirdConverter,
|
| 1554 |
+
"BlenderbotTokenizer": BlenderbotConverter,
|
| 1555 |
+
"CamembertTokenizer": CamembertConverter,
|
| 1556 |
+
"CLIPTokenizer": CLIPConverter,
|
| 1557 |
+
"CodeGenTokenizer": GPT2Converter,
|
| 1558 |
+
"ConvBertTokenizer": BertConverter,
|
| 1559 |
+
"DebertaTokenizer": DebertaConverter,
|
| 1560 |
+
"DebertaV2Tokenizer": DebertaV2Converter,
|
| 1561 |
+
"DistilBertTokenizer": BertConverter,
|
| 1562 |
+
"DPRReaderTokenizer": BertConverter,
|
| 1563 |
+
"DPRQuestionEncoderTokenizer": BertConverter,
|
| 1564 |
+
"DPRContextEncoderTokenizer": BertConverter,
|
| 1565 |
+
"ElectraTokenizer": BertConverter,
|
| 1566 |
+
"FNetTokenizer": AlbertConverter,
|
| 1567 |
+
"FunnelTokenizer": FunnelConverter,
|
| 1568 |
+
"GPT2Tokenizer": GPT2Converter,
|
| 1569 |
+
"HerbertTokenizer": HerbertConverter,
|
| 1570 |
+
"LayoutLMTokenizer": BertConverter,
|
| 1571 |
+
"LayoutLMv2Tokenizer": BertConverter,
|
| 1572 |
+
"LayoutLMv3Tokenizer": RobertaConverter,
|
| 1573 |
+
"LayoutXLMTokenizer": XLMRobertaConverter,
|
| 1574 |
+
"LongformerTokenizer": RobertaConverter,
|
| 1575 |
+
"LEDTokenizer": RobertaConverter,
|
| 1576 |
+
"LxmertTokenizer": BertConverter,
|
| 1577 |
+
"MarkupLMTokenizer": MarkupLMConverter,
|
| 1578 |
+
"MBartTokenizer": MBartConverter,
|
| 1579 |
+
"MBart50Tokenizer": MBart50Converter,
|
| 1580 |
+
"MPNetTokenizer": MPNetConverter,
|
| 1581 |
+
"MobileBertTokenizer": BertConverter,
|
| 1582 |
+
"MvpTokenizer": RobertaConverter,
|
| 1583 |
+
"NllbTokenizer": NllbConverter,
|
| 1584 |
+
"OpenAIGPTTokenizer": OpenAIGPTConverter,
|
| 1585 |
+
"PegasusTokenizer": PegasusConverter,
|
| 1586 |
+
"Qwen2Tokenizer": Qwen2Converter,
|
| 1587 |
+
"RealmTokenizer": BertConverter,
|
| 1588 |
+
"ReformerTokenizer": ReformerConverter,
|
| 1589 |
+
"RemBertTokenizer": RemBertConverter,
|
| 1590 |
+
"RetriBertTokenizer": BertConverter,
|
| 1591 |
+
"RobertaTokenizer": RobertaConverter,
|
| 1592 |
+
"RoFormerTokenizer": RoFormerConverter,
|
| 1593 |
+
"SeamlessM4TTokenizer": SeamlessM4TConverter,
|
| 1594 |
+
"SqueezeBertTokenizer": BertConverter,
|
| 1595 |
+
"T5Tokenizer": T5Converter,
|
| 1596 |
+
"UdopTokenizer": UdopConverter,
|
| 1597 |
+
"WhisperTokenizer": WhisperConverter,
|
| 1598 |
+
"XLMRobertaTokenizer": XLMRobertaConverter,
|
| 1599 |
+
"XLNetTokenizer": XLNetConverter,
|
| 1600 |
+
"SplinterTokenizer": SplinterConverter,
|
| 1601 |
+
"XGLMTokenizer": XGLMConverter,
|
| 1602 |
+
"LlamaTokenizer": LlamaConverter,
|
| 1603 |
+
"CodeLlamaTokenizer": LlamaConverter,
|
| 1604 |
+
"GemmaTokenizer": GemmaConverter,
|
| 1605 |
+
"Phi3Tokenizer": LlamaConverter,
|
| 1606 |
+
}
|
| 1607 |
+
|
| 1608 |
+
|
| 1609 |
+
def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
|
| 1610 |
+
"""
|
| 1611 |
+
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
|
| 1612 |
+
|
| 1613 |
+
Args:
|
| 1614 |
+
transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
|
| 1615 |
+
Instance of a slow tokenizer to convert in the backend tokenizer for
|
| 1616 |
+
[`~tokenization_utils_base.PreTrainedTokenizerFast`].
|
| 1617 |
+
from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece.
|
| 1618 |
+
Defaults to False.
|
| 1619 |
+
|
| 1620 |
+
Return:
|
| 1621 |
+
A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
|
| 1622 |
+
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
|
| 1623 |
+
"""
|
| 1624 |
+
|
| 1625 |
+
tokenizer_class_name = transformer_tokenizer.__class__.__name__
|
| 1626 |
+
if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
|
| 1627 |
+
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
|
| 1628 |
+
return converter_class(transformer_tokenizer).converted()
|
| 1629 |
+
|
| 1630 |
+
else:
|
| 1631 |
+
try:
|
| 1632 |
+
logger.info("Converting from Tiktoken")
|
| 1633 |
+
return TikTokenConverter(
|
| 1634 |
+
vocab_file=transformer_tokenizer.vocab_file,
|
| 1635 |
+
additional_special_tokens=transformer_tokenizer.additional_special_tokens,
|
| 1636 |
+
).converted()
|
| 1637 |
+
except Exception:
|
| 1638 |
+
raise ValueError(
|
| 1639 |
+
f"Converting from Tiktoken failed, if a converter for SentencePiece is available, provide a model path "
|
| 1640 |
+
f"with a SentencePiece tokenizer.model file."
|
| 1641 |
+
f"Currently available slow->fast convertors: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
|
| 1642 |
+
)
|
.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizers_checkpoints_to_fast.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert slow tokenizers checkpoints in fast (serialization format of the `tokenizers` library)"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
import transformers
|
| 21 |
+
|
| 22 |
+
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
| 23 |
+
from .utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logging.set_verbosity_info()
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
TOKENIZER_CLASSES = {
|
| 32 |
+
# Phi3 uses Llama tokenizer
|
| 33 |
+
name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast")
|
| 34 |
+
for name in SLOW_TO_FAST_CONVERTERS
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):
|
| 39 |
+
if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES:
|
| 40 |
+
raise ValueError(f"Unrecognized tokenizer name, should be one of {list(TOKENIZER_CLASSES.keys())}.")
|
| 41 |
+
|
| 42 |
+
if tokenizer_name is None:
|
| 43 |
+
tokenizer_names = TOKENIZER_CLASSES
|
| 44 |
+
else:
|
| 45 |
+
tokenizer_names = {tokenizer_name: getattr(transformers, tokenizer_name + "Fast")}
|
| 46 |
+
|
| 47 |
+
logger.info(f"Loading tokenizer classes: {tokenizer_names}")
|
| 48 |
+
|
| 49 |
+
for tokenizer_name in tokenizer_names:
|
| 50 |
+
tokenizer_class = TOKENIZER_CLASSES[tokenizer_name]
|
| 51 |
+
|
| 52 |
+
add_prefix = True
|
| 53 |
+
if checkpoint_name is None:
|
| 54 |
+
checkpoint_names = list(tokenizer_class.max_model_input_sizes.keys())
|
| 55 |
+
else:
|
| 56 |
+
checkpoint_names = [checkpoint_name]
|
| 57 |
+
|
| 58 |
+
logger.info(f"For tokenizer {tokenizer_class.__class__.__name__} loading checkpoints: {checkpoint_names}")
|
| 59 |
+
|
| 60 |
+
for checkpoint in checkpoint_names:
|
| 61 |
+
logger.info(f"Loading {tokenizer_class.__class__.__name__} {checkpoint}")
|
| 62 |
+
|
| 63 |
+
# Load tokenizer
|
| 64 |
+
tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download)
|
| 65 |
+
|
| 66 |
+
# Save fast tokenizer
|
| 67 |
+
logger.info(f"Save fast tokenizer to {dump_path} with prefix {checkpoint} add_prefix {add_prefix}")
|
| 68 |
+
|
| 69 |
+
# For organization names we create sub-directories
|
| 70 |
+
if "/" in checkpoint:
|
| 71 |
+
checkpoint_directory, checkpoint_prefix_name = checkpoint.split("/")
|
| 72 |
+
dump_path_full = os.path.join(dump_path, checkpoint_directory)
|
| 73 |
+
elif add_prefix:
|
| 74 |
+
checkpoint_prefix_name = checkpoint
|
| 75 |
+
dump_path_full = dump_path
|
| 76 |
+
else:
|
| 77 |
+
checkpoint_prefix_name = None
|
| 78 |
+
dump_path_full = dump_path
|
| 79 |
+
|
| 80 |
+
logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
|
| 81 |
+
|
| 82 |
+
if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]:
|
| 83 |
+
file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint]
|
| 84 |
+
next_char = file_path.split(checkpoint)[-1][0]
|
| 85 |
+
if next_char == "/":
|
| 86 |
+
dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name)
|
| 87 |
+
checkpoint_prefix_name = None
|
| 88 |
+
|
| 89 |
+
logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
|
| 90 |
+
|
| 91 |
+
file_names = tokenizer.save_pretrained(
|
| 92 |
+
dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name
|
| 93 |
+
)
|
| 94 |
+
logger.info(f"=> File names {file_names}")
|
| 95 |
+
|
| 96 |
+
for file_name in file_names:
|
| 97 |
+
if not file_name.endswith("tokenizer.json"):
|
| 98 |
+
os.remove(file_name)
|
| 99 |
+
logger.info(f"=> removing {file_name}")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
parser = argparse.ArgumentParser()
|
| 104 |
+
# Required parameters
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--dump_path", default=None, type=str, required=True, help="Path to output generated fast tokenizer files."
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--tokenizer_name",
|
| 110 |
+
default=None,
|
| 111 |
+
type=str,
|
| 112 |
+
help=(
|
| 113 |
+
f"Optional tokenizer type selected in the list of {list(TOKENIZER_CLASSES.keys())}. If not given, will "
|
| 114 |
+
"download and convert all the checkpoints from AWS."
|
| 115 |
+
),
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--checkpoint_name",
|
| 119 |
+
default=None,
|
| 120 |
+
type=str,
|
| 121 |
+
help="Optional checkpoint name. If not given, will download and convert the canonical checkpoints from AWS.",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--force_download",
|
| 125 |
+
action="store_true",
|
| 126 |
+
help="Re-download checkpoints.",
|
| 127 |
+
)
|
| 128 |
+
args = parser.parse_args()
|
| 129 |
+
|
| 130 |
+
convert_slow_checkpoint_to_fast(args.tokenizer_name, args.checkpoint_name, args.dump_path, args.force_download)
|
.venv/lib/python3.11/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert Seq2Seq TF Hub checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
from . import (
|
| 20 |
+
BertConfig,
|
| 21 |
+
BertGenerationConfig,
|
| 22 |
+
BertGenerationDecoder,
|
| 23 |
+
BertGenerationEncoder,
|
| 24 |
+
load_tf_weights_in_bert_generation,
|
| 25 |
+
logging,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logging.set_verbosity_info()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
|
| 33 |
+
# Initialise PyTorch model
|
| 34 |
+
bert_config = BertConfig.from_pretrained(
|
| 35 |
+
"google-bert/bert-large-cased",
|
| 36 |
+
vocab_size=vocab_size,
|
| 37 |
+
max_position_embeddings=512,
|
| 38 |
+
is_decoder=True,
|
| 39 |
+
add_cross_attention=True,
|
| 40 |
+
)
|
| 41 |
+
bert_config_dict = bert_config.to_dict()
|
| 42 |
+
del bert_config_dict["type_vocab_size"]
|
| 43 |
+
config = BertGenerationConfig(**bert_config_dict)
|
| 44 |
+
if is_encoder:
|
| 45 |
+
model = BertGenerationEncoder(config)
|
| 46 |
+
else:
|
| 47 |
+
model = BertGenerationDecoder(config)
|
| 48 |
+
print(f"Building PyTorch model from configuration: {config}")
|
| 49 |
+
|
| 50 |
+
# Load weights from tf checkpoint
|
| 51 |
+
load_tf_weights_in_bert_generation(
|
| 52 |
+
model,
|
| 53 |
+
tf_hub_path,
|
| 54 |
+
model_class="bert",
|
| 55 |
+
is_encoder_named_decoder=is_encoder_named_decoder,
|
| 56 |
+
is_encoder=is_encoder,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Save pytorch-model
|
| 60 |
+
print(f"Save PyTorch model and config to {pytorch_dump_path}")
|
| 61 |
+
model.save_pretrained(pytorch_dump_path)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
parser = argparse.ArgumentParser()
|
| 66 |
+
# Required parameters
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--is_encoder_named_decoder",
|
| 75 |
+
action="store_true",
|
| 76 |
+
help="If decoder has to be renamed to encoder in PyTorch model.",
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.")
|
| 79 |
+
parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model")
|
| 80 |
+
args = parser.parse_args()
|
| 81 |
+
convert_tf_checkpoint_to_pytorch(
|
| 82 |
+
args.tf_hub_path,
|
| 83 |
+
args.pytorch_dump_path,
|
| 84 |
+
args.is_encoder_named_decoder,
|
| 85 |
+
args.vocab_size,
|
| 86 |
+
is_encoder=args.is_encoder,
|
| 87 |
+
)
|
.venv/lib/python3.11/site-packages/transformers/debug_utils.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import collections
|
| 16 |
+
|
| 17 |
+
from .utils import ExplicitEnum, is_torch_available, logging
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if is_torch_available():
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DebugUnderflowOverflow:
|
| 28 |
+
"""
|
| 29 |
+
This debug class helps detect and understand where the model starts getting very large or very small, and more
|
| 30 |
+
importantly `nan` or `inf` weight and activation elements.
|
| 31 |
+
|
| 32 |
+
There are 2 working modes:
|
| 33 |
+
|
| 34 |
+
1. Underflow/overflow detection (default)
|
| 35 |
+
2. Specific batch absolute min/max tracing without detection
|
| 36 |
+
|
| 37 |
+
Mode 1: Underflow/overflow detection
|
| 38 |
+
|
| 39 |
+
To activate the underflow/overflow detection, initialize the object with the model :
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
debug_overflow = DebugUnderflowOverflow(model)
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output
|
| 46 |
+
elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event,
|
| 47 |
+
each frame reporting
|
| 48 |
+
|
| 49 |
+
1. the fully qualified module name plus the class name whose `forward` was run
|
| 50 |
+
2. the absolute min and max value of all elements for each module weights, and the inputs and output
|
| 51 |
+
|
| 52 |
+
For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16
|
| 53 |
+
mixed precision :
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
Detected inf/nan during batch_number=0
|
| 57 |
+
Last 21 forward frames:
|
| 58 |
+
abs min abs max metadata
|
| 59 |
+
[...]
|
| 60 |
+
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
|
| 61 |
+
2.17e-07 4.50e+00 weight
|
| 62 |
+
1.79e-06 4.65e+00 input[0]
|
| 63 |
+
2.68e-06 3.70e+01 output
|
| 64 |
+
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
|
| 65 |
+
8.08e-07 2.66e+01 weight
|
| 66 |
+
1.79e-06 4.65e+00 input[0]
|
| 67 |
+
1.27e-04 2.37e+02 output
|
| 68 |
+
encoder.block.2.layer.1.DenseReluDense.wo Linear
|
| 69 |
+
1.01e-06 6.44e+00 weight
|
| 70 |
+
0.00e+00 9.74e+03 input[0]
|
| 71 |
+
3.18e-04 6.27e+04 output
|
| 72 |
+
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
|
| 73 |
+
1.79e-06 4.65e+00 input[0]
|
| 74 |
+
3.18e-04 6.27e+04 output
|
| 75 |
+
encoder.block.2.layer.1.dropout Dropout
|
| 76 |
+
3.18e-04 6.27e+04 input[0]
|
| 77 |
+
0.00e+00 inf output
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was
|
| 81 |
+
around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
|
| 82 |
+
renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
|
| 83 |
+
64K, and we get an overlow.
|
| 84 |
+
|
| 85 |
+
As you can see it's the previous frames that we need to look into when the numbers start going into very large for
|
| 86 |
+
fp16 numbers.
|
| 87 |
+
|
| 88 |
+
The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
|
| 89 |
+
|
| 90 |
+
By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
To validate that you have set up this debugging feature correctly, and you intend to use it in a training that
|
| 97 |
+
may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in
|
| 98 |
+
the next section.
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
Mode 2. Specific batch absolute min/max tracing without detection
|
| 102 |
+
|
| 103 |
+
The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
|
| 104 |
+
|
| 105 |
+
Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
|
| 106 |
+
given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
|
| 107 |
+
|
| 108 |
+
```python
|
| 109 |
+
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
|
| 113 |
+
|
| 114 |
+
This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
|
| 115 |
+
fast-forward right to that area.
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
Early stopping:
|
| 119 |
+
|
| 120 |
+
You can also specify the batch number after which to stop the training, with :
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
This feature is mainly useful in the tracing mode, but you can use it for any mode.
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
**Performance**:
|
| 130 |
+
|
| 131 |
+
As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training
|
| 132 |
+
down. Therefore remember to turn it off once the debugging needs have been met.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
model (`nn.Module`):
|
| 136 |
+
The model to debug.
|
| 137 |
+
max_frames_to_save (`int`, *optional*, defaults to 21):
|
| 138 |
+
How many frames back to record
|
| 139 |
+
trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
|
| 140 |
+
Which batch numbers to trace (turns detection off)
|
| 141 |
+
abort_after_batch_num (`int``, *optional*):
|
| 142 |
+
Whether to abort after a certain batch number has finished
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
|
| 146 |
+
self.model = model
|
| 147 |
+
self.trace_batch_nums = trace_batch_nums
|
| 148 |
+
self.abort_after_batch_num = abort_after_batch_num
|
| 149 |
+
|
| 150 |
+
# keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
|
| 151 |
+
self.frames = collections.deque([], max_frames_to_save)
|
| 152 |
+
self.frame = []
|
| 153 |
+
self.batch_number = 0
|
| 154 |
+
self.total_calls = 0
|
| 155 |
+
self.detected_overflow = False
|
| 156 |
+
self.prefix = " "
|
| 157 |
+
|
| 158 |
+
self.analyse_model()
|
| 159 |
+
|
| 160 |
+
self.register_forward_hook()
|
| 161 |
+
|
| 162 |
+
def save_frame(self, frame=None):
|
| 163 |
+
if frame is not None:
|
| 164 |
+
self.expand_frame(frame)
|
| 165 |
+
self.frames.append("\n".join(self.frame))
|
| 166 |
+
self.frame = [] # start a new frame
|
| 167 |
+
|
| 168 |
+
def expand_frame(self, line):
|
| 169 |
+
self.frame.append(line)
|
| 170 |
+
|
| 171 |
+
def trace_frames(self):
|
| 172 |
+
print("\n".join(self.frames))
|
| 173 |
+
self.frames = []
|
| 174 |
+
|
| 175 |
+
def reset_saved_frames(self):
|
| 176 |
+
self.frames = []
|
| 177 |
+
|
| 178 |
+
def dump_saved_frames(self):
|
| 179 |
+
print(f"\nDetected inf/nan during batch_number={self.batch_number}")
|
| 180 |
+
print(f"Last {len(self.frames)} forward frames:")
|
| 181 |
+
print(f"{'abs min':8} {'abs max':8} metadata")
|
| 182 |
+
print("\n".join(self.frames))
|
| 183 |
+
print("\n\n")
|
| 184 |
+
self.frames = []
|
| 185 |
+
|
| 186 |
+
def analyse_model(self):
|
| 187 |
+
# extract the fully qualified module names, to be able to report at run time. e.g.:
|
| 188 |
+
# encoder.block.2.layer.0.SelfAttention.o
|
| 189 |
+
#
|
| 190 |
+
# for shared weights only the first shared module name will be registered
|
| 191 |
+
self.module_names = {m: name for name, m in self.model.named_modules()}
|
| 192 |
+
# self.longest_module_name = max(len(v) for v in self.module_names.values())
|
| 193 |
+
|
| 194 |
+
def analyse_variable(self, var, ctx):
|
| 195 |
+
if torch.is_tensor(var):
|
| 196 |
+
self.expand_frame(get_abs_min_max(var, ctx))
|
| 197 |
+
if detect_overflow(var, ctx):
|
| 198 |
+
self.detected_overflow = True
|
| 199 |
+
elif var is None:
|
| 200 |
+
self.expand_frame(f"{'None':>17} {ctx}")
|
| 201 |
+
else:
|
| 202 |
+
self.expand_frame(f"{'not a tensor':>17} {ctx}")
|
| 203 |
+
|
| 204 |
+
def batch_start_frame(self):
|
| 205 |
+
self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
|
| 206 |
+
self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
|
| 207 |
+
|
| 208 |
+
def batch_end_frame(self):
|
| 209 |
+
self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")
|
| 210 |
+
|
| 211 |
+
def create_frame(self, module, input, output):
|
| 212 |
+
self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
|
| 213 |
+
|
| 214 |
+
# params
|
| 215 |
+
for name, p in module.named_parameters(recurse=False):
|
| 216 |
+
self.analyse_variable(p, name)
|
| 217 |
+
|
| 218 |
+
# inputs
|
| 219 |
+
if isinstance(input, tuple):
|
| 220 |
+
for i, x in enumerate(input):
|
| 221 |
+
self.analyse_variable(x, f"input[{i}]")
|
| 222 |
+
else:
|
| 223 |
+
self.analyse_variable(input, "input")
|
| 224 |
+
|
| 225 |
+
# outputs
|
| 226 |
+
if isinstance(output, tuple):
|
| 227 |
+
for i, x in enumerate(output):
|
| 228 |
+
# possibly a tuple of tuples
|
| 229 |
+
if isinstance(x, tuple):
|
| 230 |
+
for j, y in enumerate(x):
|
| 231 |
+
self.analyse_variable(y, f"output[{i}][{j}]")
|
| 232 |
+
else:
|
| 233 |
+
self.analyse_variable(x, f"output[{i}]")
|
| 234 |
+
else:
|
| 235 |
+
self.analyse_variable(output, "output")
|
| 236 |
+
|
| 237 |
+
self.save_frame()
|
| 238 |
+
|
| 239 |
+
def register_forward_hook(self):
|
| 240 |
+
self.model.apply(self._register_forward_hook)
|
| 241 |
+
|
| 242 |
+
def _register_forward_hook(self, module):
|
| 243 |
+
module.register_forward_hook(self.forward_hook)
|
| 244 |
+
|
| 245 |
+
def forward_hook(self, module, input, output):
|
| 246 |
+
# - input is a tuple of packed inputs (could be non-Tensors)
|
| 247 |
+
# - output could be a Tensor or a tuple of Tensors and non-Tensors
|
| 248 |
+
|
| 249 |
+
last_frame_of_batch = False
|
| 250 |
+
|
| 251 |
+
trace_mode = True if self.batch_number in self.trace_batch_nums else False
|
| 252 |
+
if trace_mode:
|
| 253 |
+
self.reset_saved_frames()
|
| 254 |
+
|
| 255 |
+
if self.total_calls == 0:
|
| 256 |
+
self.batch_start_frame()
|
| 257 |
+
self.total_calls += 1
|
| 258 |
+
|
| 259 |
+
# count batch numbers - the very first forward hook of the batch will be called when the
|
| 260 |
+
# batch completes - i.e. it gets called very last - we know this batch has finished
|
| 261 |
+
if module == self.model:
|
| 262 |
+
self.batch_number += 1
|
| 263 |
+
last_frame_of_batch = True
|
| 264 |
+
|
| 265 |
+
self.create_frame(module, input, output)
|
| 266 |
+
|
| 267 |
+
# if last_frame_of_batch:
|
| 268 |
+
# self.batch_end_frame()
|
| 269 |
+
|
| 270 |
+
if trace_mode:
|
| 271 |
+
self.trace_frames()
|
| 272 |
+
|
| 273 |
+
if last_frame_of_batch:
|
| 274 |
+
self.batch_start_frame()
|
| 275 |
+
|
| 276 |
+
if self.detected_overflow and not trace_mode:
|
| 277 |
+
self.dump_saved_frames()
|
| 278 |
+
|
| 279 |
+
# now we can abort, as it's pointless to continue running
|
| 280 |
+
raise ValueError(
|
| 281 |
+
"DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
|
| 282 |
+
"Please scroll up above this traceback to see the activation values prior to this event."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# abort after certain batch if requested to do so
|
| 286 |
+
if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
|
| 287 |
+
raise ValueError(
|
| 288 |
+
f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
|
| 289 |
+
f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def get_abs_min_max(var, ctx):
|
| 294 |
+
abs_var = var.abs()
|
| 295 |
+
return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def detect_overflow(var, ctx):
|
| 299 |
+
"""
|
| 300 |
+
Report whether the tensor contains any `nan` or `inf` entries.
|
| 301 |
+
|
| 302 |
+
This is useful for detecting overflows/underflows and best to call right after the function that did some math that
|
| 303 |
+
modified the tensor in question.
|
| 304 |
+
|
| 305 |
+
This function contains a few other helper features that you can enable and tweak directly if you want to track
|
| 306 |
+
various other things.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
var: the tensor variable to check
|
| 310 |
+
ctx: the message to print as a context
|
| 311 |
+
|
| 312 |
+
Return:
|
| 313 |
+
`True` if `inf` or `nan` was detected, `False` otherwise
|
| 314 |
+
"""
|
| 315 |
+
detected = False
|
| 316 |
+
if torch.isnan(var).any().item():
|
| 317 |
+
detected = True
|
| 318 |
+
print(f"{ctx} has nans")
|
| 319 |
+
if torch.isinf(var).any().item():
|
| 320 |
+
detected = True
|
| 321 |
+
print(f"{ctx} has infs")
|
| 322 |
+
|
| 323 |
+
# if needed to monitor large elements can enable the following
|
| 324 |
+
if 0: # and detected:
|
| 325 |
+
n100 = var[torch.ge(var.abs(), 100)]
|
| 326 |
+
if n100.numel() > 0:
|
| 327 |
+
print(f"{ctx}: n100={n100.numel()}")
|
| 328 |
+
n1000 = var[torch.ge(var.abs(), 1000)]
|
| 329 |
+
if n1000.numel() > 0:
|
| 330 |
+
print(f"{ctx}: n1000={n1000.numel()}")
|
| 331 |
+
n10000 = var[torch.ge(var.abs(), 10000)]
|
| 332 |
+
if n10000.numel() > 0:
|
| 333 |
+
print(f"{ctx}: n10000={n10000.numel()}")
|
| 334 |
+
|
| 335 |
+
if 0:
|
| 336 |
+
print(f"min={var.min():9.2e} max={var.max():9.2e}")
|
| 337 |
+
|
| 338 |
+
if 0:
|
| 339 |
+
print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
|
| 340 |
+
|
| 341 |
+
return detected
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class DebugOption(ExplicitEnum):
|
| 345 |
+
UNDERFLOW_OVERFLOW = "underflow_overflow"
|
| 346 |
+
TPU_METRICS_DEBUG = "tpu_metrics_debug"
|
.venv/lib/python3.11/site-packages/transformers/dependency_versions_check.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .dependency_versions_table import deps
|
| 16 |
+
from .utils.versions import require_version, require_version_core
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# define which module versions we always want to check at run time
|
| 20 |
+
# (usually the ones defined in `install_requires` in setup.py)
|
| 21 |
+
#
|
| 22 |
+
# order specific notes:
|
| 23 |
+
# - tqdm must be checked before tokenizers
|
| 24 |
+
|
| 25 |
+
pkgs_to_check_at_runtime = [
|
| 26 |
+
"python",
|
| 27 |
+
"tqdm",
|
| 28 |
+
"regex",
|
| 29 |
+
"requests",
|
| 30 |
+
"packaging",
|
| 31 |
+
"filelock",
|
| 32 |
+
"numpy",
|
| 33 |
+
"tokenizers",
|
| 34 |
+
"huggingface-hub",
|
| 35 |
+
"safetensors",
|
| 36 |
+
"accelerate",
|
| 37 |
+
"pyyaml",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
for pkg in pkgs_to_check_at_runtime:
|
| 41 |
+
if pkg in deps:
|
| 42 |
+
if pkg == "tokenizers":
|
| 43 |
+
# must be loaded here, or else tqdm check may fail
|
| 44 |
+
from .utils import is_tokenizers_available
|
| 45 |
+
|
| 46 |
+
if not is_tokenizers_available():
|
| 47 |
+
continue # not required, check version only if installed
|
| 48 |
+
elif pkg == "accelerate":
|
| 49 |
+
# must be loaded here, or else tqdm check may fail
|
| 50 |
+
from .utils import is_accelerate_available
|
| 51 |
+
|
| 52 |
+
# Maybe switch to is_torch_available in the future here so that Accelerate is hard dep of
|
| 53 |
+
# Transformers with PyTorch
|
| 54 |
+
if not is_accelerate_available():
|
| 55 |
+
continue # not required, check version only if installed
|
| 56 |
+
|
| 57 |
+
require_version_core(deps[pkg])
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def dep_version_check(pkg, hint=None):
|
| 63 |
+
require_version(deps[pkg], hint)
|
.venv/lib/python3.11/site-packages/transformers/dependency_versions_table.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
| 2 |
+
# 1. modify the `_deps` dict in setup.py
|
| 3 |
+
# 2. run `make deps_table_update``
|
| 4 |
+
deps = {
|
| 5 |
+
"Pillow": "Pillow>=10.0.1,<=15.0",
|
| 6 |
+
"accelerate": "accelerate>=0.26.0",
|
| 7 |
+
"av": "av==9.2.0",
|
| 8 |
+
"beautifulsoup4": "beautifulsoup4",
|
| 9 |
+
"blobfile": "blobfile",
|
| 10 |
+
"codecarbon": "codecarbon>=2.8.1",
|
| 11 |
+
"cookiecutter": "cookiecutter==1.7.3",
|
| 12 |
+
"dataclasses": "dataclasses",
|
| 13 |
+
"datasets": "datasets!=2.5.0",
|
| 14 |
+
"deepspeed": "deepspeed>=0.9.3",
|
| 15 |
+
"diffusers": "diffusers",
|
| 16 |
+
"dill": "dill<0.3.5",
|
| 17 |
+
"evaluate": "evaluate>=0.2.0",
|
| 18 |
+
"faiss-cpu": "faiss-cpu",
|
| 19 |
+
"fastapi": "fastapi",
|
| 20 |
+
"filelock": "filelock",
|
| 21 |
+
"flax": "flax>=0.4.1,<=0.7.0",
|
| 22 |
+
"fsspec": "fsspec<2023.10.0",
|
| 23 |
+
"ftfy": "ftfy",
|
| 24 |
+
"fugashi": "fugashi>=1.0",
|
| 25 |
+
"GitPython": "GitPython<3.1.19",
|
| 26 |
+
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
| 27 |
+
"huggingface-hub": "huggingface-hub>=0.24.0,<1.0",
|
| 28 |
+
"importlib_metadata": "importlib_metadata",
|
| 29 |
+
"ipadic": "ipadic>=1.0.0,<2.0",
|
| 30 |
+
"isort": "isort>=5.5.4",
|
| 31 |
+
"jax": "jax>=0.4.1,<=0.4.13",
|
| 32 |
+
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
|
| 33 |
+
"jieba": "jieba",
|
| 34 |
+
"jinja2": "jinja2>=3.1.0",
|
| 35 |
+
"kenlm": "kenlm",
|
| 36 |
+
"keras": "keras>2.9,<2.16",
|
| 37 |
+
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
| 38 |
+
"librosa": "librosa",
|
| 39 |
+
"nltk": "nltk<=3.8.1",
|
| 40 |
+
"natten": "natten>=0.14.6,<0.15.0",
|
| 41 |
+
"numpy": "numpy>=1.17",
|
| 42 |
+
"onnxconverter-common": "onnxconverter-common",
|
| 43 |
+
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
| 44 |
+
"onnxruntime": "onnxruntime>=1.4.0",
|
| 45 |
+
"opencv-python": "opencv-python",
|
| 46 |
+
"optimum-benchmark": "optimum-benchmark>=0.3.0",
|
| 47 |
+
"optuna": "optuna",
|
| 48 |
+
"optax": "optax>=0.0.8,<=0.1.4",
|
| 49 |
+
"packaging": "packaging>=20.0",
|
| 50 |
+
"parameterized": "parameterized",
|
| 51 |
+
"phonemizer": "phonemizer",
|
| 52 |
+
"protobuf": "protobuf",
|
| 53 |
+
"psutil": "psutil",
|
| 54 |
+
"pyyaml": "pyyaml>=5.1",
|
| 55 |
+
"pydantic": "pydantic",
|
| 56 |
+
"pytest": "pytest>=7.2.0,<8.0.0",
|
| 57 |
+
"pytest-asyncio": "pytest-asyncio",
|
| 58 |
+
"pytest-timeout": "pytest-timeout",
|
| 59 |
+
"pytest-xdist": "pytest-xdist",
|
| 60 |
+
"python": "python>=3.9.0",
|
| 61 |
+
"ray[tune]": "ray[tune]>=2.7.0",
|
| 62 |
+
"regex": "regex!=2019.12.17",
|
| 63 |
+
"requests": "requests",
|
| 64 |
+
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
|
| 65 |
+
"rjieba": "rjieba",
|
| 66 |
+
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
|
| 67 |
+
"ruff": "ruff==0.5.1",
|
| 68 |
+
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
|
| 69 |
+
"sacremoses": "sacremoses",
|
| 70 |
+
"safetensors": "safetensors>=0.4.1",
|
| 71 |
+
"sagemaker": "sagemaker>=2.31.0",
|
| 72 |
+
"schedulefree": "schedulefree>=1.2.6",
|
| 73 |
+
"scikit-learn": "scikit-learn",
|
| 74 |
+
"scipy": "scipy<1.13.0",
|
| 75 |
+
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
| 76 |
+
"sigopt": "sigopt",
|
| 77 |
+
"starlette": "starlette",
|
| 78 |
+
"sudachipy": "sudachipy>=0.6.6",
|
| 79 |
+
"sudachidict_core": "sudachidict_core>=20220729",
|
| 80 |
+
"tensorboard": "tensorboard",
|
| 81 |
+
"tensorflow-cpu": "tensorflow-cpu>2.9,<2.16",
|
| 82 |
+
"tensorflow": "tensorflow>2.9,<2.16",
|
| 83 |
+
"tensorflow-text": "tensorflow-text<2.16",
|
| 84 |
+
"tensorflow-probability": "tensorflow-probability<0.24",
|
| 85 |
+
"tf2onnx": "tf2onnx",
|
| 86 |
+
"timeout-decorator": "timeout-decorator",
|
| 87 |
+
"tiktoken": "tiktoken",
|
| 88 |
+
"timm": "timm<=1.0.11",
|
| 89 |
+
"tokenizers": "tokenizers>=0.21,<0.22",
|
| 90 |
+
"torch": "torch>=2.0",
|
| 91 |
+
"torchaudio": "torchaudio",
|
| 92 |
+
"torchvision": "torchvision",
|
| 93 |
+
"pyctcdecode": "pyctcdecode>=0.4.0",
|
| 94 |
+
"tqdm": "tqdm>=4.27",
|
| 95 |
+
"unidic": "unidic>=1.0.2",
|
| 96 |
+
"unidic_lite": "unidic_lite>=1.0.7",
|
| 97 |
+
"urllib3": "urllib3<2.0.0",
|
| 98 |
+
"uvicorn": "uvicorn",
|
| 99 |
+
"pytest-rich": "pytest-rich",
|
| 100 |
+
"libcst": "libcst",
|
| 101 |
+
"rich": "rich",
|
| 102 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/dynamic_module_utils.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Utilities to dynamically load objects from the Hub."""
|
| 16 |
+
|
| 17 |
+
import filecmp
|
| 18 |
+
import hashlib
|
| 19 |
+
import importlib
|
| 20 |
+
import importlib.util
|
| 21 |
+
import os
|
| 22 |
+
import re
|
| 23 |
+
import shutil
|
| 24 |
+
import signal
|
| 25 |
+
import sys
|
| 26 |
+
import threading
|
| 27 |
+
import typing
|
| 28 |
+
import warnings
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from types import ModuleType
|
| 31 |
+
from typing import Any, Dict, List, Optional, Union
|
| 32 |
+
|
| 33 |
+
from huggingface_hub import try_to_load_from_cache
|
| 34 |
+
|
| 35 |
+
from .utils import (
|
| 36 |
+
HF_MODULES_CACHE,
|
| 37 |
+
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
| 38 |
+
cached_file,
|
| 39 |
+
extract_commit_hash,
|
| 40 |
+
is_offline_mode,
|
| 41 |
+
logging,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 46 |
+
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def init_hf_modules():
|
| 50 |
+
"""
|
| 51 |
+
Creates the cache directory for modules with an init, and adds it to the Python path.
|
| 52 |
+
"""
|
| 53 |
+
# This function has already been executed if HF_MODULES_CACHE already is in the Python path.
|
| 54 |
+
if HF_MODULES_CACHE in sys.path:
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
sys.path.append(HF_MODULES_CACHE)
|
| 58 |
+
os.makedirs(HF_MODULES_CACHE, exist_ok=True)
|
| 59 |
+
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
|
| 60 |
+
if not init_path.exists():
|
| 61 |
+
init_path.touch()
|
| 62 |
+
importlib.invalidate_caches()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
|
| 66 |
+
"""
|
| 67 |
+
Creates a dynamic module in the cache directory for modules.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
name (`str` or `os.PathLike`):
|
| 71 |
+
The name of the dynamic module to create.
|
| 72 |
+
"""
|
| 73 |
+
init_hf_modules()
|
| 74 |
+
dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
|
| 75 |
+
# If the parent module does not exist yet, recursively create it.
|
| 76 |
+
if not dynamic_module_path.parent.exists():
|
| 77 |
+
create_dynamic_module(dynamic_module_path.parent)
|
| 78 |
+
os.makedirs(dynamic_module_path, exist_ok=True)
|
| 79 |
+
init_path = dynamic_module_path / "__init__.py"
|
| 80 |
+
if not init_path.exists():
|
| 81 |
+
init_path.touch()
|
| 82 |
+
# It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
|
| 83 |
+
# with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
|
| 84 |
+
importlib.invalidate_caches()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
|
| 88 |
+
"""
|
| 89 |
+
Get the list of modules that are relatively imported in a module file.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
module_file (`str` or `os.PathLike`): The module file to inspect.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
`List[str]`: The list of relative imports in the module.
|
| 96 |
+
"""
|
| 97 |
+
with open(module_file, "r", encoding="utf-8") as f:
|
| 98 |
+
content = f.read()
|
| 99 |
+
|
| 100 |
+
# Imports of the form `import .xxx`
|
| 101 |
+
relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
|
| 102 |
+
# Imports of the form `from .xxx import yyy`
|
| 103 |
+
relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
|
| 104 |
+
# Unique-ify
|
| 105 |
+
return list(set(relative_imports))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
|
| 109 |
+
"""
|
| 110 |
+
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
|
| 111 |
+
imports (if a imports b and b imports c, it will return module files for b and c).
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
module_file (`str` or `os.PathLike`): The module file to inspect.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
|
| 118 |
+
of module files a given module needs.
|
| 119 |
+
"""
|
| 120 |
+
no_change = False
|
| 121 |
+
files_to_check = [module_file]
|
| 122 |
+
all_relative_imports = []
|
| 123 |
+
|
| 124 |
+
# Let's recurse through all relative imports
|
| 125 |
+
while not no_change:
|
| 126 |
+
new_imports = []
|
| 127 |
+
for f in files_to_check:
|
| 128 |
+
new_imports.extend(get_relative_imports(f))
|
| 129 |
+
|
| 130 |
+
module_path = Path(module_file).parent
|
| 131 |
+
new_import_files = [str(module_path / m) for m in new_imports]
|
| 132 |
+
new_import_files = [f for f in new_import_files if f not in all_relative_imports]
|
| 133 |
+
files_to_check = [f"{f}.py" for f in new_import_files]
|
| 134 |
+
|
| 135 |
+
no_change = len(new_import_files) == 0
|
| 136 |
+
all_relative_imports.extend(files_to_check)
|
| 137 |
+
|
| 138 |
+
return all_relative_imports
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
| 142 |
+
"""
|
| 143 |
+
Extracts all the libraries (not relative imports this time) that are imported in a file.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
filename (`str` or `os.PathLike`): The module file to inspect.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
`List[str]`: The list of all packages required to use the input module.
|
| 150 |
+
"""
|
| 151 |
+
with open(filename, "r", encoding="utf-8") as f:
|
| 152 |
+
content = f.read()
|
| 153 |
+
|
| 154 |
+
# filter out try/except block so in custom code we can have try/except imports
|
| 155 |
+
content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
|
| 156 |
+
|
| 157 |
+
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
|
| 158 |
+
content = re.sub(
|
| 159 |
+
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", "", content, flags=re.MULTILINE
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Imports of the form `import xxx`
|
| 163 |
+
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
| 164 |
+
# Imports of the form `from xxx import yyy`
|
| 165 |
+
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
| 166 |
+
# Only keep the top-level module
|
| 167 |
+
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
| 168 |
+
return list(set(imports))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
| 172 |
+
"""
|
| 173 |
+
Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
|
| 174 |
+
library is missing.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
filename (`str` or `os.PathLike`): The module file to check.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
`List[str]`: The list of relative imports in the file.
|
| 181 |
+
"""
|
| 182 |
+
imports = get_imports(filename)
|
| 183 |
+
missing_packages = []
|
| 184 |
+
for imp in imports:
|
| 185 |
+
try:
|
| 186 |
+
importlib.import_module(imp)
|
| 187 |
+
except ImportError as exception:
|
| 188 |
+
logger.warning(f"Encountered exception while importing {imp}: {exception}")
|
| 189 |
+
# Some packages can fail with an ImportError because of a dependency issue.
|
| 190 |
+
# This check avoids hiding such errors.
|
| 191 |
+
# See https://github.com/huggingface/transformers/issues/33604
|
| 192 |
+
if "No module named" in str(exception):
|
| 193 |
+
missing_packages.append(imp)
|
| 194 |
+
else:
|
| 195 |
+
raise
|
| 196 |
+
|
| 197 |
+
if len(missing_packages) > 0:
|
| 198 |
+
raise ImportError(
|
| 199 |
+
"This modeling file requires the following packages that were not found in your environment: "
|
| 200 |
+
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
return get_relative_imports(filename)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_class_in_module(
|
| 207 |
+
class_name: str,
|
| 208 |
+
module_path: Union[str, os.PathLike],
|
| 209 |
+
*,
|
| 210 |
+
force_reload: bool = False,
|
| 211 |
+
) -> typing.Type:
|
| 212 |
+
"""
|
| 213 |
+
Import a module on the cache directory for modules and extract a class from it.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
class_name (`str`): The name of the class to import.
|
| 217 |
+
module_path (`str` or `os.PathLike`): The path to the module to import.
|
| 218 |
+
force_reload (`bool`, *optional*, defaults to `False`):
|
| 219 |
+
Whether to reload the dynamic module from file if it already exists in `sys.modules`.
|
| 220 |
+
Otherwise, the module is only reloaded if the file has changed.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
`typing.Type`: The class looked for.
|
| 224 |
+
"""
|
| 225 |
+
name = os.path.normpath(module_path)
|
| 226 |
+
if name.endswith(".py"):
|
| 227 |
+
name = name[:-3]
|
| 228 |
+
name = name.replace(os.path.sep, ".")
|
| 229 |
+
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
| 230 |
+
with _HF_REMOTE_CODE_LOCK:
|
| 231 |
+
if force_reload:
|
| 232 |
+
sys.modules.pop(name, None)
|
| 233 |
+
importlib.invalidate_caches()
|
| 234 |
+
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
| 235 |
+
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
| 236 |
+
|
| 237 |
+
# Hash the module file and all its relative imports to check if we need to reload it
|
| 238 |
+
module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
|
| 239 |
+
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
|
| 240 |
+
|
| 241 |
+
module: ModuleType
|
| 242 |
+
if cached_module is None:
|
| 243 |
+
module = importlib.util.module_from_spec(module_spec)
|
| 244 |
+
# insert it into sys.modules before any loading begins
|
| 245 |
+
sys.modules[name] = module
|
| 246 |
+
else:
|
| 247 |
+
module = cached_module
|
| 248 |
+
# reload in both cases, unless the module is already imported and the hash hits
|
| 249 |
+
if getattr(module, "__transformers_module_hash__", "") != module_hash:
|
| 250 |
+
module_spec.loader.exec_module(module)
|
| 251 |
+
module.__transformers_module_hash__ = module_hash
|
| 252 |
+
return getattr(module, class_name)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def get_cached_module_file(
|
| 256 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 257 |
+
module_file: str,
|
| 258 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 259 |
+
force_download: bool = False,
|
| 260 |
+
resume_download: Optional[bool] = None,
|
| 261 |
+
proxies: Optional[Dict[str, str]] = None,
|
| 262 |
+
token: Optional[Union[bool, str]] = None,
|
| 263 |
+
revision: Optional[str] = None,
|
| 264 |
+
local_files_only: bool = False,
|
| 265 |
+
repo_type: Optional[str] = None,
|
| 266 |
+
_commit_hash: Optional[str] = None,
|
| 267 |
+
**deprecated_kwargs,
|
| 268 |
+
) -> str:
|
| 269 |
+
"""
|
| 270 |
+
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
| 271 |
+
Transformers module.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 275 |
+
This can be either:
|
| 276 |
+
|
| 277 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 278 |
+
huggingface.co.
|
| 279 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 280 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 281 |
+
|
| 282 |
+
module_file (`str`):
|
| 283 |
+
The name of the module file containing the class to look for.
|
| 284 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 285 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 286 |
+
cache should not be used.
|
| 287 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 288 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 289 |
+
exist.
|
| 290 |
+
resume_download:
|
| 291 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 292 |
+
Will be removed in v5 of Transformers.
|
| 293 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 294 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 295 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 296 |
+
token (`str` or *bool*, *optional*):
|
| 297 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 298 |
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
| 299 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 300 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 301 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 302 |
+
identifier allowed by git.
|
| 303 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 304 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
| 305 |
+
repo_type (`str`, *optional*):
|
| 306 |
+
Specify the repo type (useful when downloading from a space for instance).
|
| 307 |
+
|
| 308 |
+
<Tip>
|
| 309 |
+
|
| 310 |
+
Passing `token=True` is required when you want to use a private model.
|
| 311 |
+
|
| 312 |
+
</Tip>
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
`str`: The path to the module inside the cache.
|
| 316 |
+
"""
|
| 317 |
+
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
| 318 |
+
if use_auth_token is not None:
|
| 319 |
+
warnings.warn(
|
| 320 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 321 |
+
FutureWarning,
|
| 322 |
+
)
|
| 323 |
+
if token is not None:
|
| 324 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 325 |
+
token = use_auth_token
|
| 326 |
+
|
| 327 |
+
if is_offline_mode() and not local_files_only:
|
| 328 |
+
logger.info("Offline mode: forcing local_files_only=True")
|
| 329 |
+
local_files_only = True
|
| 330 |
+
|
| 331 |
+
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
| 332 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
| 333 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
| 334 |
+
if is_local:
|
| 335 |
+
submodule = os.path.basename(pretrained_model_name_or_path)
|
| 336 |
+
else:
|
| 337 |
+
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
| 338 |
+
cached_module = try_to_load_from_cache(
|
| 339 |
+
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
new_files = []
|
| 343 |
+
try:
|
| 344 |
+
# Load from URL or cache if already cached
|
| 345 |
+
resolved_module_file = cached_file(
|
| 346 |
+
pretrained_model_name_or_path,
|
| 347 |
+
module_file,
|
| 348 |
+
cache_dir=cache_dir,
|
| 349 |
+
force_download=force_download,
|
| 350 |
+
proxies=proxies,
|
| 351 |
+
resume_download=resume_download,
|
| 352 |
+
local_files_only=local_files_only,
|
| 353 |
+
token=token,
|
| 354 |
+
revision=revision,
|
| 355 |
+
repo_type=repo_type,
|
| 356 |
+
_commit_hash=_commit_hash,
|
| 357 |
+
)
|
| 358 |
+
if not is_local and cached_module != resolved_module_file:
|
| 359 |
+
new_files.append(module_file)
|
| 360 |
+
|
| 361 |
+
except EnvironmentError:
|
| 362 |
+
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
| 363 |
+
raise
|
| 364 |
+
|
| 365 |
+
# Check we have all the requirements in our environment
|
| 366 |
+
modules_needed = check_imports(resolved_module_file)
|
| 367 |
+
|
| 368 |
+
# Now we move the module inside our cached dynamic modules.
|
| 369 |
+
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
| 370 |
+
create_dynamic_module(full_submodule)
|
| 371 |
+
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
| 372 |
+
if submodule == os.path.basename(pretrained_model_name_or_path):
|
| 373 |
+
# We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
|
| 374 |
+
# has changed since last copy.
|
| 375 |
+
if not (submodule_path / module_file).exists() or not filecmp.cmp(
|
| 376 |
+
resolved_module_file, str(submodule_path / module_file)
|
| 377 |
+
):
|
| 378 |
+
shutil.copy(resolved_module_file, submodule_path / module_file)
|
| 379 |
+
importlib.invalidate_caches()
|
| 380 |
+
for module_needed in modules_needed:
|
| 381 |
+
module_needed = f"{module_needed}.py"
|
| 382 |
+
module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
|
| 383 |
+
if not (submodule_path / module_needed).exists() or not filecmp.cmp(
|
| 384 |
+
module_needed_file, str(submodule_path / module_needed)
|
| 385 |
+
):
|
| 386 |
+
shutil.copy(module_needed_file, submodule_path / module_needed)
|
| 387 |
+
importlib.invalidate_caches()
|
| 388 |
+
else:
|
| 389 |
+
# Get the commit hash
|
| 390 |
+
commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
|
| 391 |
+
|
| 392 |
+
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
|
| 393 |
+
# benefit of versioning.
|
| 394 |
+
submodule_path = submodule_path / commit_hash
|
| 395 |
+
full_submodule = full_submodule + os.path.sep + commit_hash
|
| 396 |
+
create_dynamic_module(full_submodule)
|
| 397 |
+
|
| 398 |
+
if not (submodule_path / module_file).exists():
|
| 399 |
+
shutil.copy(resolved_module_file, submodule_path / module_file)
|
| 400 |
+
importlib.invalidate_caches()
|
| 401 |
+
# Make sure we also have every file with relative
|
| 402 |
+
for module_needed in modules_needed:
|
| 403 |
+
if not (submodule_path / f"{module_needed}.py").exists():
|
| 404 |
+
get_cached_module_file(
|
| 405 |
+
pretrained_model_name_or_path,
|
| 406 |
+
f"{module_needed}.py",
|
| 407 |
+
cache_dir=cache_dir,
|
| 408 |
+
force_download=force_download,
|
| 409 |
+
resume_download=resume_download,
|
| 410 |
+
proxies=proxies,
|
| 411 |
+
token=token,
|
| 412 |
+
revision=revision,
|
| 413 |
+
local_files_only=local_files_only,
|
| 414 |
+
_commit_hash=commit_hash,
|
| 415 |
+
)
|
| 416 |
+
new_files.append(f"{module_needed}.py")
|
| 417 |
+
|
| 418 |
+
if len(new_files) > 0 and revision is None:
|
| 419 |
+
new_files = "\n".join([f"- {f}" for f in new_files])
|
| 420 |
+
repo_type_str = "" if repo_type is None else f"{repo_type}s/"
|
| 421 |
+
url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
|
| 422 |
+
logger.warning(
|
| 423 |
+
f"A new version of the following files was downloaded from {url}:\n{new_files}"
|
| 424 |
+
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
|
| 425 |
+
"versions of the code file, you can pin a revision."
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
return os.path.join(full_submodule, module_file)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def get_class_from_dynamic_module(
|
| 432 |
+
class_reference: str,
|
| 433 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 434 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 435 |
+
force_download: bool = False,
|
| 436 |
+
resume_download: Optional[bool] = None,
|
| 437 |
+
proxies: Optional[Dict[str, str]] = None,
|
| 438 |
+
token: Optional[Union[bool, str]] = None,
|
| 439 |
+
revision: Optional[str] = None,
|
| 440 |
+
local_files_only: bool = False,
|
| 441 |
+
repo_type: Optional[str] = None,
|
| 442 |
+
code_revision: Optional[str] = None,
|
| 443 |
+
**kwargs,
|
| 444 |
+
) -> typing.Type:
|
| 445 |
+
"""
|
| 446 |
+
Extracts a class from a module file, present in the local folder or repository of a model.
|
| 447 |
+
|
| 448 |
+
<Tip warning={true}>
|
| 449 |
+
|
| 450 |
+
Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
|
| 451 |
+
therefore only be called on trusted repos.
|
| 452 |
+
|
| 453 |
+
</Tip>
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
class_reference (`str`):
|
| 459 |
+
The full name of the class to load, including its module and optionally its repo.
|
| 460 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 461 |
+
This can be either:
|
| 462 |
+
|
| 463 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 464 |
+
huggingface.co.
|
| 465 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 466 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 467 |
+
|
| 468 |
+
This is used when `class_reference` does not specify another repo.
|
| 469 |
+
module_file (`str`):
|
| 470 |
+
The name of the module file containing the class to look for.
|
| 471 |
+
class_name (`str`):
|
| 472 |
+
The name of the class to import in the module.
|
| 473 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 474 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 475 |
+
cache should not be used.
|
| 476 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 477 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 478 |
+
exist.
|
| 479 |
+
resume_download:
|
| 480 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 481 |
+
Will be removed in v5 of Transformers.
|
| 482 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 483 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 484 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 485 |
+
token (`str` or `bool`, *optional*):
|
| 486 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 487 |
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
| 488 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 489 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 490 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 491 |
+
identifier allowed by git.
|
| 492 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 493 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
| 494 |
+
repo_type (`str`, *optional*):
|
| 495 |
+
Specify the repo type (useful when downloading from a space for instance).
|
| 496 |
+
code_revision (`str`, *optional*, defaults to `"main"`):
|
| 497 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
|
| 498 |
+
rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
|
| 499 |
+
storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
|
| 500 |
+
|
| 501 |
+
<Tip>
|
| 502 |
+
|
| 503 |
+
Passing `token=True` is required when you want to use a private model.
|
| 504 |
+
|
| 505 |
+
</Tip>
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
`typing.Type`: The class, dynamically imported from the module.
|
| 509 |
+
|
| 510 |
+
Examples:
|
| 511 |
+
|
| 512 |
+
```python
|
| 513 |
+
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
|
| 514 |
+
# module.
|
| 515 |
+
cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
|
| 516 |
+
|
| 517 |
+
# Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
|
| 518 |
+
# module.
|
| 519 |
+
cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
|
| 520 |
+
```"""
|
| 521 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 522 |
+
if use_auth_token is not None:
|
| 523 |
+
warnings.warn(
|
| 524 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 525 |
+
FutureWarning,
|
| 526 |
+
)
|
| 527 |
+
if token is not None:
|
| 528 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 529 |
+
token = use_auth_token
|
| 530 |
+
|
| 531 |
+
# Catch the name of the repo if it's specified in `class_reference`
|
| 532 |
+
if "--" in class_reference:
|
| 533 |
+
repo_id, class_reference = class_reference.split("--")
|
| 534 |
+
else:
|
| 535 |
+
repo_id = pretrained_model_name_or_path
|
| 536 |
+
module_file, class_name = class_reference.split(".")
|
| 537 |
+
|
| 538 |
+
if code_revision is None and pretrained_model_name_or_path == repo_id:
|
| 539 |
+
code_revision = revision
|
| 540 |
+
# And lastly we get the class inside our newly created module
|
| 541 |
+
final_module = get_cached_module_file(
|
| 542 |
+
repo_id,
|
| 543 |
+
module_file + ".py",
|
| 544 |
+
cache_dir=cache_dir,
|
| 545 |
+
force_download=force_download,
|
| 546 |
+
resume_download=resume_download,
|
| 547 |
+
proxies=proxies,
|
| 548 |
+
token=token,
|
| 549 |
+
revision=code_revision,
|
| 550 |
+
local_files_only=local_files_only,
|
| 551 |
+
repo_type=repo_type,
|
| 552 |
+
)
|
| 553 |
+
return get_class_in_module(class_name, final_module, force_reload=force_download)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
|
| 557 |
+
"""
|
| 558 |
+
Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
|
| 559 |
+
adds the proper fields in a config.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
obj (`Any`): The object for which to save the module files.
|
| 563 |
+
folder (`str` or `os.PathLike`): The folder where to save.
|
| 564 |
+
config (`PretrainedConfig` or dictionary, `optional`):
|
| 565 |
+
A config in which to register the auto_map corresponding to this custom object.
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
`List[str]`: The list of files saved.
|
| 569 |
+
"""
|
| 570 |
+
if obj.__module__ == "__main__":
|
| 571 |
+
logger.warning(
|
| 572 |
+
f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
|
| 573 |
+
"this code in a separate module so we can include it in the saved folder and make it easier to share via "
|
| 574 |
+
"the Hub."
|
| 575 |
+
)
|
| 576 |
+
return
|
| 577 |
+
|
| 578 |
+
def _set_auto_map_in_config(_config):
|
| 579 |
+
module_name = obj.__class__.__module__
|
| 580 |
+
last_module = module_name.split(".")[-1]
|
| 581 |
+
full_name = f"{last_module}.{obj.__class__.__name__}"
|
| 582 |
+
# Special handling for tokenizers
|
| 583 |
+
if "Tokenizer" in full_name:
|
| 584 |
+
slow_tokenizer_class = None
|
| 585 |
+
fast_tokenizer_class = None
|
| 586 |
+
if obj.__class__.__name__.endswith("Fast"):
|
| 587 |
+
# Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
|
| 588 |
+
fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
|
| 589 |
+
if getattr(obj, "slow_tokenizer_class", None) is not None:
|
| 590 |
+
slow_tokenizer = getattr(obj, "slow_tokenizer_class")
|
| 591 |
+
slow_tok_module_name = slow_tokenizer.__module__
|
| 592 |
+
last_slow_tok_module = slow_tok_module_name.split(".")[-1]
|
| 593 |
+
slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
|
| 594 |
+
else:
|
| 595 |
+
# Slow tokenizer: no way to have the fast class
|
| 596 |
+
slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
|
| 597 |
+
|
| 598 |
+
full_name = (slow_tokenizer_class, fast_tokenizer_class)
|
| 599 |
+
|
| 600 |
+
if isinstance(_config, dict):
|
| 601 |
+
auto_map = _config.get("auto_map", {})
|
| 602 |
+
auto_map[obj._auto_class] = full_name
|
| 603 |
+
_config["auto_map"] = auto_map
|
| 604 |
+
elif getattr(_config, "auto_map", None) is not None:
|
| 605 |
+
_config.auto_map[obj._auto_class] = full_name
|
| 606 |
+
else:
|
| 607 |
+
_config.auto_map = {obj._auto_class: full_name}
|
| 608 |
+
|
| 609 |
+
# Add object class to the config auto_map
|
| 610 |
+
if isinstance(config, (list, tuple)):
|
| 611 |
+
for cfg in config:
|
| 612 |
+
_set_auto_map_in_config(cfg)
|
| 613 |
+
elif config is not None:
|
| 614 |
+
_set_auto_map_in_config(config)
|
| 615 |
+
|
| 616 |
+
result = []
|
| 617 |
+
# Copy module file to the output folder.
|
| 618 |
+
object_file = sys.modules[obj.__module__].__file__
|
| 619 |
+
dest_file = Path(folder) / (Path(object_file).name)
|
| 620 |
+
shutil.copy(object_file, dest_file)
|
| 621 |
+
result.append(dest_file)
|
| 622 |
+
|
| 623 |
+
# Gather all relative imports recursively and make sure they are copied as well.
|
| 624 |
+
for needed_file in get_relative_import_files(object_file):
|
| 625 |
+
dest_file = Path(folder) / (Path(needed_file).name)
|
| 626 |
+
shutil.copy(needed_file, dest_file)
|
| 627 |
+
result.append(dest_file)
|
| 628 |
+
|
| 629 |
+
return result
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def _raise_timeout_error(signum, frame):
|
| 633 |
+
raise ValueError(
|
| 634 |
+
"Loading this model requires you to execute custom code contained in the model repository on your local "
|
| 635 |
+
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
TIME_OUT_REMOTE_CODE = 15
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
|
| 643 |
+
if trust_remote_code is None:
|
| 644 |
+
if has_local_code:
|
| 645 |
+
trust_remote_code = False
|
| 646 |
+
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
| 647 |
+
prev_sig_handler = None
|
| 648 |
+
try:
|
| 649 |
+
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
| 650 |
+
signal.alarm(TIME_OUT_REMOTE_CODE)
|
| 651 |
+
while trust_remote_code is None:
|
| 652 |
+
answer = input(
|
| 653 |
+
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
| 654 |
+
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
| 655 |
+
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
| 656 |
+
f"Do you wish to run the custom code? [y/N] "
|
| 657 |
+
)
|
| 658 |
+
if answer.lower() in ["yes", "y", "1"]:
|
| 659 |
+
trust_remote_code = True
|
| 660 |
+
elif answer.lower() in ["no", "n", "0", ""]:
|
| 661 |
+
trust_remote_code = False
|
| 662 |
+
signal.alarm(0)
|
| 663 |
+
except Exception:
|
| 664 |
+
# OS which does not support signal.SIGALRM
|
| 665 |
+
raise ValueError(
|
| 666 |
+
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
| 667 |
+
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
| 668 |
+
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
| 669 |
+
)
|
| 670 |
+
finally:
|
| 671 |
+
if prev_sig_handler is not None:
|
| 672 |
+
signal.signal(signal.SIGALRM, prev_sig_handler)
|
| 673 |
+
signal.alarm(0)
|
| 674 |
+
elif has_remote_code:
|
| 675 |
+
# For the CI which puts the timeout at 0
|
| 676 |
+
_raise_timeout_error(None, None)
|
| 677 |
+
|
| 678 |
+
if has_remote_code and not has_local_code and not trust_remote_code:
|
| 679 |
+
raise ValueError(
|
| 680 |
+
f"Loading {model_name} requires you to execute the configuration file in that"
|
| 681 |
+
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
| 682 |
+
" set the option `trust_remote_code=True` to remove this error."
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
return trust_remote_code
|
.venv/lib/python3.11/site-packages/transformers/feature_extraction_sequence_utils.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Sequence feature extraction class for common feature extractors to preprocess sequences.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from typing import Dict, List, Optional, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
| 24 |
+
from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SequenceFeatureExtractor(FeatureExtractionMixin):
|
| 31 |
+
"""
|
| 32 |
+
This is a general feature extraction class for speech recognition.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
feature_size (`int`):
|
| 36 |
+
The feature dimension of the extracted features.
|
| 37 |
+
sampling_rate (`int`):
|
| 38 |
+
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
|
| 39 |
+
padding_value (`float`):
|
| 40 |
+
The value that is used to fill the padding values / vectors.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
|
| 44 |
+
self.feature_size = feature_size
|
| 45 |
+
self.sampling_rate = sampling_rate
|
| 46 |
+
self.padding_value = padding_value
|
| 47 |
+
|
| 48 |
+
self.padding_side = kwargs.pop("padding_side", "right")
|
| 49 |
+
self.return_attention_mask = kwargs.pop("return_attention_mask", True)
|
| 50 |
+
|
| 51 |
+
super().__init__(**kwargs)
|
| 52 |
+
|
| 53 |
+
def pad(
|
| 54 |
+
self,
|
| 55 |
+
processed_features: Union[
|
| 56 |
+
BatchFeature,
|
| 57 |
+
List[BatchFeature],
|
| 58 |
+
Dict[str, BatchFeature],
|
| 59 |
+
Dict[str, List[BatchFeature]],
|
| 60 |
+
List[Dict[str, BatchFeature]],
|
| 61 |
+
],
|
| 62 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
| 63 |
+
max_length: Optional[int] = None,
|
| 64 |
+
truncation: bool = False,
|
| 65 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 66 |
+
return_attention_mask: Optional[bool] = None,
|
| 67 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 68 |
+
) -> BatchFeature:
|
| 69 |
+
"""
|
| 70 |
+
Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
|
| 71 |
+
max sequence length in the batch.
|
| 72 |
+
|
| 73 |
+
Padding side (left/right) padding values are defined at the feature extractor level (with `self.padding_side`,
|
| 74 |
+
`self.padding_value`)
|
| 75 |
+
|
| 76 |
+
<Tip>
|
| 77 |
+
|
| 78 |
+
If the `processed_features` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
|
| 79 |
+
result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
|
| 80 |
+
PyTorch tensors, you will lose the specific device of your tensors however.
|
| 81 |
+
|
| 82 |
+
</Tip>
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
processed_features ([`BatchFeature`], list of [`BatchFeature`], `Dict[str, List[float]]`, `Dict[str, List[List[float]]` or `List[Dict[str, List[float]]]`):
|
| 86 |
+
Processed inputs. Can represent one input ([`BatchFeature`] or `Dict[str, List[float]]`) or a batch of
|
| 87 |
+
input values / vectors (list of [`BatchFeature`], *Dict[str, List[List[float]]]* or *List[Dict[str,
|
| 88 |
+
List[float]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
|
| 89 |
+
collate function.
|
| 90 |
+
|
| 91 |
+
Instead of `List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
|
| 92 |
+
see the note above for the return type.
|
| 93 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 94 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 95 |
+
index) among:
|
| 96 |
+
|
| 97 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 98 |
+
sequence if provided).
|
| 99 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 100 |
+
acceptable input length for the model if that argument is not provided.
|
| 101 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
| 102 |
+
lengths).
|
| 103 |
+
max_length (`int`, *optional*):
|
| 104 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 105 |
+
truncation (`bool`):
|
| 106 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
| 107 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 108 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 109 |
+
|
| 110 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 111 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
|
| 112 |
+
return_attention_mask (`bool`, *optional*):
|
| 113 |
+
Whether to return the attention mask. If left to the default, will return the attention mask according
|
| 114 |
+
to the specific feature_extractor's default.
|
| 115 |
+
|
| 116 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 117 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 118 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
| 119 |
+
|
| 120 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 121 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 122 |
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
| 123 |
+
"""
|
| 124 |
+
# If we have a list of dicts, let's convert it in a dict of lists
|
| 125 |
+
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
| 126 |
+
if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
|
| 127 |
+
processed_features = {
|
| 128 |
+
key: [example[key] for example in processed_features] for key in processed_features[0].keys()
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
# The model's main input name, usually `input_values`, has be passed for padding
|
| 132 |
+
if self.model_input_names[0] not in processed_features:
|
| 133 |
+
raise ValueError(
|
| 134 |
+
"You should supply an instance of `transformers.BatchFeature` or list of `transformers.BatchFeature`"
|
| 135 |
+
f" to this method that includes {self.model_input_names[0]}, but you provided"
|
| 136 |
+
f" {list(processed_features.keys())}"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
required_input = processed_features[self.model_input_names[0]]
|
| 140 |
+
return_attention_mask = (
|
| 141 |
+
return_attention_mask if return_attention_mask is not None else self.return_attention_mask
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if len(required_input) == 0:
|
| 145 |
+
if return_attention_mask:
|
| 146 |
+
processed_features["attention_mask"] = []
|
| 147 |
+
return processed_features
|
| 148 |
+
|
| 149 |
+
# If we have PyTorch/TF tensors or lists as inputs, we cast them as Numpy arrays
|
| 150 |
+
# and rebuild them afterwards if no return_tensors is specified
|
| 151 |
+
# Note that we lose the specific device the tensor may be on for PyTorch
|
| 152 |
+
|
| 153 |
+
first_element = required_input[0]
|
| 154 |
+
if isinstance(first_element, (list, tuple)):
|
| 155 |
+
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
|
| 156 |
+
index = 0
|
| 157 |
+
while len(required_input[index]) == 0:
|
| 158 |
+
index += 1
|
| 159 |
+
if index < len(required_input):
|
| 160 |
+
first_element = required_input[index][0]
|
| 161 |
+
|
| 162 |
+
if return_tensors is None:
|
| 163 |
+
if is_tf_tensor(first_element):
|
| 164 |
+
return_tensors = "tf"
|
| 165 |
+
elif is_torch_tensor(first_element):
|
| 166 |
+
return_tensors = "pt"
|
| 167 |
+
elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
|
| 168 |
+
return_tensors = "np"
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(
|
| 171 |
+
f"type of {first_element} unknown: {type(first_element)}. "
|
| 172 |
+
"Should be one of a python, numpy, pytorch or tensorflow object."
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
for key, value in processed_features.items():
|
| 176 |
+
if isinstance(value[0], (int, float)):
|
| 177 |
+
processed_features[key] = to_numpy(value)
|
| 178 |
+
else:
|
| 179 |
+
processed_features[key] = [to_numpy(v) for v in value]
|
| 180 |
+
|
| 181 |
+
# Convert padding_strategy in PaddingStrategy
|
| 182 |
+
padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
|
| 183 |
+
|
| 184 |
+
required_input = processed_features[self.model_input_names[0]]
|
| 185 |
+
|
| 186 |
+
batch_size = len(required_input)
|
| 187 |
+
if not all(len(v) == batch_size for v in processed_features.values()):
|
| 188 |
+
raise ValueError("Some items in the output dictionary have a different batch size than others.")
|
| 189 |
+
|
| 190 |
+
truncated_inputs = []
|
| 191 |
+
for i in range(batch_size):
|
| 192 |
+
inputs = {k: v[i] for k, v in processed_features.items()}
|
| 193 |
+
# truncation
|
| 194 |
+
inputs_slice = self._truncate(
|
| 195 |
+
inputs,
|
| 196 |
+
max_length=max_length,
|
| 197 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 198 |
+
truncation=truncation,
|
| 199 |
+
)
|
| 200 |
+
truncated_inputs.append(inputs_slice)
|
| 201 |
+
|
| 202 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
| 203 |
+
# make sure that `max_length` cannot be longer than the longest truncated length
|
| 204 |
+
max_length = max(len(input_slice[self.model_input_names[0]]) for input_slice in truncated_inputs)
|
| 205 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH
|
| 206 |
+
|
| 207 |
+
batch_outputs = {}
|
| 208 |
+
for i in range(batch_size):
|
| 209 |
+
# padding
|
| 210 |
+
outputs = self._pad(
|
| 211 |
+
truncated_inputs[i],
|
| 212 |
+
max_length=max_length,
|
| 213 |
+
padding_strategy=padding_strategy,
|
| 214 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 215 |
+
return_attention_mask=return_attention_mask,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
for key, value in outputs.items():
|
| 219 |
+
if key not in batch_outputs:
|
| 220 |
+
batch_outputs[key] = []
|
| 221 |
+
if value.dtype is np.dtype(np.float64):
|
| 222 |
+
value = value.astype(np.float32)
|
| 223 |
+
batch_outputs[key].append(value)
|
| 224 |
+
|
| 225 |
+
return BatchFeature(batch_outputs, tensor_type=return_tensors)
|
| 226 |
+
|
| 227 |
+
def _pad(
|
| 228 |
+
self,
|
| 229 |
+
processed_features: Union[Dict[str, np.ndarray], BatchFeature],
|
| 230 |
+
max_length: Optional[int] = None,
|
| 231 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 232 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 233 |
+
return_attention_mask: Optional[bool] = None,
|
| 234 |
+
) -> dict:
|
| 235 |
+
"""
|
| 236 |
+
Pad inputs (on left/right and up to predefined length or max length in the batch)
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
processed_features (`Union[Dict[str, np.ndarray], BatchFeature]`):
|
| 240 |
+
Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
|
| 241 |
+
of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
|
| 242 |
+
max_length (`int`, *optional*):
|
| 243 |
+
Maximum length of the returned list and optionally padding length (see below)
|
| 244 |
+
padding_strategy (`PaddingStrategy`, *optional*, default to `PaddingStrategy.DO_NOT_PAD`):
|
| 245 |
+
PaddingStrategy to use for padding.
|
| 246 |
+
|
| 247 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
| 248 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
| 249 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
| 250 |
+
The feature_extractor padding sides are defined in self.padding_side:
|
| 251 |
+
|
| 252 |
+
- 'left': pads on the left of the sequences
|
| 253 |
+
- 'right': pads on the right of the sequences
|
| 254 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 255 |
+
Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
|
| 256 |
+
enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
|
| 257 |
+
which benefit from having sequence lengths be a multiple of 128.
|
| 258 |
+
return_attention_mask (`bool`, *optional*):
|
| 259 |
+
Set to False to avoid returning attention mask (default: set to model specifics)
|
| 260 |
+
"""
|
| 261 |
+
required_input = processed_features[self.model_input_names[0]]
|
| 262 |
+
|
| 263 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
| 264 |
+
max_length = len(required_input)
|
| 265 |
+
|
| 266 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 267 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 268 |
+
|
| 269 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) < max_length
|
| 270 |
+
|
| 271 |
+
if return_attention_mask and "attention_mask" not in processed_features:
|
| 272 |
+
processed_features["attention_mask"] = np.ones(len(required_input), dtype=np.int32)
|
| 273 |
+
|
| 274 |
+
if needs_to_be_padded:
|
| 275 |
+
difference = max_length - len(required_input)
|
| 276 |
+
if self.padding_side == "right":
|
| 277 |
+
if return_attention_mask:
|
| 278 |
+
processed_features["attention_mask"] = np.pad(
|
| 279 |
+
processed_features["attention_mask"], (0, difference)
|
| 280 |
+
)
|
| 281 |
+
padding_shape = ((0, difference), (0, 0)) if self.feature_size > 1 else (0, difference)
|
| 282 |
+
processed_features[self.model_input_names[0]] = np.pad(
|
| 283 |
+
required_input, padding_shape, "constant", constant_values=self.padding_value
|
| 284 |
+
)
|
| 285 |
+
elif self.padding_side == "left":
|
| 286 |
+
if return_attention_mask:
|
| 287 |
+
processed_features["attention_mask"] = np.pad(
|
| 288 |
+
processed_features["attention_mask"], (difference, 0)
|
| 289 |
+
)
|
| 290 |
+
padding_shape = ((difference, 0), (0, 0)) if self.feature_size > 1 else (difference, 0)
|
| 291 |
+
processed_features[self.model_input_names[0]] = np.pad(
|
| 292 |
+
required_input, padding_shape, "constant", constant_values=self.padding_value
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
| 296 |
+
|
| 297 |
+
return processed_features
|
| 298 |
+
|
| 299 |
+
def _truncate(
|
| 300 |
+
self,
|
| 301 |
+
processed_features: Union[Dict[str, np.ndarray], BatchFeature],
|
| 302 |
+
max_length: Optional[int] = None,
|
| 303 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 304 |
+
truncation: Optional[bool] = None,
|
| 305 |
+
):
|
| 306 |
+
"""
|
| 307 |
+
Truncate inputs to predefined length or max length in the batch
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
processed_features(`Union[Dict[str, np.ndarray], BatchFeature]`):
|
| 311 |
+
Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
|
| 312 |
+
of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
|
| 313 |
+
max_length (`int`, *optional*):
|
| 314 |
+
maximum length of the returned list and optionally padding length (see below)
|
| 315 |
+
pad_to_multiple_of (`int`, *optional*) :
|
| 316 |
+
Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
|
| 317 |
+
enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
|
| 318 |
+
which benefit from having sequence lengths be a multiple of 128.
|
| 319 |
+
truncation (`bool`, *optional*):
|
| 320 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
| 321 |
+
"""
|
| 322 |
+
if not truncation:
|
| 323 |
+
return processed_features
|
| 324 |
+
elif truncation and max_length is None:
|
| 325 |
+
raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.")
|
| 326 |
+
|
| 327 |
+
required_input = processed_features[self.model_input_names[0]]
|
| 328 |
+
|
| 329 |
+
# find `max_length` that fits `pad_to_multiple_of`
|
| 330 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 331 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 332 |
+
|
| 333 |
+
needs_to_be_truncated = len(required_input) > max_length
|
| 334 |
+
|
| 335 |
+
if needs_to_be_truncated:
|
| 336 |
+
processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length]
|
| 337 |
+
if "attention_mask" in processed_features:
|
| 338 |
+
processed_features["attention_mask"] = processed_features["attention_mask"][:max_length]
|
| 339 |
+
|
| 340 |
+
return processed_features
|
| 341 |
+
|
| 342 |
+
def _get_padding_strategies(self, padding=False, max_length=None):
|
| 343 |
+
"""
|
| 344 |
+
Find the correct padding strategy
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
# Get padding strategy
|
| 348 |
+
if padding is not False:
|
| 349 |
+
if padding is True:
|
| 350 |
+
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
|
| 351 |
+
elif not isinstance(padding, PaddingStrategy):
|
| 352 |
+
padding_strategy = PaddingStrategy(padding)
|
| 353 |
+
elif isinstance(padding, PaddingStrategy):
|
| 354 |
+
padding_strategy = padding
|
| 355 |
+
else:
|
| 356 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
| 357 |
+
|
| 358 |
+
# Set max length if needed
|
| 359 |
+
if max_length is None:
|
| 360 |
+
if padding_strategy == PaddingStrategy.MAX_LENGTH:
|
| 361 |
+
raise ValueError(
|
| 362 |
+
f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined"
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Test if we have a padding value
|
| 366 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
|
| 367 |
+
raise ValueError(
|
| 368 |
+
"Asking to pad but the feature_extractor does not have a padding value. Please select a value to use"
|
| 369 |
+
" as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
return padding_strategy
|
.venv/lib/python3.11/site-packages/transformers/hf_argparser.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import dataclasses
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import types
|
| 20 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
|
| 21 |
+
from copy import copy
|
| 22 |
+
from enum import Enum
|
| 23 |
+
from inspect import isclass
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints
|
| 26 |
+
|
| 27 |
+
import yaml
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
DataClass = NewType("DataClass", Any)
|
| 31 |
+
DataClassType = NewType("DataClassType", Any)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
| 35 |
+
def string_to_bool(v):
|
| 36 |
+
if isinstance(v, bool):
|
| 37 |
+
return v
|
| 38 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 39 |
+
return True
|
| 40 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 41 |
+
return False
|
| 42 |
+
else:
|
| 43 |
+
raise ArgumentTypeError(
|
| 44 |
+
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def make_choice_type_function(choices: list) -> Callable[[str], Any]:
|
| 49 |
+
"""
|
| 50 |
+
Creates a mapping function from each choices string representation to the actual value. Used to support multiple
|
| 51 |
+
value types for a single argument.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
choices (list): List of choices.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
|
| 58 |
+
"""
|
| 59 |
+
str_to_choice = {str(choice): choice for choice in choices}
|
| 60 |
+
return lambda arg: str_to_choice.get(arg, arg)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def HfArg(
|
| 64 |
+
*,
|
| 65 |
+
aliases: Union[str, List[str]] = None,
|
| 66 |
+
help: str = None,
|
| 67 |
+
default: Any = dataclasses.MISSING,
|
| 68 |
+
default_factory: Callable[[], Any] = dataclasses.MISSING,
|
| 69 |
+
metadata: dict = None,
|
| 70 |
+
**kwargs,
|
| 71 |
+
) -> dataclasses.Field:
|
| 72 |
+
"""Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.
|
| 73 |
+
|
| 74 |
+
Example comparing the use of `HfArg` and `dataclasses.field`:
|
| 75 |
+
```
|
| 76 |
+
@dataclass
|
| 77 |
+
class Args:
|
| 78 |
+
regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
|
| 79 |
+
hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
aliases (Union[str, List[str]], optional):
|
| 84 |
+
Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
|
| 85 |
+
Defaults to None.
|
| 86 |
+
help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
|
| 87 |
+
default (Any, optional):
|
| 88 |
+
Default value for the argument. If not default or default_factory is specified, the argument is required.
|
| 89 |
+
Defaults to dataclasses.MISSING.
|
| 90 |
+
default_factory (Callable[[], Any], optional):
|
| 91 |
+
The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
|
| 92 |
+
default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
|
| 93 |
+
Defaults to dataclasses.MISSING.
|
| 94 |
+
metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Field: A `dataclasses.Field` with the desired properties.
|
| 98 |
+
"""
|
| 99 |
+
if metadata is None:
|
| 100 |
+
# Important, don't use as default param in function signature because dict is mutable and shared across function calls
|
| 101 |
+
metadata = {}
|
| 102 |
+
if aliases is not None:
|
| 103 |
+
metadata["aliases"] = aliases
|
| 104 |
+
if help is not None:
|
| 105 |
+
metadata["help"] = help
|
| 106 |
+
|
| 107 |
+
return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class HfArgumentParser(ArgumentParser):
|
| 111 |
+
"""
|
| 112 |
+
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
|
| 113 |
+
|
| 114 |
+
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
|
| 115 |
+
arguments to the parser after initialization and you'll get the output back after parsing as an additional
|
| 116 |
+
namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
dataclass_types: Iterable[DataClassType]
|
| 120 |
+
|
| 121 |
+
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
|
| 122 |
+
"""
|
| 123 |
+
Args:
|
| 124 |
+
dataclass_types:
|
| 125 |
+
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
|
| 126 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 127 |
+
Passed to `argparse.ArgumentParser()` in the regular way.
|
| 128 |
+
"""
|
| 129 |
+
# To make the default appear when using --help
|
| 130 |
+
if "formatter_class" not in kwargs:
|
| 131 |
+
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
|
| 132 |
+
super().__init__(**kwargs)
|
| 133 |
+
if dataclasses.is_dataclass(dataclass_types):
|
| 134 |
+
dataclass_types = [dataclass_types]
|
| 135 |
+
self.dataclass_types = list(dataclass_types)
|
| 136 |
+
for dtype in self.dataclass_types:
|
| 137 |
+
self._add_dataclass_arguments(dtype)
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
|
| 141 |
+
# Long-option strings are conventionlly separated by hyphens rather
|
| 142 |
+
# than underscores, e.g., "--long-format" rather than "--long_format".
|
| 143 |
+
# Argparse converts hyphens to underscores so that the destination
|
| 144 |
+
# string is a valid attribute name. Hf_argparser should do the same.
|
| 145 |
+
long_options = [f"--{field.name}"]
|
| 146 |
+
if "_" in field.name:
|
| 147 |
+
long_options.append(f"--{field.name.replace('_', '-')}")
|
| 148 |
+
|
| 149 |
+
kwargs = field.metadata.copy()
|
| 150 |
+
# field.metadata is not used at all by Data Classes,
|
| 151 |
+
# it is provided as a third-party extension mechanism.
|
| 152 |
+
if isinstance(field.type, str):
|
| 153 |
+
raise RuntimeError(
|
| 154 |
+
"Unresolved type detected, which should have been done with the help of "
|
| 155 |
+
"`typing.get_type_hints` method by default"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
aliases = kwargs.pop("aliases", [])
|
| 159 |
+
if isinstance(aliases, str):
|
| 160 |
+
aliases = [aliases]
|
| 161 |
+
|
| 162 |
+
origin_type = getattr(field.type, "__origin__", field.type)
|
| 163 |
+
if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)):
|
| 164 |
+
if str not in field.type.__args__ and (
|
| 165 |
+
len(field.type.__args__) != 2 or type(None) not in field.type.__args__
|
| 166 |
+
):
|
| 167 |
+
raise ValueError(
|
| 168 |
+
"Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because"
|
| 169 |
+
" the argument parser only supports one type per argument."
|
| 170 |
+
f" Problem encountered in field '{field.name}'."
|
| 171 |
+
)
|
| 172 |
+
if type(None) not in field.type.__args__:
|
| 173 |
+
# filter `str` in Union
|
| 174 |
+
field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1]
|
| 175 |
+
origin_type = getattr(field.type, "__origin__", field.type)
|
| 176 |
+
elif bool not in field.type.__args__:
|
| 177 |
+
# filter `NoneType` in Union (except for `Union[bool, NoneType]`)
|
| 178 |
+
field.type = (
|
| 179 |
+
field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
|
| 180 |
+
)
|
| 181 |
+
origin_type = getattr(field.type, "__origin__", field.type)
|
| 182 |
+
|
| 183 |
+
# A variable to store kwargs for a boolean field, if needed
|
| 184 |
+
# so that we can init a `no_*` complement argument (see below)
|
| 185 |
+
bool_kwargs = {}
|
| 186 |
+
if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)):
|
| 187 |
+
if origin_type is Literal:
|
| 188 |
+
kwargs["choices"] = field.type.__args__
|
| 189 |
+
else:
|
| 190 |
+
kwargs["choices"] = [x.value for x in field.type]
|
| 191 |
+
|
| 192 |
+
kwargs["type"] = make_choice_type_function(kwargs["choices"])
|
| 193 |
+
|
| 194 |
+
if field.default is not dataclasses.MISSING:
|
| 195 |
+
kwargs["default"] = field.default
|
| 196 |
+
else:
|
| 197 |
+
kwargs["required"] = True
|
| 198 |
+
elif field.type is bool or field.type == Optional[bool]:
|
| 199 |
+
# Copy the currect kwargs to use to instantiate a `no_*` complement argument below.
|
| 200 |
+
# We do not initialize it here because the `no_*` alternative must be instantiated after the real argument
|
| 201 |
+
bool_kwargs = copy(kwargs)
|
| 202 |
+
|
| 203 |
+
# Hack because type=bool in argparse does not behave as we want.
|
| 204 |
+
kwargs["type"] = string_to_bool
|
| 205 |
+
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
| 206 |
+
# Default value is False if we have no default when of type bool.
|
| 207 |
+
default = False if field.default is dataclasses.MISSING else field.default
|
| 208 |
+
# This is the value that will get picked if we don't include --{field.name} in any way
|
| 209 |
+
kwargs["default"] = default
|
| 210 |
+
# This tells argparse we accept 0 or 1 value after --{field.name}
|
| 211 |
+
kwargs["nargs"] = "?"
|
| 212 |
+
# This is the value that will get picked if we do --{field.name} (without value)
|
| 213 |
+
kwargs["const"] = True
|
| 214 |
+
elif isclass(origin_type) and issubclass(origin_type, list):
|
| 215 |
+
kwargs["type"] = field.type.__args__[0]
|
| 216 |
+
kwargs["nargs"] = "+"
|
| 217 |
+
if field.default_factory is not dataclasses.MISSING:
|
| 218 |
+
kwargs["default"] = field.default_factory()
|
| 219 |
+
elif field.default is dataclasses.MISSING:
|
| 220 |
+
kwargs["required"] = True
|
| 221 |
+
else:
|
| 222 |
+
kwargs["type"] = field.type
|
| 223 |
+
if field.default is not dataclasses.MISSING:
|
| 224 |
+
kwargs["default"] = field.default
|
| 225 |
+
elif field.default_factory is not dataclasses.MISSING:
|
| 226 |
+
kwargs["default"] = field.default_factory()
|
| 227 |
+
else:
|
| 228 |
+
kwargs["required"] = True
|
| 229 |
+
parser.add_argument(*long_options, *aliases, **kwargs)
|
| 230 |
+
|
| 231 |
+
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
|
| 232 |
+
# Order is important for arguments with the same destination!
|
| 233 |
+
# We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down
|
| 234 |
+
# here and we do not need those changes/additional keys.
|
| 235 |
+
if field.default is True and (field.type is bool or field.type == Optional[bool]):
|
| 236 |
+
bool_kwargs["default"] = False
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
f"--no_{field.name}",
|
| 239 |
+
f"--no-{field.name.replace('_', '-')}",
|
| 240 |
+
action="store_false",
|
| 241 |
+
dest=field.name,
|
| 242 |
+
**bool_kwargs,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def _add_dataclass_arguments(self, dtype: DataClassType):
|
| 246 |
+
if hasattr(dtype, "_argument_group_name"):
|
| 247 |
+
parser = self.add_argument_group(dtype._argument_group_name)
|
| 248 |
+
else:
|
| 249 |
+
parser = self
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
type_hints: Dict[str, type] = get_type_hints(dtype)
|
| 253 |
+
except NameError:
|
| 254 |
+
raise RuntimeError(
|
| 255 |
+
f"Type resolution failed for {dtype}. Try declaring the class in global scope or "
|
| 256 |
+
"removing line of `from __future__ import annotations` which opts in Postponed "
|
| 257 |
+
"Evaluation of Annotations (PEP 563)"
|
| 258 |
+
)
|
| 259 |
+
except TypeError as ex:
|
| 260 |
+
# Remove this block when we drop Python 3.9 support
|
| 261 |
+
if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex):
|
| 262 |
+
python_version = ".".join(map(str, sys.version_info[:3]))
|
| 263 |
+
raise RuntimeError(
|
| 264 |
+
f"Type resolution failed for {dtype} on Python {python_version}. Try removing "
|
| 265 |
+
"line of `from __future__ import annotations` which opts in union types as "
|
| 266 |
+
"`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To "
|
| 267 |
+
"support Python versions that lower than 3.10, you need to use "
|
| 268 |
+
"`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of "
|
| 269 |
+
"`X | None`."
|
| 270 |
+
) from ex
|
| 271 |
+
raise
|
| 272 |
+
|
| 273 |
+
for field in dataclasses.fields(dtype):
|
| 274 |
+
if not field.init:
|
| 275 |
+
continue
|
| 276 |
+
field.type = type_hints[field.name]
|
| 277 |
+
self._parse_dataclass_field(parser, field)
|
| 278 |
+
|
| 279 |
+
def parse_args_into_dataclasses(
|
| 280 |
+
self,
|
| 281 |
+
args=None,
|
| 282 |
+
return_remaining_strings=False,
|
| 283 |
+
look_for_args_file=True,
|
| 284 |
+
args_filename=None,
|
| 285 |
+
args_file_flag=None,
|
| 286 |
+
) -> Tuple[DataClass, ...]:
|
| 287 |
+
"""
|
| 288 |
+
Parse command-line args into instances of the specified dataclass types.
|
| 289 |
+
|
| 290 |
+
This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
|
| 291 |
+
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
args:
|
| 295 |
+
List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
|
| 296 |
+
return_remaining_strings:
|
| 297 |
+
If true, also return a list of remaining argument strings.
|
| 298 |
+
look_for_args_file:
|
| 299 |
+
If true, will look for a ".args" file with the same base name as the entry point script for this
|
| 300 |
+
process, and will append its potential content to the command line args.
|
| 301 |
+
args_filename:
|
| 302 |
+
If not None, will uses this file instead of the ".args" file specified in the previous argument.
|
| 303 |
+
args_file_flag:
|
| 304 |
+
If not None, will look for a file in the command-line args specified with this flag. The flag can be
|
| 305 |
+
specified multiple times and precedence is determined by the order (last one wins).
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Tuple consisting of:
|
| 309 |
+
|
| 310 |
+
- the dataclass instances in the same order as they were passed to the initializer.abspath
|
| 311 |
+
- if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
|
| 312 |
+
after initialization.
|
| 313 |
+
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)):
|
| 317 |
+
args_files = []
|
| 318 |
+
|
| 319 |
+
if args_filename:
|
| 320 |
+
args_files.append(Path(args_filename))
|
| 321 |
+
elif look_for_args_file and len(sys.argv):
|
| 322 |
+
args_files.append(Path(sys.argv[0]).with_suffix(".args"))
|
| 323 |
+
|
| 324 |
+
# args files specified via command line flag should overwrite default args files so we add them last
|
| 325 |
+
if args_file_flag:
|
| 326 |
+
# Create special parser just to extract the args_file_flag values
|
| 327 |
+
args_file_parser = ArgumentParser()
|
| 328 |
+
args_file_parser.add_argument(args_file_flag, type=str, action="append")
|
| 329 |
+
|
| 330 |
+
# Use only remaining args for further parsing (remove the args_file_flag)
|
| 331 |
+
cfg, args = args_file_parser.parse_known_args(args=args)
|
| 332 |
+
cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None)
|
| 333 |
+
|
| 334 |
+
if cmd_args_file_paths:
|
| 335 |
+
args_files.extend([Path(p) for p in cmd_args_file_paths])
|
| 336 |
+
|
| 337 |
+
file_args = []
|
| 338 |
+
for args_file in args_files:
|
| 339 |
+
if args_file.exists():
|
| 340 |
+
file_args += args_file.read_text().split()
|
| 341 |
+
|
| 342 |
+
# in case of duplicate arguments the last one has precedence
|
| 343 |
+
# args specified via the command line should overwrite args from files, so we add them last
|
| 344 |
+
args = file_args + args if args is not None else file_args + sys.argv[1:]
|
| 345 |
+
namespace, remaining_args = self.parse_known_args(args=args)
|
| 346 |
+
outputs = []
|
| 347 |
+
for dtype in self.dataclass_types:
|
| 348 |
+
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
| 349 |
+
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
|
| 350 |
+
for k in keys:
|
| 351 |
+
delattr(namespace, k)
|
| 352 |
+
obj = dtype(**inputs)
|
| 353 |
+
outputs.append(obj)
|
| 354 |
+
if len(namespace.__dict__) > 0:
|
| 355 |
+
# additional namespace.
|
| 356 |
+
outputs.append(namespace)
|
| 357 |
+
if return_remaining_strings:
|
| 358 |
+
return (*outputs, remaining_args)
|
| 359 |
+
else:
|
| 360 |
+
if remaining_args:
|
| 361 |
+
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
|
| 362 |
+
|
| 363 |
+
return (*outputs,)
|
| 364 |
+
|
| 365 |
+
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
|
| 366 |
+
"""
|
| 367 |
+
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
|
| 368 |
+
types.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
args (`dict`):
|
| 372 |
+
dict containing config values
|
| 373 |
+
allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
| 374 |
+
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
Tuple consisting of:
|
| 378 |
+
|
| 379 |
+
- the dataclass instances in the same order as they were passed to the initializer.
|
| 380 |
+
"""
|
| 381 |
+
unused_keys = set(args.keys())
|
| 382 |
+
outputs = []
|
| 383 |
+
for dtype in self.dataclass_types:
|
| 384 |
+
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
| 385 |
+
inputs = {k: v for k, v in args.items() if k in keys}
|
| 386 |
+
unused_keys.difference_update(inputs.keys())
|
| 387 |
+
obj = dtype(**inputs)
|
| 388 |
+
outputs.append(obj)
|
| 389 |
+
if not allow_extra_keys and unused_keys:
|
| 390 |
+
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
|
| 391 |
+
return tuple(outputs)
|
| 392 |
+
|
| 393 |
+
def parse_json_file(
|
| 394 |
+
self, json_file: Union[str, os.PathLike], allow_extra_keys: bool = False
|
| 395 |
+
) -> Tuple[DataClass, ...]:
|
| 396 |
+
"""
|
| 397 |
+
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
| 398 |
+
dataclass types.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
json_file (`str` or `os.PathLike`):
|
| 402 |
+
File name of the json file to parse
|
| 403 |
+
allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
| 404 |
+
Defaults to False. If False, will raise an exception if the json file contains keys that are not
|
| 405 |
+
parsed.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
Tuple consisting of:
|
| 409 |
+
|
| 410 |
+
- the dataclass instances in the same order as they were passed to the initializer.
|
| 411 |
+
"""
|
| 412 |
+
with open(Path(json_file), encoding="utf-8") as open_json_file:
|
| 413 |
+
data = json.loads(open_json_file.read())
|
| 414 |
+
outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
|
| 415 |
+
return tuple(outputs)
|
| 416 |
+
|
| 417 |
+
def parse_yaml_file(
|
| 418 |
+
self, yaml_file: Union[str, os.PathLike], allow_extra_keys: bool = False
|
| 419 |
+
) -> Tuple[DataClass, ...]:
|
| 420 |
+
"""
|
| 421 |
+
Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
|
| 422 |
+
dataclass types.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
yaml_file (`str` or `os.PathLike`):
|
| 426 |
+
File name of the yaml file to parse
|
| 427 |
+
allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
| 428 |
+
Defaults to False. If False, will raise an exception if the json file contains keys that are not
|
| 429 |
+
parsed.
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
Tuple consisting of:
|
| 433 |
+
|
| 434 |
+
- the dataclass instances in the same order as they were passed to the initializer.
|
| 435 |
+
"""
|
| 436 |
+
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
|
| 437 |
+
return tuple(outputs)
|
.venv/lib/python3.11/site-packages/transformers/hyperparameter_search.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from .integrations import (
|
| 17 |
+
is_optuna_available,
|
| 18 |
+
is_ray_tune_available,
|
| 19 |
+
is_sigopt_available,
|
| 20 |
+
is_wandb_available,
|
| 21 |
+
run_hp_search_optuna,
|
| 22 |
+
run_hp_search_ray,
|
| 23 |
+
run_hp_search_sigopt,
|
| 24 |
+
run_hp_search_wandb,
|
| 25 |
+
)
|
| 26 |
+
from .trainer_utils import (
|
| 27 |
+
HPSearchBackend,
|
| 28 |
+
default_hp_space_optuna,
|
| 29 |
+
default_hp_space_ray,
|
| 30 |
+
default_hp_space_sigopt,
|
| 31 |
+
default_hp_space_wandb,
|
| 32 |
+
)
|
| 33 |
+
from .utils import logging
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class HyperParamSearchBackendBase:
|
| 40 |
+
name: str
|
| 41 |
+
pip_package: str = None
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def is_available():
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
|
| 47 |
+
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
| 48 |
+
raise NotImplementedError
|
| 49 |
+
|
| 50 |
+
def default_hp_space(self, trial):
|
| 51 |
+
raise NotImplementedError
|
| 52 |
+
|
| 53 |
+
def ensure_available(self):
|
| 54 |
+
if not self.is_available():
|
| 55 |
+
raise RuntimeError(
|
| 56 |
+
f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def pip_install(cls):
|
| 61 |
+
return f"`pip install {cls.pip_package or cls.name}`"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class OptunaBackend(HyperParamSearchBackendBase):
|
| 65 |
+
name = "optuna"
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def is_available():
|
| 69 |
+
return is_optuna_available()
|
| 70 |
+
|
| 71 |
+
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
| 72 |
+
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
|
| 73 |
+
|
| 74 |
+
def default_hp_space(self, trial):
|
| 75 |
+
return default_hp_space_optuna(trial)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class RayTuneBackend(HyperParamSearchBackendBase):
|
| 79 |
+
name = "ray"
|
| 80 |
+
pip_package = "'ray[tune]'"
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def is_available():
|
| 84 |
+
return is_ray_tune_available()
|
| 85 |
+
|
| 86 |
+
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
| 87 |
+
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
|
| 88 |
+
|
| 89 |
+
def default_hp_space(self, trial):
|
| 90 |
+
return default_hp_space_ray(trial)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SigOptBackend(HyperParamSearchBackendBase):
|
| 94 |
+
name = "sigopt"
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def is_available():
|
| 98 |
+
return is_sigopt_available()
|
| 99 |
+
|
| 100 |
+
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
| 101 |
+
return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
|
| 102 |
+
|
| 103 |
+
def default_hp_space(self, trial):
|
| 104 |
+
return default_hp_space_sigopt(trial)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class WandbBackend(HyperParamSearchBackendBase):
|
| 108 |
+
name = "wandb"
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def is_available():
|
| 112 |
+
return is_wandb_available()
|
| 113 |
+
|
| 114 |
+
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
| 115 |
+
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
|
| 116 |
+
|
| 117 |
+
def default_hp_space(self, trial):
|
| 118 |
+
return default_hp_space_wandb(trial)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
|
| 122 |
+
HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def default_hp_search_backend() -> str:
|
| 127 |
+
available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
|
| 128 |
+
if len(available_backends) > 0:
|
| 129 |
+
name = available_backends[0].name
|
| 130 |
+
if len(available_backends) > 1:
|
| 131 |
+
logger.info(
|
| 132 |
+
f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
|
| 133 |
+
)
|
| 134 |
+
return name
|
| 135 |
+
raise RuntimeError(
|
| 136 |
+
"No hyperparameter search backend available.\n"
|
| 137 |
+
+ "\n".join(
|
| 138 |
+
f" - To install {backend.name} run {backend.pip_install()}"
|
| 139 |
+
for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
|
| 140 |
+
)
|
| 141 |
+
)
|
.venv/lib/python3.11/site-packages/transformers/image_processing_base.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from io import BytesIO
|
| 22 |
+
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import requests
|
| 26 |
+
|
| 27 |
+
from .dynamic_module_utils import custom_object_save
|
| 28 |
+
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
|
| 29 |
+
from .utils import (
|
| 30 |
+
IMAGE_PROCESSOR_NAME,
|
| 31 |
+
PushToHubMixin,
|
| 32 |
+
add_model_info_to_auto_map,
|
| 33 |
+
add_model_info_to_custom_pipelines,
|
| 34 |
+
cached_file,
|
| 35 |
+
copy_func,
|
| 36 |
+
download_url,
|
| 37 |
+
is_offline_mode,
|
| 38 |
+
is_remote_url,
|
| 39 |
+
is_vision_available,
|
| 40 |
+
logging,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_vision_available():
|
| 45 |
+
from PIL import Image
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
logger = logging.get_logger(__name__)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
|
| 55 |
+
# We override the class string here, but logic is the same.
|
| 56 |
+
class BatchFeature(BaseBatchFeature):
|
| 57 |
+
r"""
|
| 58 |
+
Holds the output of the image processor specific `__call__` methods.
|
| 59 |
+
|
| 60 |
+
This class is derived from a python dictionary and can be used as a dictionary.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
data (`dict`):
|
| 64 |
+
Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
|
| 65 |
+
tensor_type (`Union[None, str, TensorType]`, *optional*):
|
| 66 |
+
You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
|
| 67 |
+
initialization.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# TODO: (Amy) - factor out the common parts of this and the feature extractor
|
| 72 |
+
class ImageProcessingMixin(PushToHubMixin):
|
| 73 |
+
"""
|
| 74 |
+
This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
|
| 75 |
+
extractors.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
_auto_class = None
|
| 79 |
+
|
| 80 |
+
def __init__(self, **kwargs):
|
| 81 |
+
"""Set elements of `kwargs` as attributes."""
|
| 82 |
+
# This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
|
| 83 |
+
# `XXXImageProcessor`, this attribute and its value are misleading.
|
| 84 |
+
kwargs.pop("feature_extractor_type", None)
|
| 85 |
+
# Pop "processor_class" as it should be saved as private attribute
|
| 86 |
+
self._processor_class = kwargs.pop("processor_class", None)
|
| 87 |
+
# Additional attributes without default values
|
| 88 |
+
for key, value in kwargs.items():
|
| 89 |
+
try:
|
| 90 |
+
setattr(self, key, value)
|
| 91 |
+
except AttributeError as err:
|
| 92 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
| 93 |
+
raise err
|
| 94 |
+
|
| 95 |
+
def _set_processor_class(self, processor_class: str):
|
| 96 |
+
"""Sets processor class as an attribute."""
|
| 97 |
+
self._processor_class = processor_class
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def from_pretrained(
|
| 101 |
+
cls: Type[ImageProcessorType],
|
| 102 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 103 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 104 |
+
force_download: bool = False,
|
| 105 |
+
local_files_only: bool = False,
|
| 106 |
+
token: Optional[Union[str, bool]] = None,
|
| 107 |
+
revision: str = "main",
|
| 108 |
+
**kwargs,
|
| 109 |
+
) -> ImageProcessorType:
|
| 110 |
+
r"""
|
| 111 |
+
Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 115 |
+
This can be either:
|
| 116 |
+
|
| 117 |
+
- a string, the *model id* of a pretrained image_processor hosted inside a model repo on
|
| 118 |
+
huggingface.co.
|
| 119 |
+
- a path to a *directory* containing a image processor file saved using the
|
| 120 |
+
[`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
|
| 121 |
+
`./my_model_directory/`.
|
| 122 |
+
- a path or url to a saved image processor JSON *file*, e.g.,
|
| 123 |
+
`./my_model_directory/preprocessor_config.json`.
|
| 124 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 125 |
+
Path to a directory in which a downloaded pretrained model image processor should be cached if the
|
| 126 |
+
standard cache should not be used.
|
| 127 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 128 |
+
Whether or not to force to (re-)download the image processor files and override the cached versions if
|
| 129 |
+
they exist.
|
| 130 |
+
resume_download:
|
| 131 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 132 |
+
Will be removed in v5 of Transformers.
|
| 133 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 134 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 135 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 136 |
+
token (`str` or `bool`, *optional*):
|
| 137 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
| 138 |
+
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
| 139 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 140 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 141 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 142 |
+
identifier allowed by git.
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
<Tip>
|
| 146 |
+
|
| 147 |
+
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
|
| 148 |
+
|
| 149 |
+
</Tip>
|
| 150 |
+
|
| 151 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 152 |
+
If `False`, then this function returns just the final image processor object. If `True`, then this
|
| 153 |
+
functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
| 154 |
+
consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
|
| 155 |
+
`kwargs` which has not been used to update `image_processor` and is otherwise ignored.
|
| 156 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 157 |
+
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
| 158 |
+
specify the folder name here.
|
| 159 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 160 |
+
The values in kwargs of any keys which are image processor attributes will be used to override the
|
| 161 |
+
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
|
| 162 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
A image processor of type [`~image_processing_utils.ImageProcessingMixin`].
|
| 166 |
+
|
| 167 |
+
Examples:
|
| 168 |
+
|
| 169 |
+
```python
|
| 170 |
+
# We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a
|
| 171 |
+
# derived class: *CLIPImageProcessor*
|
| 172 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
| 173 |
+
"openai/clip-vit-base-patch32"
|
| 174 |
+
) # Download image_processing_config from huggingface.co and cache.
|
| 175 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
| 176 |
+
"./test/saved_model/"
|
| 177 |
+
) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')*
|
| 178 |
+
image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
|
| 179 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
| 180 |
+
"openai/clip-vit-base-patch32", do_normalize=False, foo=False
|
| 181 |
+
)
|
| 182 |
+
assert image_processor.do_normalize is False
|
| 183 |
+
image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained(
|
| 184 |
+
"openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True
|
| 185 |
+
)
|
| 186 |
+
assert image_processor.do_normalize is False
|
| 187 |
+
assert unused_kwargs == {"foo": False}
|
| 188 |
+
```"""
|
| 189 |
+
kwargs["cache_dir"] = cache_dir
|
| 190 |
+
kwargs["force_download"] = force_download
|
| 191 |
+
kwargs["local_files_only"] = local_files_only
|
| 192 |
+
kwargs["revision"] = revision
|
| 193 |
+
|
| 194 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 195 |
+
if use_auth_token is not None:
|
| 196 |
+
warnings.warn(
|
| 197 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 198 |
+
FutureWarning,
|
| 199 |
+
)
|
| 200 |
+
if token is not None:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 203 |
+
)
|
| 204 |
+
token = use_auth_token
|
| 205 |
+
|
| 206 |
+
if token is not None:
|
| 207 |
+
kwargs["token"] = token
|
| 208 |
+
|
| 209 |
+
image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
| 210 |
+
|
| 211 |
+
return cls.from_dict(image_processor_dict, **kwargs)
|
| 212 |
+
|
| 213 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
| 214 |
+
"""
|
| 215 |
+
Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
|
| 216 |
+
[`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
save_directory (`str` or `os.PathLike`):
|
| 220 |
+
Directory where the image processor JSON file will be saved (will be created if it does not exist).
|
| 221 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 222 |
+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
| 223 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 224 |
+
namespace).
|
| 225 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 226 |
+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 227 |
+
"""
|
| 228 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 229 |
+
|
| 230 |
+
if use_auth_token is not None:
|
| 231 |
+
warnings.warn(
|
| 232 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 233 |
+
FutureWarning,
|
| 234 |
+
)
|
| 235 |
+
if kwargs.get("token", None) is not None:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 238 |
+
)
|
| 239 |
+
kwargs["token"] = use_auth_token
|
| 240 |
+
|
| 241 |
+
if os.path.isfile(save_directory):
|
| 242 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 243 |
+
|
| 244 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 245 |
+
|
| 246 |
+
if push_to_hub:
|
| 247 |
+
commit_message = kwargs.pop("commit_message", None)
|
| 248 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
| 249 |
+
repo_id = self._create_repo(repo_id, **kwargs)
|
| 250 |
+
files_timestamps = self._get_files_timestamps(save_directory)
|
| 251 |
+
|
| 252 |
+
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
| 253 |
+
# loaded from the Hub.
|
| 254 |
+
if self._auto_class is not None:
|
| 255 |
+
custom_object_save(self, save_directory, config=self)
|
| 256 |
+
|
| 257 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 258 |
+
output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
|
| 259 |
+
|
| 260 |
+
self.to_json_file(output_image_processor_file)
|
| 261 |
+
logger.info(f"Image processor saved in {output_image_processor_file}")
|
| 262 |
+
|
| 263 |
+
if push_to_hub:
|
| 264 |
+
self._upload_modified_files(
|
| 265 |
+
save_directory,
|
| 266 |
+
repo_id,
|
| 267 |
+
files_timestamps,
|
| 268 |
+
commit_message=commit_message,
|
| 269 |
+
token=kwargs.get("token"),
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
return [output_image_processor_file]
|
| 273 |
+
|
| 274 |
+
@classmethod
|
| 275 |
+
def get_image_processor_dict(
|
| 276 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
| 277 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 278 |
+
"""
|
| 279 |
+
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
| 280 |
+
image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`.
|
| 281 |
+
|
| 282 |
+
Parameters:
|
| 283 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 284 |
+
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
| 285 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 286 |
+
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
| 287 |
+
specify the folder name here.
|
| 288 |
+
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
|
| 289 |
+
The name of the file in the model directory to use for the image processor config.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
|
| 293 |
+
"""
|
| 294 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 295 |
+
force_download = kwargs.pop("force_download", False)
|
| 296 |
+
resume_download = kwargs.pop("resume_download", None)
|
| 297 |
+
proxies = kwargs.pop("proxies", None)
|
| 298 |
+
token = kwargs.pop("token", None)
|
| 299 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 300 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 301 |
+
revision = kwargs.pop("revision", None)
|
| 302 |
+
subfolder = kwargs.pop("subfolder", "")
|
| 303 |
+
image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
|
| 304 |
+
|
| 305 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
| 306 |
+
from_auto_class = kwargs.pop("_from_auto", False)
|
| 307 |
+
|
| 308 |
+
if use_auth_token is not None:
|
| 309 |
+
warnings.warn(
|
| 310 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 311 |
+
FutureWarning,
|
| 312 |
+
)
|
| 313 |
+
if token is not None:
|
| 314 |
+
raise ValueError(
|
| 315 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 316 |
+
)
|
| 317 |
+
token = use_auth_token
|
| 318 |
+
|
| 319 |
+
user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
|
| 320 |
+
if from_pipeline is not None:
|
| 321 |
+
user_agent["using_pipeline"] = from_pipeline
|
| 322 |
+
|
| 323 |
+
if is_offline_mode() and not local_files_only:
|
| 324 |
+
logger.info("Offline mode: forcing local_files_only=True")
|
| 325 |
+
local_files_only = True
|
| 326 |
+
|
| 327 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
| 328 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
| 329 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 330 |
+
image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
|
| 331 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
| 332 |
+
resolved_image_processor_file = pretrained_model_name_or_path
|
| 333 |
+
is_local = True
|
| 334 |
+
elif is_remote_url(pretrained_model_name_or_path):
|
| 335 |
+
image_processor_file = pretrained_model_name_or_path
|
| 336 |
+
resolved_image_processor_file = download_url(pretrained_model_name_or_path)
|
| 337 |
+
else:
|
| 338 |
+
image_processor_file = image_processor_filename
|
| 339 |
+
try:
|
| 340 |
+
# Load from local folder or from cache or download from model Hub and cache
|
| 341 |
+
resolved_image_processor_file = cached_file(
|
| 342 |
+
pretrained_model_name_or_path,
|
| 343 |
+
image_processor_file,
|
| 344 |
+
cache_dir=cache_dir,
|
| 345 |
+
force_download=force_download,
|
| 346 |
+
proxies=proxies,
|
| 347 |
+
resume_download=resume_download,
|
| 348 |
+
local_files_only=local_files_only,
|
| 349 |
+
token=token,
|
| 350 |
+
user_agent=user_agent,
|
| 351 |
+
revision=revision,
|
| 352 |
+
subfolder=subfolder,
|
| 353 |
+
)
|
| 354 |
+
except EnvironmentError:
|
| 355 |
+
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
| 356 |
+
# the original exception.
|
| 357 |
+
raise
|
| 358 |
+
except Exception:
|
| 359 |
+
# For any other exception, we throw a generic error.
|
| 360 |
+
raise EnvironmentError(
|
| 361 |
+
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
| 362 |
+
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
| 363 |
+
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
| 364 |
+
f" directory containing a {image_processor_filename} file"
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
try:
|
| 368 |
+
# Load image_processor dict
|
| 369 |
+
with open(resolved_image_processor_file, "r", encoding="utf-8") as reader:
|
| 370 |
+
text = reader.read()
|
| 371 |
+
image_processor_dict = json.loads(text)
|
| 372 |
+
|
| 373 |
+
except json.JSONDecodeError:
|
| 374 |
+
raise EnvironmentError(
|
| 375 |
+
f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
if is_local:
|
| 379 |
+
logger.info(f"loading configuration file {resolved_image_processor_file}")
|
| 380 |
+
else:
|
| 381 |
+
logger.info(
|
| 382 |
+
f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
|
| 383 |
+
)
|
| 384 |
+
if "auto_map" in image_processor_dict:
|
| 385 |
+
image_processor_dict["auto_map"] = add_model_info_to_auto_map(
|
| 386 |
+
image_processor_dict["auto_map"], pretrained_model_name_or_path
|
| 387 |
+
)
|
| 388 |
+
if "custom_pipelines" in image_processor_dict:
|
| 389 |
+
image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
|
| 390 |
+
image_processor_dict["custom_pipelines"], pretrained_model_name_or_path
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
return image_processor_dict, kwargs
|
| 394 |
+
|
| 395 |
+
@classmethod
|
| 396 |
+
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
| 397 |
+
"""
|
| 398 |
+
Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
image_processor_dict (`Dict[str, Any]`):
|
| 402 |
+
Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
|
| 403 |
+
retrieved from a pretrained checkpoint by leveraging the
|
| 404 |
+
[`~image_processing_utils.ImageProcessingMixin.to_dict`] method.
|
| 405 |
+
kwargs (`Dict[str, Any]`):
|
| 406 |
+
Additional parameters from which to initialize the image processor object.
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
[`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
|
| 410 |
+
parameters.
|
| 411 |
+
"""
|
| 412 |
+
image_processor_dict = image_processor_dict.copy()
|
| 413 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
| 414 |
+
|
| 415 |
+
# The `size` parameter is a dict and was previously an int or tuple in feature extractors.
|
| 416 |
+
# We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
|
| 417 |
+
# dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
|
| 418 |
+
if "size" in kwargs and "size" in image_processor_dict:
|
| 419 |
+
image_processor_dict["size"] = kwargs.pop("size")
|
| 420 |
+
if "crop_size" in kwargs and "crop_size" in image_processor_dict:
|
| 421 |
+
image_processor_dict["crop_size"] = kwargs.pop("crop_size")
|
| 422 |
+
|
| 423 |
+
image_processor = cls(**image_processor_dict)
|
| 424 |
+
|
| 425 |
+
# Update image_processor with kwargs if needed
|
| 426 |
+
to_remove = []
|
| 427 |
+
for key, value in kwargs.items():
|
| 428 |
+
if hasattr(image_processor, key):
|
| 429 |
+
setattr(image_processor, key, value)
|
| 430 |
+
to_remove.append(key)
|
| 431 |
+
for key in to_remove:
|
| 432 |
+
kwargs.pop(key, None)
|
| 433 |
+
|
| 434 |
+
logger.info(f"Image processor {image_processor}")
|
| 435 |
+
if return_unused_kwargs:
|
| 436 |
+
return image_processor, kwargs
|
| 437 |
+
else:
|
| 438 |
+
return image_processor
|
| 439 |
+
|
| 440 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 441 |
+
"""
|
| 442 |
+
Serializes this instance to a Python dictionary.
|
| 443 |
+
|
| 444 |
+
Returns:
|
| 445 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
|
| 446 |
+
"""
|
| 447 |
+
output = copy.deepcopy(self.__dict__)
|
| 448 |
+
output["image_processor_type"] = self.__class__.__name__
|
| 449 |
+
|
| 450 |
+
return output
|
| 451 |
+
|
| 452 |
+
@classmethod
|
| 453 |
+
def from_json_file(cls, json_file: Union[str, os.PathLike]):
|
| 454 |
+
"""
|
| 455 |
+
Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON
|
| 456 |
+
file of parameters.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
json_file (`str` or `os.PathLike`):
|
| 460 |
+
Path to the JSON file containing the parameters.
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
|
| 464 |
+
instantiated from that JSON file.
|
| 465 |
+
"""
|
| 466 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 467 |
+
text = reader.read()
|
| 468 |
+
image_processor_dict = json.loads(text)
|
| 469 |
+
return cls(**image_processor_dict)
|
| 470 |
+
|
| 471 |
+
def to_json_string(self) -> str:
|
| 472 |
+
"""
|
| 473 |
+
Serializes this instance to a JSON string.
|
| 474 |
+
|
| 475 |
+
Returns:
|
| 476 |
+
`str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
|
| 477 |
+
"""
|
| 478 |
+
dictionary = self.to_dict()
|
| 479 |
+
|
| 480 |
+
for key, value in dictionary.items():
|
| 481 |
+
if isinstance(value, np.ndarray):
|
| 482 |
+
dictionary[key] = value.tolist()
|
| 483 |
+
|
| 484 |
+
# make sure private name "_processor_class" is correctly
|
| 485 |
+
# saved as "processor_class"
|
| 486 |
+
_processor_class = dictionary.pop("_processor_class", None)
|
| 487 |
+
if _processor_class is not None:
|
| 488 |
+
dictionary["processor_class"] = _processor_class
|
| 489 |
+
|
| 490 |
+
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
| 491 |
+
|
| 492 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
| 493 |
+
"""
|
| 494 |
+
Save this instance to a JSON file.
|
| 495 |
+
|
| 496 |
+
Args:
|
| 497 |
+
json_file_path (`str` or `os.PathLike`):
|
| 498 |
+
Path to the JSON file in which this image_processor instance's parameters will be saved.
|
| 499 |
+
"""
|
| 500 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 501 |
+
writer.write(self.to_json_string())
|
| 502 |
+
|
| 503 |
+
def __repr__(self):
|
| 504 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
| 505 |
+
|
| 506 |
+
@classmethod
|
| 507 |
+
def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
|
| 508 |
+
"""
|
| 509 |
+
Register this class with a given auto class. This should only be used for custom image processors as the ones
|
| 510 |
+
in the library are already mapped with `AutoImageProcessor `.
|
| 511 |
+
|
| 512 |
+
<Tip warning={true}>
|
| 513 |
+
|
| 514 |
+
This API is experimental and may have some slight breaking changes in the next releases.
|
| 515 |
+
|
| 516 |
+
</Tip>
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
|
| 520 |
+
The auto class to register this new image processor with.
|
| 521 |
+
"""
|
| 522 |
+
if not isinstance(auto_class, str):
|
| 523 |
+
auto_class = auto_class.__name__
|
| 524 |
+
|
| 525 |
+
import transformers.models.auto as auto_module
|
| 526 |
+
|
| 527 |
+
if not hasattr(auto_module, auto_class):
|
| 528 |
+
raise ValueError(f"{auto_class} is not a valid auto class.")
|
| 529 |
+
|
| 530 |
+
cls._auto_class = auto_class
|
| 531 |
+
|
| 532 |
+
def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
|
| 533 |
+
"""
|
| 534 |
+
Convert a single or a list of urls into the corresponding `PIL.Image` objects.
|
| 535 |
+
|
| 536 |
+
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
| 537 |
+
returned.
|
| 538 |
+
"""
|
| 539 |
+
headers = {
|
| 540 |
+
"User-Agent": (
|
| 541 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
|
| 542 |
+
" Safari/537.36"
|
| 543 |
+
)
|
| 544 |
+
}
|
| 545 |
+
if isinstance(image_url_or_urls, list):
|
| 546 |
+
return [self.fetch_images(x) for x in image_url_or_urls]
|
| 547 |
+
elif isinstance(image_url_or_urls, str):
|
| 548 |
+
response = requests.get(image_url_or_urls, stream=True, headers=headers)
|
| 549 |
+
response.raise_for_status()
|
| 550 |
+
return Image.open(BytesIO(response.content))
|
| 551 |
+
else:
|
| 552 |
+
raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
|
| 556 |
+
if ImageProcessingMixin.push_to_hub.__doc__ is not None:
|
| 557 |
+
ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format(
|
| 558 |
+
object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
|
| 559 |
+
)
|
.venv/lib/python3.11/site-packages/transformers/image_processing_utils.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Dict, Iterable, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from .image_processing_base import BatchFeature, ImageProcessingMixin
|
| 21 |
+
from .image_transforms import center_crop, normalize, rescale
|
| 22 |
+
from .image_utils import ChannelDimension
|
| 23 |
+
from .utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
INIT_SERVICE_KWARGS = [
|
| 30 |
+
"processor_class",
|
| 31 |
+
"image_processor_type",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class BaseImageProcessor(ImageProcessingMixin):
|
| 36 |
+
def __init__(self, **kwargs):
|
| 37 |
+
super().__init__(**kwargs)
|
| 38 |
+
|
| 39 |
+
def __call__(self, images, **kwargs) -> BatchFeature:
|
| 40 |
+
"""Preprocess an image or a batch of images."""
|
| 41 |
+
return self.preprocess(images, **kwargs)
|
| 42 |
+
|
| 43 |
+
def preprocess(self, images, **kwargs) -> BatchFeature:
|
| 44 |
+
raise NotImplementedError("Each image processor must implement its own preprocess method")
|
| 45 |
+
|
| 46 |
+
def rescale(
|
| 47 |
+
self,
|
| 48 |
+
image: np.ndarray,
|
| 49 |
+
scale: float,
|
| 50 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 51 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 52 |
+
**kwargs,
|
| 53 |
+
) -> np.ndarray:
|
| 54 |
+
"""
|
| 55 |
+
Rescale an image by a scale factor. image = image * scale.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
image (`np.ndarray`):
|
| 59 |
+
Image to rescale.
|
| 60 |
+
scale (`float`):
|
| 61 |
+
The scaling factor to rescale pixel values by.
|
| 62 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 63 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 64 |
+
image is used. Can be one of:
|
| 65 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 66 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 67 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 68 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 69 |
+
from the input image. Can be one of:
|
| 70 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 71 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
`np.ndarray`: The rescaled image.
|
| 75 |
+
"""
|
| 76 |
+
return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
|
| 77 |
+
|
| 78 |
+
def normalize(
|
| 79 |
+
self,
|
| 80 |
+
image: np.ndarray,
|
| 81 |
+
mean: Union[float, Iterable[float]],
|
| 82 |
+
std: Union[float, Iterable[float]],
|
| 83 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 84 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 85 |
+
**kwargs,
|
| 86 |
+
) -> np.ndarray:
|
| 87 |
+
"""
|
| 88 |
+
Normalize an image. image = (image - image_mean) / image_std.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
image (`np.ndarray`):
|
| 92 |
+
Image to normalize.
|
| 93 |
+
mean (`float` or `Iterable[float]`):
|
| 94 |
+
Image mean to use for normalization.
|
| 95 |
+
std (`float` or `Iterable[float]`):
|
| 96 |
+
Image standard deviation to use for normalization.
|
| 97 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 98 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 99 |
+
image is used. Can be one of:
|
| 100 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 101 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 102 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 103 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 104 |
+
from the input image. Can be one of:
|
| 105 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 106 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
`np.ndarray`: The normalized image.
|
| 110 |
+
"""
|
| 111 |
+
return normalize(
|
| 112 |
+
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def center_crop(
|
| 116 |
+
self,
|
| 117 |
+
image: np.ndarray,
|
| 118 |
+
size: Dict[str, int],
|
| 119 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 120 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 121 |
+
**kwargs,
|
| 122 |
+
) -> np.ndarray:
|
| 123 |
+
"""
|
| 124 |
+
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
|
| 125 |
+
any edge, the image is padded with 0's and then center cropped.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
image (`np.ndarray`):
|
| 129 |
+
Image to center crop.
|
| 130 |
+
size (`Dict[str, int]`):
|
| 131 |
+
Size of the output image.
|
| 132 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 133 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 134 |
+
image is used. Can be one of:
|
| 135 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 136 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 137 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 138 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 139 |
+
from the input image. Can be one of:
|
| 140 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 141 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 142 |
+
"""
|
| 143 |
+
size = get_size_dict(size)
|
| 144 |
+
if "height" not in size or "width" not in size:
|
| 145 |
+
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
|
| 146 |
+
return center_crop(
|
| 147 |
+
image,
|
| 148 |
+
size=(size["height"], size["width"]),
|
| 149 |
+
data_format=data_format,
|
| 150 |
+
input_data_format=input_data_format,
|
| 151 |
+
**kwargs,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def to_dict(self):
|
| 155 |
+
encoder_dict = super().to_dict()
|
| 156 |
+
encoder_dict.pop("_valid_processor_keys", None)
|
| 157 |
+
return encoder_dict
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
VALID_SIZE_DICT_KEYS = (
|
| 161 |
+
{"height", "width"},
|
| 162 |
+
{"shortest_edge"},
|
| 163 |
+
{"shortest_edge", "longest_edge"},
|
| 164 |
+
{"longest_edge"},
|
| 165 |
+
{"max_height", "max_width"},
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def is_valid_size_dict(size_dict):
|
| 170 |
+
if not isinstance(size_dict, dict):
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
size_dict_keys = set(size_dict.keys())
|
| 174 |
+
for allowed_keys in VALID_SIZE_DICT_KEYS:
|
| 175 |
+
if size_dict_keys == allowed_keys:
|
| 176 |
+
return True
|
| 177 |
+
return False
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def convert_to_size_dict(
|
| 181 |
+
size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
|
| 182 |
+
):
|
| 183 |
+
# By default, if size is an int we assume it represents a tuple of (size, size).
|
| 184 |
+
if isinstance(size, int) and default_to_square:
|
| 185 |
+
if max_size is not None:
|
| 186 |
+
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
|
| 187 |
+
return {"height": size, "width": size}
|
| 188 |
+
# In other configs, if size is an int and default_to_square is False, size represents the length of
|
| 189 |
+
# the shortest edge after resizing.
|
| 190 |
+
elif isinstance(size, int) and not default_to_square:
|
| 191 |
+
size_dict = {"shortest_edge": size}
|
| 192 |
+
if max_size is not None:
|
| 193 |
+
size_dict["longest_edge"] = max_size
|
| 194 |
+
return size_dict
|
| 195 |
+
# Otherwise, if size is a tuple it's either (height, width) or (width, height)
|
| 196 |
+
elif isinstance(size, (tuple, list)) and height_width_order:
|
| 197 |
+
return {"height": size[0], "width": size[1]}
|
| 198 |
+
elif isinstance(size, (tuple, list)) and not height_width_order:
|
| 199 |
+
return {"height": size[1], "width": size[0]}
|
| 200 |
+
elif size is None and max_size is not None:
|
| 201 |
+
if default_to_square:
|
| 202 |
+
raise ValueError("Cannot specify both default_to_square=True and max_size")
|
| 203 |
+
return {"longest_edge": max_size}
|
| 204 |
+
|
| 205 |
+
raise ValueError(f"Could not convert size input to size dict: {size}")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_size_dict(
|
| 209 |
+
size: Union[int, Iterable[int], Dict[str, int]] = None,
|
| 210 |
+
max_size: Optional[int] = None,
|
| 211 |
+
height_width_order: bool = True,
|
| 212 |
+
default_to_square: bool = True,
|
| 213 |
+
param_name="size",
|
| 214 |
+
) -> dict:
|
| 215 |
+
"""
|
| 216 |
+
Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
|
| 217 |
+
compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
|
| 218 |
+
width) or (width, height) format.
|
| 219 |
+
|
| 220 |
+
- If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
|
| 221 |
+
size[0]}` if `height_width_order` is `False`.
|
| 222 |
+
- If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
|
| 223 |
+
- If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
|
| 224 |
+
is set, it is added to the dict as `{"longest_edge": max_size}`.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
|
| 228 |
+
The `size` parameter to be cast into a size dictionary.
|
| 229 |
+
max_size (`Optional[int]`, *optional*):
|
| 230 |
+
The `max_size` parameter to be cast into a size dictionary.
|
| 231 |
+
height_width_order (`bool`, *optional*, defaults to `True`):
|
| 232 |
+
If `size` is a tuple, whether it's in (height, width) or (width, height) order.
|
| 233 |
+
default_to_square (`bool`, *optional*, defaults to `True`):
|
| 234 |
+
If `size` is an int, whether to default to a square image or not.
|
| 235 |
+
"""
|
| 236 |
+
if not isinstance(size, dict):
|
| 237 |
+
size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
|
| 238 |
+
logger.info(
|
| 239 |
+
f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
|
| 240 |
+
f" Converted to {size_dict}.",
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
size_dict = size
|
| 244 |
+
|
| 245 |
+
if not is_valid_size_dict(size_dict):
|
| 246 |
+
raise ValueError(
|
| 247 |
+
f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
|
| 248 |
+
)
|
| 249 |
+
return size_dict
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
|
| 253 |
+
"""
|
| 254 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
| 255 |
+
|
| 256 |
+
This is done by calculating the effective and wasted resolution for each possible resolution.
|
| 257 |
+
|
| 258 |
+
The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
original_size (tuple):
|
| 262 |
+
The original size of the image in the format (height, width).
|
| 263 |
+
possible_resolutions (list):
|
| 264 |
+
A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
tuple: The best fit resolution in the format (height, width).
|
| 268 |
+
"""
|
| 269 |
+
original_height, original_width = original_size
|
| 270 |
+
best_fit = None
|
| 271 |
+
max_effective_resolution = 0
|
| 272 |
+
min_wasted_resolution = float("inf")
|
| 273 |
+
|
| 274 |
+
for height, width in possible_resolutions:
|
| 275 |
+
scale = min(width / original_width, height / original_height)
|
| 276 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
| 277 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
| 278 |
+
wasted_resolution = (width * height) - effective_resolution
|
| 279 |
+
|
| 280 |
+
if effective_resolution > max_effective_resolution or (
|
| 281 |
+
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
|
| 282 |
+
):
|
| 283 |
+
max_effective_resolution = effective_resolution
|
| 284 |
+
min_wasted_resolution = wasted_resolution
|
| 285 |
+
best_fit = (height, width)
|
| 286 |
+
|
| 287 |
+
return best_fit
|
.venv/lib/python3.11/site-packages/transformers/image_processing_utils_fast.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import functools
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Iterable, List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
from .image_processing_utils import BaseImageProcessor
|
| 21 |
+
from .utils.import_utils import is_torch_available, is_torchvision_available
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if is_torchvision_available():
|
| 25 |
+
from torchvision.transforms import Compose
|
| 26 |
+
|
| 27 |
+
if is_torch_available():
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass(frozen=True)
|
| 32 |
+
class SizeDict:
|
| 33 |
+
"""
|
| 34 |
+
Hashable dictionary to store image size information.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
height: int = None
|
| 38 |
+
width: int = None
|
| 39 |
+
longest_edge: int = None
|
| 40 |
+
shortest_edge: int = None
|
| 41 |
+
max_height: int = None
|
| 42 |
+
max_width: int = None
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, key):
|
| 45 |
+
if hasattr(self, key):
|
| 46 |
+
return getattr(self, key)
|
| 47 |
+
raise KeyError(f"Key {key} not found in SizeDict.")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class BaseImageProcessorFast(BaseImageProcessor):
|
| 51 |
+
_transform_params = None
|
| 52 |
+
|
| 53 |
+
def _build_transforms(self, **kwargs) -> "Compose":
|
| 54 |
+
"""
|
| 55 |
+
Given the input settings e.g. do_resize, build the image transforms.
|
| 56 |
+
"""
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
def _validate_params(self, **kwargs) -> None:
|
| 60 |
+
for k, v in kwargs.items():
|
| 61 |
+
if k not in self._transform_params:
|
| 62 |
+
raise ValueError(f"Invalid transform parameter {k}={v}.")
|
| 63 |
+
|
| 64 |
+
@functools.lru_cache(maxsize=1)
|
| 65 |
+
def get_transforms(self, **kwargs) -> "Compose":
|
| 66 |
+
self._validate_params(**kwargs)
|
| 67 |
+
return self._build_transforms(**kwargs)
|
| 68 |
+
|
| 69 |
+
def to_dict(self):
|
| 70 |
+
encoder_dict = super().to_dict()
|
| 71 |
+
encoder_dict.pop("_transform_params", None)
|
| 72 |
+
return encoder_dict
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_image_size_for_max_height_width(
|
| 76 |
+
image_size: Tuple[int, int],
|
| 77 |
+
max_height: int,
|
| 78 |
+
max_width: int,
|
| 79 |
+
) -> Tuple[int, int]:
|
| 80 |
+
"""
|
| 81 |
+
Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
|
| 82 |
+
Important, even if image_height < max_height and image_width < max_width, the image will be resized
|
| 83 |
+
to at least one of the edges be equal to max_height or max_width.
|
| 84 |
+
|
| 85 |
+
For example:
|
| 86 |
+
- input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
|
| 87 |
+
- input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
image_size (`Tuple[int, int]`):
|
| 91 |
+
The image to resize.
|
| 92 |
+
max_height (`int`):
|
| 93 |
+
The maximum allowed height.
|
| 94 |
+
max_width (`int`):
|
| 95 |
+
The maximum allowed width.
|
| 96 |
+
"""
|
| 97 |
+
height, width = image_size
|
| 98 |
+
height_scale = max_height / height
|
| 99 |
+
width_scale = max_width / width
|
| 100 |
+
min_scale = min(height_scale, width_scale)
|
| 101 |
+
new_height = int(height * min_scale)
|
| 102 |
+
new_width = int(width * min_scale)
|
| 103 |
+
return new_height, new_width
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
|
| 107 |
+
"""
|
| 108 |
+
Squeezes a tensor, but only if the axis specified has dim 1.
|
| 109 |
+
"""
|
| 110 |
+
if axis is None:
|
| 111 |
+
return tensor.squeeze()
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
return tensor.squeeze(axis=axis)
|
| 115 |
+
except ValueError:
|
| 116 |
+
return tensor
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
| 120 |
+
"""
|
| 121 |
+
Return the maximum value across all indices of an iterable of values.
|
| 122 |
+
"""
|
| 123 |
+
return [max(values_i) for values_i in zip(*values)]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]:
|
| 127 |
+
"""
|
| 128 |
+
Get the maximum height and width across all images in a batch.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
| 132 |
+
|
| 133 |
+
return (max_height, max_width)
|
.venv/lib/python3.11/site-packages/transformers/image_transforms.py
ADDED
|
@@ -0,0 +1,860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
from math import ceil
|
| 18 |
+
from typing import Iterable, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from .image_utils import (
|
| 23 |
+
ChannelDimension,
|
| 24 |
+
ImageInput,
|
| 25 |
+
get_channel_dimension_axis,
|
| 26 |
+
get_image_size,
|
| 27 |
+
infer_channel_dimension_format,
|
| 28 |
+
)
|
| 29 |
+
from .utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
|
| 30 |
+
from .utils.import_utils import (
|
| 31 |
+
is_flax_available,
|
| 32 |
+
is_tf_available,
|
| 33 |
+
is_torch_available,
|
| 34 |
+
is_torchvision_available,
|
| 35 |
+
is_torchvision_v2_available,
|
| 36 |
+
is_vision_available,
|
| 37 |
+
requires_backends,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if is_vision_available():
|
| 42 |
+
import PIL
|
| 43 |
+
|
| 44 |
+
from .image_utils import PILImageResampling
|
| 45 |
+
|
| 46 |
+
if is_torch_available():
|
| 47 |
+
import torch
|
| 48 |
+
|
| 49 |
+
if is_tf_available():
|
| 50 |
+
import tensorflow as tf
|
| 51 |
+
|
| 52 |
+
if is_flax_available():
|
| 53 |
+
import jax.numpy as jnp
|
| 54 |
+
|
| 55 |
+
if is_torchvision_v2_available():
|
| 56 |
+
from torchvision.transforms.v2 import functional as F
|
| 57 |
+
elif is_torchvision_available():
|
| 58 |
+
from torchvision.transforms import functional as F
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def to_channel_dimension_format(
|
| 62 |
+
image: np.ndarray,
|
| 63 |
+
channel_dim: Union[ChannelDimension, str],
|
| 64 |
+
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
|
| 65 |
+
) -> np.ndarray:
|
| 66 |
+
"""
|
| 67 |
+
Converts `image` to the channel dimension format specified by `channel_dim`.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
image (`numpy.ndarray`):
|
| 71 |
+
The image to have its channel dimension set.
|
| 72 |
+
channel_dim (`ChannelDimension`):
|
| 73 |
+
The channel dimension format to use.
|
| 74 |
+
input_channel_dim (`ChannelDimension`, *optional*):
|
| 75 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
`np.ndarray`: The image with the channel dimension set to `channel_dim`.
|
| 79 |
+
"""
|
| 80 |
+
if not isinstance(image, np.ndarray):
|
| 81 |
+
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
| 82 |
+
|
| 83 |
+
if input_channel_dim is None:
|
| 84 |
+
input_channel_dim = infer_channel_dimension_format(image)
|
| 85 |
+
|
| 86 |
+
target_channel_dim = ChannelDimension(channel_dim)
|
| 87 |
+
if input_channel_dim == target_channel_dim:
|
| 88 |
+
return image
|
| 89 |
+
|
| 90 |
+
if target_channel_dim == ChannelDimension.FIRST:
|
| 91 |
+
image = image.transpose((2, 0, 1))
|
| 92 |
+
elif target_channel_dim == ChannelDimension.LAST:
|
| 93 |
+
image = image.transpose((1, 2, 0))
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))
|
| 96 |
+
|
| 97 |
+
return image
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def rescale(
|
| 101 |
+
image: np.ndarray,
|
| 102 |
+
scale: float,
|
| 103 |
+
data_format: Optional[ChannelDimension] = None,
|
| 104 |
+
dtype: np.dtype = np.float32,
|
| 105 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 106 |
+
) -> np.ndarray:
|
| 107 |
+
"""
|
| 108 |
+
Rescales `image` by `scale`.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
image (`np.ndarray`):
|
| 112 |
+
The image to rescale.
|
| 113 |
+
scale (`float`):
|
| 114 |
+
The scale to use for rescaling the image.
|
| 115 |
+
data_format (`ChannelDimension`, *optional*):
|
| 116 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 117 |
+
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
| 118 |
+
The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
|
| 119 |
+
extractors.
|
| 120 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 121 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
`np.ndarray`: The rescaled image.
|
| 125 |
+
"""
|
| 126 |
+
if not isinstance(image, np.ndarray):
|
| 127 |
+
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
| 128 |
+
|
| 129 |
+
rescaled_image = image.astype(np.float64) * scale # Numpy type promotion has changed, so always upcast first
|
| 130 |
+
if data_format is not None:
|
| 131 |
+
rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
|
| 132 |
+
|
| 133 |
+
rescaled_image = rescaled_image.astype(dtype) # Finally downcast to the desired dtype at the end
|
| 134 |
+
|
| 135 |
+
return rescaled_image
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _rescale_for_pil_conversion(image):
|
| 139 |
+
"""
|
| 140 |
+
Detects whether or not the image needs to be rescaled before being converted to a PIL image.
|
| 141 |
+
|
| 142 |
+
The assumption is that if the image is of type `np.float` and all values are between 0 and 1, it needs to be
|
| 143 |
+
rescaled.
|
| 144 |
+
"""
|
| 145 |
+
if image.dtype == np.uint8:
|
| 146 |
+
do_rescale = False
|
| 147 |
+
elif np.allclose(image, image.astype(int)):
|
| 148 |
+
if np.all(0 <= image) and np.all(image <= 255):
|
| 149 |
+
do_rescale = False
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"The image to be converted to a PIL image contains values outside the range [0, 255], "
|
| 153 |
+
f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
|
| 154 |
+
)
|
| 155 |
+
elif np.all(0 <= image) and np.all(image <= 1):
|
| 156 |
+
do_rescale = True
|
| 157 |
+
else:
|
| 158 |
+
raise ValueError(
|
| 159 |
+
"The image to be converted to a PIL image contains values outside the range [0, 1], "
|
| 160 |
+
f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
|
| 161 |
+
)
|
| 162 |
+
return do_rescale
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def to_pil_image(
|
| 166 |
+
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
| 167 |
+
do_rescale: Optional[bool] = None,
|
| 168 |
+
image_mode: Optional[str] = None,
|
| 169 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 170 |
+
) -> "PIL.Image.Image":
|
| 171 |
+
"""
|
| 172 |
+
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
| 173 |
+
needed.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`):
|
| 177 |
+
The image to convert to the `PIL.Image` format.
|
| 178 |
+
do_rescale (`bool`, *optional*):
|
| 179 |
+
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
|
| 180 |
+
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
|
| 181 |
+
and `False` otherwise.
|
| 182 |
+
image_mode (`str`, *optional*):
|
| 183 |
+
The mode to use for the PIL image. If unset, will use the default mode for the input image type.
|
| 184 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 185 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
`PIL.Image.Image`: The converted image.
|
| 189 |
+
"""
|
| 190 |
+
requires_backends(to_pil_image, ["vision"])
|
| 191 |
+
|
| 192 |
+
if isinstance(image, PIL.Image.Image):
|
| 193 |
+
return image
|
| 194 |
+
|
| 195 |
+
# Convert all tensors to numpy arrays before converting to PIL image
|
| 196 |
+
if is_torch_tensor(image) or is_tf_tensor(image):
|
| 197 |
+
image = image.numpy()
|
| 198 |
+
elif is_jax_tensor(image):
|
| 199 |
+
image = np.array(image)
|
| 200 |
+
elif not isinstance(image, np.ndarray):
|
| 201 |
+
raise ValueError("Input image type not supported: {}".format(type(image)))
|
| 202 |
+
|
| 203 |
+
# If the channel has been moved to first dim, we put it back at the end.
|
| 204 |
+
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
|
| 205 |
+
|
| 206 |
+
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
| 207 |
+
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
| 208 |
+
|
| 209 |
+
# PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed.
|
| 210 |
+
do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale
|
| 211 |
+
|
| 212 |
+
if do_rescale:
|
| 213 |
+
image = rescale(image, 255)
|
| 214 |
+
|
| 215 |
+
image = image.astype(np.uint8)
|
| 216 |
+
return PIL.Image.fromarray(image, mode=image_mode)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
|
| 220 |
+
def get_resize_output_image_size(
|
| 221 |
+
input_image: np.ndarray,
|
| 222 |
+
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
| 223 |
+
default_to_square: bool = True,
|
| 224 |
+
max_size: Optional[int] = None,
|
| 225 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 226 |
+
) -> tuple:
|
| 227 |
+
"""
|
| 228 |
+
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
|
| 229 |
+
size.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
input_image (`np.ndarray`):
|
| 233 |
+
The image to resize.
|
| 234 |
+
size (`int` or `Tuple[int, int]` or List[int] or `Tuple[int]`):
|
| 235 |
+
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to
|
| 236 |
+
this.
|
| 237 |
+
|
| 238 |
+
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
|
| 239 |
+
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this
|
| 240 |
+
number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
| 241 |
+
default_to_square (`bool`, *optional*, defaults to `True`):
|
| 242 |
+
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square
|
| 243 |
+
(`size`,`size`). If set to `False`, will replicate
|
| 244 |
+
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
|
| 245 |
+
with support for resizing only the smallest edge and providing an optional `max_size`.
|
| 246 |
+
max_size (`int`, *optional*):
|
| 247 |
+
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater
|
| 248 |
+
than `max_size` after being resized according to `size`, then the image is resized again so that the longer
|
| 249 |
+
edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
|
| 250 |
+
than `size`. Only used if `default_to_square` is `False`.
|
| 251 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 252 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
`tuple`: The target (height, width) dimension of the output image after resizing.
|
| 256 |
+
"""
|
| 257 |
+
if isinstance(size, (tuple, list)):
|
| 258 |
+
if len(size) == 2:
|
| 259 |
+
return tuple(size)
|
| 260 |
+
elif len(size) == 1:
|
| 261 |
+
# Perform same logic as if size was an int
|
| 262 |
+
size = size[0]
|
| 263 |
+
else:
|
| 264 |
+
raise ValueError("size must have 1 or 2 elements if it is a list or tuple")
|
| 265 |
+
|
| 266 |
+
if default_to_square:
|
| 267 |
+
return (size, size)
|
| 268 |
+
|
| 269 |
+
height, width = get_image_size(input_image, input_data_format)
|
| 270 |
+
short, long = (width, height) if width <= height else (height, width)
|
| 271 |
+
requested_new_short = size
|
| 272 |
+
|
| 273 |
+
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
|
| 274 |
+
|
| 275 |
+
if max_size is not None:
|
| 276 |
+
if max_size <= requested_new_short:
|
| 277 |
+
raise ValueError(
|
| 278 |
+
f"max_size = {max_size} must be strictly greater than the requested "
|
| 279 |
+
f"size for the smaller edge size = {size}"
|
| 280 |
+
)
|
| 281 |
+
if new_long > max_size:
|
| 282 |
+
new_short, new_long = int(max_size * new_short / new_long), max_size
|
| 283 |
+
|
| 284 |
+
return (new_long, new_short) if width <= height else (new_short, new_long)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def resize(
|
| 288 |
+
image: np.ndarray,
|
| 289 |
+
size: Tuple[int, int],
|
| 290 |
+
resample: "PILImageResampling" = None,
|
| 291 |
+
reducing_gap: Optional[int] = None,
|
| 292 |
+
data_format: Optional[ChannelDimension] = None,
|
| 293 |
+
return_numpy: bool = True,
|
| 294 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 295 |
+
) -> np.ndarray:
|
| 296 |
+
"""
|
| 297 |
+
Resizes `image` to `(height, width)` specified by `size` using the PIL library.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
image (`np.ndarray`):
|
| 301 |
+
The image to resize.
|
| 302 |
+
size (`Tuple[int, int]`):
|
| 303 |
+
The size to use for resizing the image.
|
| 304 |
+
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 305 |
+
The filter to user for resampling.
|
| 306 |
+
reducing_gap (`int`, *optional*):
|
| 307 |
+
Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to
|
| 308 |
+
the fair resampling. See corresponding Pillow documentation for more details.
|
| 309 |
+
data_format (`ChannelDimension`, *optional*):
|
| 310 |
+
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
| 311 |
+
return_numpy (`bool`, *optional*, defaults to `True`):
|
| 312 |
+
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
|
| 313 |
+
returned.
|
| 314 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 315 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
`np.ndarray`: The resized image.
|
| 319 |
+
"""
|
| 320 |
+
requires_backends(resize, ["vision"])
|
| 321 |
+
|
| 322 |
+
resample = resample if resample is not None else PILImageResampling.BILINEAR
|
| 323 |
+
|
| 324 |
+
if not len(size) == 2:
|
| 325 |
+
raise ValueError("size must have 2 elements")
|
| 326 |
+
|
| 327 |
+
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
|
| 328 |
+
# The resized image from PIL will always have channels last, so find the input format first.
|
| 329 |
+
if input_data_format is None:
|
| 330 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 331 |
+
data_format = input_data_format if data_format is None else data_format
|
| 332 |
+
|
| 333 |
+
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
|
| 334 |
+
# the pillow library to resize the image and then convert back to numpy
|
| 335 |
+
do_rescale = False
|
| 336 |
+
if not isinstance(image, PIL.Image.Image):
|
| 337 |
+
do_rescale = _rescale_for_pil_conversion(image)
|
| 338 |
+
image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
|
| 339 |
+
height, width = size
|
| 340 |
+
# PIL images are in the format (width, height)
|
| 341 |
+
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
|
| 342 |
+
|
| 343 |
+
if return_numpy:
|
| 344 |
+
resized_image = np.array(resized_image)
|
| 345 |
+
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
|
| 346 |
+
# so we need to add it back if necessary.
|
| 347 |
+
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
|
| 348 |
+
# The image is always in channels last format after converting from a PIL image
|
| 349 |
+
resized_image = to_channel_dimension_format(
|
| 350 |
+
resized_image, data_format, input_channel_dim=ChannelDimension.LAST
|
| 351 |
+
)
|
| 352 |
+
# If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
|
| 353 |
+
# rescale it back to the original range.
|
| 354 |
+
resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
|
| 355 |
+
return resized_image
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def normalize(
|
| 359 |
+
image: np.ndarray,
|
| 360 |
+
mean: Union[float, Iterable[float]],
|
| 361 |
+
std: Union[float, Iterable[float]],
|
| 362 |
+
data_format: Optional[ChannelDimension] = None,
|
| 363 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 364 |
+
) -> np.ndarray:
|
| 365 |
+
"""
|
| 366 |
+
Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
|
| 367 |
+
|
| 368 |
+
image = (image - mean) / std
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
image (`np.ndarray`):
|
| 372 |
+
The image to normalize.
|
| 373 |
+
mean (`float` or `Iterable[float]`):
|
| 374 |
+
The mean to use for normalization.
|
| 375 |
+
std (`float` or `Iterable[float]`):
|
| 376 |
+
The standard deviation to use for normalization.
|
| 377 |
+
data_format (`ChannelDimension`, *optional*):
|
| 378 |
+
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
| 379 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 380 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 381 |
+
"""
|
| 382 |
+
if not isinstance(image, np.ndarray):
|
| 383 |
+
raise ValueError("image must be a numpy array")
|
| 384 |
+
|
| 385 |
+
if input_data_format is None:
|
| 386 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 387 |
+
|
| 388 |
+
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
|
| 389 |
+
num_channels = image.shape[channel_axis]
|
| 390 |
+
|
| 391 |
+
# We cast to float32 to avoid errors that can occur when subtracting uint8 values.
|
| 392 |
+
# We preserve the original dtype if it is a float type to prevent upcasting float16.
|
| 393 |
+
if not np.issubdtype(image.dtype, np.floating):
|
| 394 |
+
image = image.astype(np.float32)
|
| 395 |
+
|
| 396 |
+
if isinstance(mean, Iterable):
|
| 397 |
+
if len(mean) != num_channels:
|
| 398 |
+
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
|
| 399 |
+
else:
|
| 400 |
+
mean = [mean] * num_channels
|
| 401 |
+
mean = np.array(mean, dtype=image.dtype)
|
| 402 |
+
|
| 403 |
+
if isinstance(std, Iterable):
|
| 404 |
+
if len(std) != num_channels:
|
| 405 |
+
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
|
| 406 |
+
else:
|
| 407 |
+
std = [std] * num_channels
|
| 408 |
+
std = np.array(std, dtype=image.dtype)
|
| 409 |
+
|
| 410 |
+
if input_data_format == ChannelDimension.LAST:
|
| 411 |
+
image = (image - mean) / std
|
| 412 |
+
else:
|
| 413 |
+
image = ((image.T - mean) / std).T
|
| 414 |
+
|
| 415 |
+
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
| 416 |
+
return image
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def center_crop(
|
| 420 |
+
image: np.ndarray,
|
| 421 |
+
size: Tuple[int, int],
|
| 422 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 423 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 424 |
+
return_numpy: Optional[bool] = None,
|
| 425 |
+
) -> np.ndarray:
|
| 426 |
+
"""
|
| 427 |
+
Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
|
| 428 |
+
the size given, it will be padded (so the returned result will always be of size `size`).
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
image (`np.ndarray`):
|
| 432 |
+
The image to crop.
|
| 433 |
+
size (`Tuple[int, int]`):
|
| 434 |
+
The target size for the cropped image.
|
| 435 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 436 |
+
The channel dimension format for the output image. Can be one of:
|
| 437 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 438 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 439 |
+
If unset, will use the inferred format of the input image.
|
| 440 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 441 |
+
The channel dimension format for the input image. Can be one of:
|
| 442 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 443 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 444 |
+
If unset, will use the inferred format of the input image.
|
| 445 |
+
return_numpy (`bool`, *optional*):
|
| 446 |
+
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
|
| 447 |
+
previous ImageFeatureExtractionMixin method.
|
| 448 |
+
- Unset: will return the same type as the input image.
|
| 449 |
+
- `True`: will return a numpy array.
|
| 450 |
+
- `False`: will return a `PIL.Image.Image` object.
|
| 451 |
+
Returns:
|
| 452 |
+
`np.ndarray`: The cropped image.
|
| 453 |
+
"""
|
| 454 |
+
requires_backends(center_crop, ["vision"])
|
| 455 |
+
|
| 456 |
+
if return_numpy is not None:
|
| 457 |
+
warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
|
| 458 |
+
|
| 459 |
+
return_numpy = True if return_numpy is None else return_numpy
|
| 460 |
+
|
| 461 |
+
if not isinstance(image, np.ndarray):
|
| 462 |
+
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
|
| 463 |
+
|
| 464 |
+
if not isinstance(size, Iterable) or len(size) != 2:
|
| 465 |
+
raise ValueError("size must have 2 elements representing the height and width of the output image")
|
| 466 |
+
|
| 467 |
+
if input_data_format is None:
|
| 468 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 469 |
+
output_data_format = data_format if data_format is not None else input_data_format
|
| 470 |
+
|
| 471 |
+
# We perform the crop in (C, H, W) format and then convert to the output format
|
| 472 |
+
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
|
| 473 |
+
|
| 474 |
+
orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
|
| 475 |
+
crop_height, crop_width = size
|
| 476 |
+
crop_height, crop_width = int(crop_height), int(crop_width)
|
| 477 |
+
|
| 478 |
+
# In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
|
| 479 |
+
top = (orig_height - crop_height) // 2
|
| 480 |
+
bottom = top + crop_height
|
| 481 |
+
# In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
|
| 482 |
+
left = (orig_width - crop_width) // 2
|
| 483 |
+
right = left + crop_width
|
| 484 |
+
|
| 485 |
+
# Check if cropped area is within image boundaries
|
| 486 |
+
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
|
| 487 |
+
image = image[..., top:bottom, left:right]
|
| 488 |
+
image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
|
| 489 |
+
return image
|
| 490 |
+
|
| 491 |
+
# Otherwise, we may need to pad if the image is too small. Oh joy...
|
| 492 |
+
new_height = max(crop_height, orig_height)
|
| 493 |
+
new_width = max(crop_width, orig_width)
|
| 494 |
+
new_shape = image.shape[:-2] + (new_height, new_width)
|
| 495 |
+
new_image = np.zeros_like(image, shape=new_shape)
|
| 496 |
+
|
| 497 |
+
# If the image is too small, pad it with zeros
|
| 498 |
+
top_pad = ceil((new_height - orig_height) / 2)
|
| 499 |
+
bottom_pad = top_pad + orig_height
|
| 500 |
+
left_pad = ceil((new_width - orig_width) / 2)
|
| 501 |
+
right_pad = left_pad + orig_width
|
| 502 |
+
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
|
| 503 |
+
|
| 504 |
+
top += top_pad
|
| 505 |
+
bottom += top_pad
|
| 506 |
+
left += left_pad
|
| 507 |
+
right += left_pad
|
| 508 |
+
|
| 509 |
+
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
|
| 510 |
+
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
|
| 511 |
+
|
| 512 |
+
if not return_numpy:
|
| 513 |
+
new_image = to_pil_image(new_image)
|
| 514 |
+
|
| 515 |
+
return new_image
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
|
| 519 |
+
center_x, center_y, width, height = bboxes_center.unbind(-1)
|
| 520 |
+
bbox_corners = torch.stack(
|
| 521 |
+
# top left x, top left y, bottom right x, bottom right y
|
| 522 |
+
[(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
|
| 523 |
+
dim=-1,
|
| 524 |
+
)
|
| 525 |
+
return bbox_corners
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
|
| 529 |
+
center_x, center_y, width, height = bboxes_center.T
|
| 530 |
+
bboxes_corners = np.stack(
|
| 531 |
+
# top left x, top left y, bottom right x, bottom right y
|
| 532 |
+
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
|
| 533 |
+
axis=-1,
|
| 534 |
+
)
|
| 535 |
+
return bboxes_corners
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor":
|
| 539 |
+
center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1)
|
| 540 |
+
bboxes_corners = tf.stack(
|
| 541 |
+
# top left x, top left y, bottom right x, bottom right y
|
| 542 |
+
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
|
| 543 |
+
axis=-1,
|
| 544 |
+
)
|
| 545 |
+
return bboxes_corners
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
|
| 549 |
+
def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
|
| 550 |
+
"""
|
| 551 |
+
Converts bounding boxes from center format to corners format.
|
| 552 |
+
|
| 553 |
+
center format: contains the coordinate for the center of the box and its width, height dimensions
|
| 554 |
+
(center_x, center_y, width, height)
|
| 555 |
+
corners format: contains the coodinates for the top-left and bottom-right corners of the box
|
| 556 |
+
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
|
| 557 |
+
"""
|
| 558 |
+
# Function is used during model forward pass, so we use the input framework if possible, without
|
| 559 |
+
# converting to numpy
|
| 560 |
+
if is_torch_tensor(bboxes_center):
|
| 561 |
+
return _center_to_corners_format_torch(bboxes_center)
|
| 562 |
+
elif isinstance(bboxes_center, np.ndarray):
|
| 563 |
+
return _center_to_corners_format_numpy(bboxes_center)
|
| 564 |
+
elif is_tf_tensor(bboxes_center):
|
| 565 |
+
return _center_to_corners_format_tf(bboxes_center)
|
| 566 |
+
|
| 567 |
+
raise ValueError(f"Unsupported input type {type(bboxes_center)}")
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
|
| 571 |
+
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
|
| 572 |
+
b = [
|
| 573 |
+
(top_left_x + bottom_right_x) / 2, # center x
|
| 574 |
+
(top_left_y + bottom_right_y) / 2, # center y
|
| 575 |
+
(bottom_right_x - top_left_x), # width
|
| 576 |
+
(bottom_right_y - top_left_y), # height
|
| 577 |
+
]
|
| 578 |
+
return torch.stack(b, dim=-1)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
|
| 582 |
+
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
|
| 583 |
+
bboxes_center = np.stack(
|
| 584 |
+
[
|
| 585 |
+
(top_left_x + bottom_right_x) / 2, # center x
|
| 586 |
+
(top_left_y + bottom_right_y) / 2, # center y
|
| 587 |
+
(bottom_right_x - top_left_x), # width
|
| 588 |
+
(bottom_right_y - top_left_y), # height
|
| 589 |
+
],
|
| 590 |
+
axis=-1,
|
| 591 |
+
)
|
| 592 |
+
return bboxes_center
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor":
|
| 596 |
+
top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1)
|
| 597 |
+
bboxes_center = tf.stack(
|
| 598 |
+
[
|
| 599 |
+
(top_left_x + bottom_right_x) / 2, # center x
|
| 600 |
+
(top_left_y + bottom_right_y) / 2, # center y
|
| 601 |
+
(bottom_right_x - top_left_x), # width
|
| 602 |
+
(bottom_right_y - top_left_y), # height
|
| 603 |
+
],
|
| 604 |
+
axis=-1,
|
| 605 |
+
)
|
| 606 |
+
return bboxes_center
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
|
| 610 |
+
"""
|
| 611 |
+
Converts bounding boxes from corners format to center format.
|
| 612 |
+
|
| 613 |
+
corners format: contains the coordinates for the top-left and bottom-right corners of the box
|
| 614 |
+
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
|
| 615 |
+
center format: contains the coordinate for the center of the box and its the width, height dimensions
|
| 616 |
+
(center_x, center_y, width, height)
|
| 617 |
+
"""
|
| 618 |
+
# Inverse function accepts different input types so implemented here too
|
| 619 |
+
if is_torch_tensor(bboxes_corners):
|
| 620 |
+
return _corners_to_center_format_torch(bboxes_corners)
|
| 621 |
+
elif isinstance(bboxes_corners, np.ndarray):
|
| 622 |
+
return _corners_to_center_format_numpy(bboxes_corners)
|
| 623 |
+
elif is_tf_tensor(bboxes_corners):
|
| 624 |
+
return _corners_to_center_format_tf(bboxes_corners)
|
| 625 |
+
|
| 626 |
+
raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
|
| 630 |
+
# Copyright (c) 2018, Alexander Kirillov
|
| 631 |
+
# All rights reserved.
|
| 632 |
+
def rgb_to_id(color):
|
| 633 |
+
"""
|
| 634 |
+
Converts RGB color to unique ID.
|
| 635 |
+
"""
|
| 636 |
+
if isinstance(color, np.ndarray) and len(color.shape) == 3:
|
| 637 |
+
if color.dtype == np.uint8:
|
| 638 |
+
color = color.astype(np.int32)
|
| 639 |
+
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
|
| 640 |
+
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def id_to_rgb(id_map):
|
| 644 |
+
"""
|
| 645 |
+
Converts unique ID to RGB color.
|
| 646 |
+
"""
|
| 647 |
+
if isinstance(id_map, np.ndarray):
|
| 648 |
+
id_map_copy = id_map.copy()
|
| 649 |
+
rgb_shape = tuple(list(id_map.shape) + [3])
|
| 650 |
+
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
|
| 651 |
+
for i in range(3):
|
| 652 |
+
rgb_map[..., i] = id_map_copy % 256
|
| 653 |
+
id_map_copy //= 256
|
| 654 |
+
return rgb_map
|
| 655 |
+
color = []
|
| 656 |
+
for _ in range(3):
|
| 657 |
+
color.append(id_map % 256)
|
| 658 |
+
id_map //= 256
|
| 659 |
+
return color
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
class PaddingMode(ExplicitEnum):
|
| 663 |
+
"""
|
| 664 |
+
Enum class for the different padding modes to use when padding images.
|
| 665 |
+
"""
|
| 666 |
+
|
| 667 |
+
CONSTANT = "constant"
|
| 668 |
+
REFLECT = "reflect"
|
| 669 |
+
REPLICATE = "replicate"
|
| 670 |
+
SYMMETRIC = "symmetric"
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def pad(
|
| 674 |
+
image: np.ndarray,
|
| 675 |
+
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
|
| 676 |
+
mode: PaddingMode = PaddingMode.CONSTANT,
|
| 677 |
+
constant_values: Union[float, Iterable[float]] = 0.0,
|
| 678 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 679 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 680 |
+
) -> np.ndarray:
|
| 681 |
+
"""
|
| 682 |
+
Pads the `image` with the specified (height, width) `padding` and `mode`.
|
| 683 |
+
|
| 684 |
+
Args:
|
| 685 |
+
image (`np.ndarray`):
|
| 686 |
+
The image to pad.
|
| 687 |
+
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
|
| 688 |
+
Padding to apply to the edges of the height, width axes. Can be one of three formats:
|
| 689 |
+
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
|
| 690 |
+
- `((before, after),)` yields same before and after pad for height and width.
|
| 691 |
+
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
|
| 692 |
+
mode (`PaddingMode`):
|
| 693 |
+
The padding mode to use. Can be one of:
|
| 694 |
+
- `"constant"`: pads with a constant value.
|
| 695 |
+
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
| 696 |
+
vector along each axis.
|
| 697 |
+
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
| 698 |
+
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
| 699 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
| 700 |
+
The value to use for the padding if `mode` is `"constant"`.
|
| 701 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 702 |
+
The channel dimension format for the output image. Can be one of:
|
| 703 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 704 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 705 |
+
If unset, will use same as the input image.
|
| 706 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 707 |
+
The channel dimension format for the input image. Can be one of:
|
| 708 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 709 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 710 |
+
If unset, will use the inferred format of the input image.
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
`np.ndarray`: The padded image.
|
| 714 |
+
|
| 715 |
+
"""
|
| 716 |
+
if input_data_format is None:
|
| 717 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 718 |
+
|
| 719 |
+
def _expand_for_data_format(values):
|
| 720 |
+
"""
|
| 721 |
+
Convert values to be in the format expected by np.pad based on the data format.
|
| 722 |
+
"""
|
| 723 |
+
if isinstance(values, (int, float)):
|
| 724 |
+
values = ((values, values), (values, values))
|
| 725 |
+
elif isinstance(values, tuple) and len(values) == 1:
|
| 726 |
+
values = ((values[0], values[0]), (values[0], values[0]))
|
| 727 |
+
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
|
| 728 |
+
values = (values, values)
|
| 729 |
+
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
|
| 730 |
+
values = values
|
| 731 |
+
else:
|
| 732 |
+
raise ValueError(f"Unsupported format: {values}")
|
| 733 |
+
|
| 734 |
+
# add 0 for channel dimension
|
| 735 |
+
values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
|
| 736 |
+
|
| 737 |
+
# Add additional padding if there's a batch dimension
|
| 738 |
+
values = (0, *values) if image.ndim == 4 else values
|
| 739 |
+
return values
|
| 740 |
+
|
| 741 |
+
padding = _expand_for_data_format(padding)
|
| 742 |
+
|
| 743 |
+
if mode == PaddingMode.CONSTANT:
|
| 744 |
+
constant_values = _expand_for_data_format(constant_values)
|
| 745 |
+
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
|
| 746 |
+
elif mode == PaddingMode.REFLECT:
|
| 747 |
+
image = np.pad(image, padding, mode="reflect")
|
| 748 |
+
elif mode == PaddingMode.REPLICATE:
|
| 749 |
+
image = np.pad(image, padding, mode="edge")
|
| 750 |
+
elif mode == PaddingMode.SYMMETRIC:
|
| 751 |
+
image = np.pad(image, padding, mode="symmetric")
|
| 752 |
+
else:
|
| 753 |
+
raise ValueError(f"Invalid padding mode: {mode}")
|
| 754 |
+
|
| 755 |
+
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
| 756 |
+
return image
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
|
| 760 |
+
def convert_to_rgb(image: ImageInput) -> ImageInput:
|
| 761 |
+
"""
|
| 762 |
+
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
| 763 |
+
as is.
|
| 764 |
+
Args:
|
| 765 |
+
image (Image):
|
| 766 |
+
The image to convert.
|
| 767 |
+
"""
|
| 768 |
+
requires_backends(convert_to_rgb, ["vision"])
|
| 769 |
+
|
| 770 |
+
if not isinstance(image, PIL.Image.Image):
|
| 771 |
+
return image
|
| 772 |
+
|
| 773 |
+
if image.mode == "RGB":
|
| 774 |
+
return image
|
| 775 |
+
|
| 776 |
+
image = image.convert("RGB")
|
| 777 |
+
return image
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def flip_channel_order(
|
| 781 |
+
image: np.ndarray,
|
| 782 |
+
data_format: Optional[ChannelDimension] = None,
|
| 783 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 784 |
+
) -> np.ndarray:
|
| 785 |
+
"""
|
| 786 |
+
Flips the channel order of the image.
|
| 787 |
+
|
| 788 |
+
If the image is in RGB format, it will be converted to BGR and vice versa.
|
| 789 |
+
|
| 790 |
+
Args:
|
| 791 |
+
image (`np.ndarray`):
|
| 792 |
+
The image to flip.
|
| 793 |
+
data_format (`ChannelDimension`, *optional*):
|
| 794 |
+
The channel dimension format for the output image. Can be one of:
|
| 795 |
+
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 796 |
+
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 797 |
+
If unset, will use same as the input image.
|
| 798 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 799 |
+
The channel dimension format for the input image. Can be one of:
|
| 800 |
+
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 801 |
+
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 802 |
+
If unset, will use the inferred format of the input image.
|
| 803 |
+
"""
|
| 804 |
+
input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
|
| 805 |
+
|
| 806 |
+
if input_data_format == ChannelDimension.LAST:
|
| 807 |
+
image = image[..., ::-1]
|
| 808 |
+
elif input_data_format == ChannelDimension.FIRST:
|
| 809 |
+
image = image[::-1, ...]
|
| 810 |
+
else:
|
| 811 |
+
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
| 812 |
+
|
| 813 |
+
if data_format is not None:
|
| 814 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 815 |
+
return image
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def _cast_tensor_to_float(x):
|
| 819 |
+
if x.is_floating_point():
|
| 820 |
+
return x
|
| 821 |
+
return x.float()
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
class FusedRescaleNormalize:
|
| 825 |
+
"""
|
| 826 |
+
Rescale and normalize the input image in one step.
|
| 827 |
+
"""
|
| 828 |
+
|
| 829 |
+
def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
|
| 830 |
+
self.mean = torch.tensor(mean) * (1.0 / rescale_factor)
|
| 831 |
+
self.std = torch.tensor(std) * (1.0 / rescale_factor)
|
| 832 |
+
self.inplace = inplace
|
| 833 |
+
|
| 834 |
+
def __call__(self, image: "torch.Tensor"):
|
| 835 |
+
image = _cast_tensor_to_float(image)
|
| 836 |
+
return F.normalize(image, self.mean, self.std, inplace=self.inplace)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
class Rescale:
|
| 840 |
+
"""
|
| 841 |
+
Rescale the input image by rescale factor: image *= rescale_factor.
|
| 842 |
+
"""
|
| 843 |
+
|
| 844 |
+
def __init__(self, rescale_factor: float = 1.0):
|
| 845 |
+
self.rescale_factor = rescale_factor
|
| 846 |
+
|
| 847 |
+
def __call__(self, image: "torch.Tensor"):
|
| 848 |
+
image = image * self.rescale_factor
|
| 849 |
+
return image
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
class NumpyToTensor:
|
| 853 |
+
"""
|
| 854 |
+
Convert a numpy array to a PyTorch tensor.
|
| 855 |
+
"""
|
| 856 |
+
|
| 857 |
+
def __call__(self, image: np.ndarray):
|
| 858 |
+
# Same as in PyTorch, we assume incoming numpy images are in HWC format
|
| 859 |
+
# c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
|
| 860 |
+
return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()
|
.venv/lib/python3.11/site-packages/transformers/image_utils.py
ADDED
|
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import base64
|
| 17 |
+
import os
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import requests
|
| 23 |
+
from packaging import version
|
| 24 |
+
|
| 25 |
+
from .utils import (
|
| 26 |
+
ExplicitEnum,
|
| 27 |
+
TensorType,
|
| 28 |
+
is_jax_tensor,
|
| 29 |
+
is_numpy_array,
|
| 30 |
+
is_tf_tensor,
|
| 31 |
+
is_torch_available,
|
| 32 |
+
is_torch_tensor,
|
| 33 |
+
is_torchvision_available,
|
| 34 |
+
is_vision_available,
|
| 35 |
+
logging,
|
| 36 |
+
requires_backends,
|
| 37 |
+
to_numpy,
|
| 38 |
+
)
|
| 39 |
+
from .utils.constants import ( # noqa: F401
|
| 40 |
+
IMAGENET_DEFAULT_MEAN,
|
| 41 |
+
IMAGENET_DEFAULT_STD,
|
| 42 |
+
IMAGENET_STANDARD_MEAN,
|
| 43 |
+
IMAGENET_STANDARD_STD,
|
| 44 |
+
OPENAI_CLIP_MEAN,
|
| 45 |
+
OPENAI_CLIP_STD,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_vision_available():
|
| 50 |
+
import PIL.Image
|
| 51 |
+
import PIL.ImageOps
|
| 52 |
+
|
| 53 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
| 54 |
+
PILImageResampling = PIL.Image.Resampling
|
| 55 |
+
else:
|
| 56 |
+
PILImageResampling = PIL.Image
|
| 57 |
+
|
| 58 |
+
if is_torchvision_available():
|
| 59 |
+
from torchvision.transforms import InterpolationMode
|
| 60 |
+
|
| 61 |
+
pil_torch_interpolation_mapping = {
|
| 62 |
+
PILImageResampling.NEAREST: InterpolationMode.NEAREST,
|
| 63 |
+
PILImageResampling.BOX: InterpolationMode.BOX,
|
| 64 |
+
PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
|
| 65 |
+
PILImageResampling.HAMMING: InterpolationMode.HAMMING,
|
| 66 |
+
PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
|
| 67 |
+
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if TYPE_CHECKING:
|
| 72 |
+
if is_torch_available():
|
| 73 |
+
import torch
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
logger = logging.get_logger(__name__)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
ImageInput = Union[
|
| 80 |
+
"PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
|
| 81 |
+
] # noqa
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
VideoInput = Union[
|
| 85 |
+
List["PIL.Image.Image"],
|
| 86 |
+
"np.ndarray",
|
| 87 |
+
"torch.Tensor",
|
| 88 |
+
List["np.ndarray"],
|
| 89 |
+
List["torch.Tensor"],
|
| 90 |
+
List[List["PIL.Image.Image"]],
|
| 91 |
+
List[List["np.ndarrray"]],
|
| 92 |
+
List[List["torch.Tensor"]],
|
| 93 |
+
] # noqa
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ChannelDimension(ExplicitEnum):
|
| 97 |
+
FIRST = "channels_first"
|
| 98 |
+
LAST = "channels_last"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class AnnotationFormat(ExplicitEnum):
|
| 102 |
+
COCO_DETECTION = "coco_detection"
|
| 103 |
+
COCO_PANOPTIC = "coco_panoptic"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class AnnotionFormat(ExplicitEnum):
|
| 107 |
+
COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
|
| 108 |
+
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def is_pil_image(img):
|
| 115 |
+
return is_vision_available() and isinstance(img, PIL.Image.Image)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ImageType(ExplicitEnum):
|
| 119 |
+
PIL = "pillow"
|
| 120 |
+
TORCH = "torch"
|
| 121 |
+
NUMPY = "numpy"
|
| 122 |
+
TENSORFLOW = "tensorflow"
|
| 123 |
+
JAX = "jax"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_image_type(image):
|
| 127 |
+
if is_pil_image(image):
|
| 128 |
+
return ImageType.PIL
|
| 129 |
+
if is_torch_tensor(image):
|
| 130 |
+
return ImageType.TORCH
|
| 131 |
+
if is_numpy_array(image):
|
| 132 |
+
return ImageType.NUMPY
|
| 133 |
+
if is_tf_tensor(image):
|
| 134 |
+
return ImageType.TENSORFLOW
|
| 135 |
+
if is_jax_tensor(image):
|
| 136 |
+
return ImageType.JAX
|
| 137 |
+
raise ValueError(f"Unrecognised image type {type(image)}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def is_valid_image(img):
|
| 141 |
+
return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def valid_images(imgs):
|
| 145 |
+
# If we have an list of images, make sure every image is valid
|
| 146 |
+
if isinstance(imgs, (list, tuple)):
|
| 147 |
+
for img in imgs:
|
| 148 |
+
if not valid_images(img):
|
| 149 |
+
return False
|
| 150 |
+
# If not a list of tuple, we have been given a single image or batched tensor of images
|
| 151 |
+
elif not is_valid_image(imgs):
|
| 152 |
+
return False
|
| 153 |
+
return True
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def is_batched(img):
|
| 157 |
+
if isinstance(img, (list, tuple)):
|
| 158 |
+
return is_valid_image(img[0])
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def is_scaled_image(image: np.ndarray) -> bool:
|
| 163 |
+
"""
|
| 164 |
+
Checks to see whether the pixel values have already been rescaled to [0, 1].
|
| 165 |
+
"""
|
| 166 |
+
if image.dtype == np.uint8:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
# It's possible the image has pixel values in [0, 255] but is of floating type
|
| 170 |
+
return np.min(image) >= 0 and np.max(image) <= 1
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
|
| 174 |
+
"""
|
| 175 |
+
Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
|
| 176 |
+
If the input is a batch of images, it is converted to a list of images.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
images (`ImageInput`):
|
| 180 |
+
Image of images to turn into a list of images.
|
| 181 |
+
expected_ndims (`int`, *optional*, defaults to 3):
|
| 182 |
+
Expected number of dimensions for a single input image. If the input image has a different number of
|
| 183 |
+
dimensions, an error is raised.
|
| 184 |
+
"""
|
| 185 |
+
if is_batched(images):
|
| 186 |
+
return images
|
| 187 |
+
|
| 188 |
+
# Either the input is a single image, in which case we create a list of length 1
|
| 189 |
+
if isinstance(images, PIL.Image.Image):
|
| 190 |
+
# PIL images are never batched
|
| 191 |
+
return [images]
|
| 192 |
+
|
| 193 |
+
if is_valid_image(images):
|
| 194 |
+
if images.ndim == expected_ndims + 1:
|
| 195 |
+
# Batch of images
|
| 196 |
+
images = list(images)
|
| 197 |
+
elif images.ndim == expected_ndims:
|
| 198 |
+
# Single image
|
| 199 |
+
images = [images]
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
|
| 203 |
+
f" {images.ndim} dimensions."
|
| 204 |
+
)
|
| 205 |
+
return images
|
| 206 |
+
raise ValueError(
|
| 207 |
+
"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
|
| 208 |
+
f"jax.ndarray, but got {type(images)}."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def to_numpy_array(img) -> np.ndarray:
|
| 213 |
+
if not is_valid_image(img):
|
| 214 |
+
raise ValueError(f"Invalid image type: {type(img)}")
|
| 215 |
+
|
| 216 |
+
if is_vision_available() and isinstance(img, PIL.Image.Image):
|
| 217 |
+
return np.array(img)
|
| 218 |
+
return to_numpy(img)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def infer_channel_dimension_format(
|
| 222 |
+
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
|
| 223 |
+
) -> ChannelDimension:
|
| 224 |
+
"""
|
| 225 |
+
Infers the channel dimension format of `image`.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
image (`np.ndarray`):
|
| 229 |
+
The image to infer the channel dimension of.
|
| 230 |
+
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
|
| 231 |
+
The number of channels of the image.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
The channel dimension of the image.
|
| 235 |
+
"""
|
| 236 |
+
num_channels = num_channels if num_channels is not None else (1, 3)
|
| 237 |
+
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
|
| 238 |
+
|
| 239 |
+
if image.ndim == 3:
|
| 240 |
+
first_dim, last_dim = 0, 2
|
| 241 |
+
elif image.ndim == 4:
|
| 242 |
+
first_dim, last_dim = 1, 3
|
| 243 |
+
else:
|
| 244 |
+
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
|
| 245 |
+
|
| 246 |
+
if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
|
| 247 |
+
logger.warning(
|
| 248 |
+
f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
|
| 249 |
+
)
|
| 250 |
+
return ChannelDimension.FIRST
|
| 251 |
+
elif image.shape[first_dim] in num_channels:
|
| 252 |
+
return ChannelDimension.FIRST
|
| 253 |
+
elif image.shape[last_dim] in num_channels:
|
| 254 |
+
return ChannelDimension.LAST
|
| 255 |
+
raise ValueError("Unable to infer channel dimension format")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def get_channel_dimension_axis(
|
| 259 |
+
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
|
| 260 |
+
) -> int:
|
| 261 |
+
"""
|
| 262 |
+
Returns the channel dimension axis of the image.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
image (`np.ndarray`):
|
| 266 |
+
The image to get the channel dimension axis of.
|
| 267 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 268 |
+
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
The channel dimension axis of the image.
|
| 272 |
+
"""
|
| 273 |
+
if input_data_format is None:
|
| 274 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 275 |
+
if input_data_format == ChannelDimension.FIRST:
|
| 276 |
+
return image.ndim - 3
|
| 277 |
+
elif input_data_format == ChannelDimension.LAST:
|
| 278 |
+
return image.ndim - 1
|
| 279 |
+
raise ValueError(f"Unsupported data format: {input_data_format}")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
|
| 283 |
+
"""
|
| 284 |
+
Returns the (height, width) dimensions of the image.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
image (`np.ndarray`):
|
| 288 |
+
The image to get the dimensions of.
|
| 289 |
+
channel_dim (`ChannelDimension`, *optional*):
|
| 290 |
+
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
A tuple of the image's height and width.
|
| 294 |
+
"""
|
| 295 |
+
if channel_dim is None:
|
| 296 |
+
channel_dim = infer_channel_dimension_format(image)
|
| 297 |
+
|
| 298 |
+
if channel_dim == ChannelDimension.FIRST:
|
| 299 |
+
return image.shape[-2], image.shape[-1]
|
| 300 |
+
elif channel_dim == ChannelDimension.LAST:
|
| 301 |
+
return image.shape[-3], image.shape[-2]
|
| 302 |
+
else:
|
| 303 |
+
raise ValueError(f"Unsupported data format: {channel_dim}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
|
| 307 |
+
if (
|
| 308 |
+
isinstance(annotation, dict)
|
| 309 |
+
and "image_id" in annotation
|
| 310 |
+
and "annotations" in annotation
|
| 311 |
+
and isinstance(annotation["annotations"], (list, tuple))
|
| 312 |
+
and (
|
| 313 |
+
# an image can have no annotations
|
| 314 |
+
len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
|
| 315 |
+
)
|
| 316 |
+
):
|
| 317 |
+
return True
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
|
| 322 |
+
if (
|
| 323 |
+
isinstance(annotation, dict)
|
| 324 |
+
and "image_id" in annotation
|
| 325 |
+
and "segments_info" in annotation
|
| 326 |
+
and "file_name" in annotation
|
| 327 |
+
and isinstance(annotation["segments_info"], (list, tuple))
|
| 328 |
+
and (
|
| 329 |
+
# an image can have no segments
|
| 330 |
+
len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
|
| 331 |
+
)
|
| 332 |
+
):
|
| 333 |
+
return True
|
| 334 |
+
return False
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
|
| 338 |
+
return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
|
| 342 |
+
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
|
| 346 |
+
"""
|
| 347 |
+
Loads `image` to a PIL Image.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
image (`str` or `PIL.Image.Image`):
|
| 351 |
+
The image to convert to the PIL Image format.
|
| 352 |
+
timeout (`float`, *optional*):
|
| 353 |
+
The timeout value in seconds for the URL request.
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
`PIL.Image.Image`: A PIL Image.
|
| 357 |
+
"""
|
| 358 |
+
requires_backends(load_image, ["vision"])
|
| 359 |
+
if isinstance(image, str):
|
| 360 |
+
if image.startswith("http://") or image.startswith("https://"):
|
| 361 |
+
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
| 362 |
+
# like http_huggingface_co.png
|
| 363 |
+
image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
|
| 364 |
+
elif os.path.isfile(image):
|
| 365 |
+
image = PIL.Image.open(image)
|
| 366 |
+
else:
|
| 367 |
+
if image.startswith("data:image/"):
|
| 368 |
+
image = image.split(",")[1]
|
| 369 |
+
|
| 370 |
+
# Try to load as base64
|
| 371 |
+
try:
|
| 372 |
+
b64 = base64.decodebytes(image.encode())
|
| 373 |
+
image = PIL.Image.open(BytesIO(b64))
|
| 374 |
+
except Exception as e:
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
|
| 377 |
+
)
|
| 378 |
+
elif isinstance(image, PIL.Image.Image):
|
| 379 |
+
image = image
|
| 380 |
+
else:
|
| 381 |
+
raise TypeError(
|
| 382 |
+
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
|
| 383 |
+
)
|
| 384 |
+
image = PIL.ImageOps.exif_transpose(image)
|
| 385 |
+
image = image.convert("RGB")
|
| 386 |
+
return image
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def load_images(
|
| 390 |
+
images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
|
| 391 |
+
) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
|
| 392 |
+
"""Loads images, handling different levels of nesting.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
images: A single image, a list of images, or a list of lists of images to load.
|
| 396 |
+
timeout: Timeout for loading images.
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
A single image, a list of images, a list of lists of images.
|
| 400 |
+
"""
|
| 401 |
+
if isinstance(images, (list, tuple)):
|
| 402 |
+
if len(images) and isinstance(images[0], (list, tuple)):
|
| 403 |
+
return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
|
| 404 |
+
else:
|
| 405 |
+
return [load_image(image, timeout=timeout) for image in images]
|
| 406 |
+
else:
|
| 407 |
+
return load_image(images, timeout=timeout)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def validate_preprocess_arguments(
|
| 411 |
+
do_rescale: Optional[bool] = None,
|
| 412 |
+
rescale_factor: Optional[float] = None,
|
| 413 |
+
do_normalize: Optional[bool] = None,
|
| 414 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 415 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 416 |
+
do_pad: Optional[bool] = None,
|
| 417 |
+
size_divisibility: Optional[int] = None,
|
| 418 |
+
do_center_crop: Optional[bool] = None,
|
| 419 |
+
crop_size: Optional[Dict[str, int]] = None,
|
| 420 |
+
do_resize: Optional[bool] = None,
|
| 421 |
+
size: Optional[Dict[str, int]] = None,
|
| 422 |
+
resample: Optional["PILImageResampling"] = None,
|
| 423 |
+
):
|
| 424 |
+
"""
|
| 425 |
+
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
|
| 426 |
+
Raises `ValueError` if arguments incompatibility is caught.
|
| 427 |
+
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
|
| 428 |
+
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
|
| 429 |
+
existing arguments when possible.
|
| 430 |
+
|
| 431 |
+
"""
|
| 432 |
+
if do_rescale and rescale_factor is None:
|
| 433 |
+
raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
|
| 434 |
+
|
| 435 |
+
if do_pad and size_divisibility is None:
|
| 436 |
+
# Here, size_divisor might be passed as the value of size
|
| 437 |
+
raise ValueError(
|
| 438 |
+
"Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
if do_normalize and (image_mean is None or image_std is None):
|
| 442 |
+
raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
|
| 443 |
+
|
| 444 |
+
if do_center_crop and crop_size is None:
|
| 445 |
+
raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
|
| 446 |
+
|
| 447 |
+
if do_resize and (size is None or resample is None):
|
| 448 |
+
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def validate_fast_preprocess_arguments(
|
| 452 |
+
do_rescale: Optional[bool] = None,
|
| 453 |
+
rescale_factor: Optional[float] = None,
|
| 454 |
+
do_normalize: Optional[bool] = None,
|
| 455 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 456 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 457 |
+
do_pad: Optional[bool] = None,
|
| 458 |
+
size_divisibility: Optional[int] = None,
|
| 459 |
+
do_center_crop: Optional[bool] = None,
|
| 460 |
+
crop_size: Optional[Dict[str, int]] = None,
|
| 461 |
+
do_resize: Optional[bool] = None,
|
| 462 |
+
size: Optional[Dict[str, int]] = None,
|
| 463 |
+
resample: Optional["PILImageResampling"] = None,
|
| 464 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 465 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 466 |
+
):
|
| 467 |
+
"""
|
| 468 |
+
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
|
| 469 |
+
Raises `ValueError` if arguments incompatibility is caught.
|
| 470 |
+
"""
|
| 471 |
+
validate_preprocess_arguments(
|
| 472 |
+
do_rescale=do_rescale,
|
| 473 |
+
rescale_factor=rescale_factor,
|
| 474 |
+
do_normalize=do_normalize,
|
| 475 |
+
image_mean=image_mean,
|
| 476 |
+
image_std=image_std,
|
| 477 |
+
do_resize=do_resize,
|
| 478 |
+
size=size,
|
| 479 |
+
resample=resample,
|
| 480 |
+
)
|
| 481 |
+
# Extra checks for ImageProcessorFast
|
| 482 |
+
if return_tensors != "pt":
|
| 483 |
+
raise ValueError("Only returning PyTorch tensors is currently supported.")
|
| 484 |
+
|
| 485 |
+
if data_format != ChannelDimension.FIRST:
|
| 486 |
+
raise ValueError("Only channel first data format is currently supported.")
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
# In the future we can add a TF implementation here when we have TF models.
|
| 490 |
+
class ImageFeatureExtractionMixin:
|
| 491 |
+
"""
|
| 492 |
+
Mixin that contain utilities for preparing image features.
|
| 493 |
+
"""
|
| 494 |
+
|
| 495 |
+
def _ensure_format_supported(self, image):
|
| 496 |
+
if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
|
| 497 |
+
raise ValueError(
|
| 498 |
+
f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
|
| 499 |
+
"`torch.Tensor` are."
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
def to_pil_image(self, image, rescale=None):
|
| 503 |
+
"""
|
| 504 |
+
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
| 505 |
+
needed.
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
|
| 509 |
+
The image to convert to the PIL Image format.
|
| 510 |
+
rescale (`bool`, *optional*):
|
| 511 |
+
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
|
| 512 |
+
default to `True` if the image type is a floating type, `False` otherwise.
|
| 513 |
+
"""
|
| 514 |
+
self._ensure_format_supported(image)
|
| 515 |
+
|
| 516 |
+
if is_torch_tensor(image):
|
| 517 |
+
image = image.numpy()
|
| 518 |
+
|
| 519 |
+
if isinstance(image, np.ndarray):
|
| 520 |
+
if rescale is None:
|
| 521 |
+
# rescale default to the array being of floating type.
|
| 522 |
+
rescale = isinstance(image.flat[0], np.floating)
|
| 523 |
+
# If the channel as been moved to first dim, we put it back at the end.
|
| 524 |
+
if image.ndim == 3 and image.shape[0] in [1, 3]:
|
| 525 |
+
image = image.transpose(1, 2, 0)
|
| 526 |
+
if rescale:
|
| 527 |
+
image = image * 255
|
| 528 |
+
image = image.astype(np.uint8)
|
| 529 |
+
return PIL.Image.fromarray(image)
|
| 530 |
+
return image
|
| 531 |
+
|
| 532 |
+
def convert_rgb(self, image):
|
| 533 |
+
"""
|
| 534 |
+
Converts `PIL.Image.Image` to RGB format.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
image (`PIL.Image.Image`):
|
| 538 |
+
The image to convert.
|
| 539 |
+
"""
|
| 540 |
+
self._ensure_format_supported(image)
|
| 541 |
+
if not isinstance(image, PIL.Image.Image):
|
| 542 |
+
return image
|
| 543 |
+
|
| 544 |
+
return image.convert("RGB")
|
| 545 |
+
|
| 546 |
+
def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
|
| 547 |
+
"""
|
| 548 |
+
Rescale a numpy image by scale amount
|
| 549 |
+
"""
|
| 550 |
+
self._ensure_format_supported(image)
|
| 551 |
+
return image * scale
|
| 552 |
+
|
| 553 |
+
def to_numpy_array(self, image, rescale=None, channel_first=True):
|
| 554 |
+
"""
|
| 555 |
+
Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
|
| 556 |
+
dimension.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 560 |
+
The image to convert to a NumPy array.
|
| 561 |
+
rescale (`bool`, *optional*):
|
| 562 |
+
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
|
| 563 |
+
default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
|
| 564 |
+
channel_first (`bool`, *optional*, defaults to `True`):
|
| 565 |
+
Whether or not to permute the dimensions of the image to put the channel dimension first.
|
| 566 |
+
"""
|
| 567 |
+
self._ensure_format_supported(image)
|
| 568 |
+
|
| 569 |
+
if isinstance(image, PIL.Image.Image):
|
| 570 |
+
image = np.array(image)
|
| 571 |
+
|
| 572 |
+
if is_torch_tensor(image):
|
| 573 |
+
image = image.numpy()
|
| 574 |
+
|
| 575 |
+
rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
|
| 576 |
+
|
| 577 |
+
if rescale:
|
| 578 |
+
image = self.rescale(image.astype(np.float32), 1 / 255.0)
|
| 579 |
+
|
| 580 |
+
if channel_first and image.ndim == 3:
|
| 581 |
+
image = image.transpose(2, 0, 1)
|
| 582 |
+
|
| 583 |
+
return image
|
| 584 |
+
|
| 585 |
+
def expand_dims(self, image):
|
| 586 |
+
"""
|
| 587 |
+
Expands 2-dimensional `image` to 3 dimensions.
|
| 588 |
+
|
| 589 |
+
Args:
|
| 590 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 591 |
+
The image to expand.
|
| 592 |
+
"""
|
| 593 |
+
self._ensure_format_supported(image)
|
| 594 |
+
|
| 595 |
+
# Do nothing if PIL image
|
| 596 |
+
if isinstance(image, PIL.Image.Image):
|
| 597 |
+
return image
|
| 598 |
+
|
| 599 |
+
if is_torch_tensor(image):
|
| 600 |
+
image = image.unsqueeze(0)
|
| 601 |
+
else:
|
| 602 |
+
image = np.expand_dims(image, axis=0)
|
| 603 |
+
return image
|
| 604 |
+
|
| 605 |
+
def normalize(self, image, mean, std, rescale=False):
|
| 606 |
+
"""
|
| 607 |
+
Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
|
| 608 |
+
if it's a PIL Image.
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 612 |
+
The image to normalize.
|
| 613 |
+
mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
|
| 614 |
+
The mean (per channel) to use for normalization.
|
| 615 |
+
std (`List[float]` or `np.ndarray` or `torch.Tensor`):
|
| 616 |
+
The standard deviation (per channel) to use for normalization.
|
| 617 |
+
rescale (`bool`, *optional*, defaults to `False`):
|
| 618 |
+
Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
|
| 619 |
+
happen automatically.
|
| 620 |
+
"""
|
| 621 |
+
self._ensure_format_supported(image)
|
| 622 |
+
|
| 623 |
+
if isinstance(image, PIL.Image.Image):
|
| 624 |
+
image = self.to_numpy_array(image, rescale=True)
|
| 625 |
+
# If the input image is a PIL image, it automatically gets rescaled. If it's another
|
| 626 |
+
# type it may need rescaling.
|
| 627 |
+
elif rescale:
|
| 628 |
+
if isinstance(image, np.ndarray):
|
| 629 |
+
image = self.rescale(image.astype(np.float32), 1 / 255.0)
|
| 630 |
+
elif is_torch_tensor(image):
|
| 631 |
+
image = self.rescale(image.float(), 1 / 255.0)
|
| 632 |
+
|
| 633 |
+
if isinstance(image, np.ndarray):
|
| 634 |
+
if not isinstance(mean, np.ndarray):
|
| 635 |
+
mean = np.array(mean).astype(image.dtype)
|
| 636 |
+
if not isinstance(std, np.ndarray):
|
| 637 |
+
std = np.array(std).astype(image.dtype)
|
| 638 |
+
elif is_torch_tensor(image):
|
| 639 |
+
import torch
|
| 640 |
+
|
| 641 |
+
if not isinstance(mean, torch.Tensor):
|
| 642 |
+
if isinstance(mean, np.ndarray):
|
| 643 |
+
mean = torch.from_numpy(mean)
|
| 644 |
+
else:
|
| 645 |
+
mean = torch.tensor(mean)
|
| 646 |
+
if not isinstance(std, torch.Tensor):
|
| 647 |
+
if isinstance(std, np.ndarray):
|
| 648 |
+
std = torch.from_numpy(std)
|
| 649 |
+
else:
|
| 650 |
+
std = torch.tensor(std)
|
| 651 |
+
|
| 652 |
+
if image.ndim == 3 and image.shape[0] in [1, 3]:
|
| 653 |
+
return (image - mean[:, None, None]) / std[:, None, None]
|
| 654 |
+
else:
|
| 655 |
+
return (image - mean) / std
|
| 656 |
+
|
| 657 |
+
def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
|
| 658 |
+
"""
|
| 659 |
+
Resizes `image`. Enforces conversion of input to PIL.Image.
|
| 660 |
+
|
| 661 |
+
Args:
|
| 662 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 663 |
+
The image to resize.
|
| 664 |
+
size (`int` or `Tuple[int, int]`):
|
| 665 |
+
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
|
| 666 |
+
matched to this.
|
| 667 |
+
|
| 668 |
+
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
|
| 669 |
+
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
|
| 670 |
+
this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
| 671 |
+
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 672 |
+
The filter to user for resampling.
|
| 673 |
+
default_to_square (`bool`, *optional*, defaults to `True`):
|
| 674 |
+
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
|
| 675 |
+
square (`size`,`size`). If set to `False`, will replicate
|
| 676 |
+
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
|
| 677 |
+
with support for resizing only the smallest edge and providing an optional `max_size`.
|
| 678 |
+
max_size (`int`, *optional*, defaults to `None`):
|
| 679 |
+
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
|
| 680 |
+
greater than `max_size` after being resized according to `size`, then the image is resized again so
|
| 681 |
+
that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
|
| 682 |
+
edge may be shorter than `size`. Only used if `default_to_square` is `False`.
|
| 683 |
+
|
| 684 |
+
Returns:
|
| 685 |
+
image: A resized `PIL.Image.Image`.
|
| 686 |
+
"""
|
| 687 |
+
resample = resample if resample is not None else PILImageResampling.BILINEAR
|
| 688 |
+
|
| 689 |
+
self._ensure_format_supported(image)
|
| 690 |
+
|
| 691 |
+
if not isinstance(image, PIL.Image.Image):
|
| 692 |
+
image = self.to_pil_image(image)
|
| 693 |
+
|
| 694 |
+
if isinstance(size, list):
|
| 695 |
+
size = tuple(size)
|
| 696 |
+
|
| 697 |
+
if isinstance(size, int) or len(size) == 1:
|
| 698 |
+
if default_to_square:
|
| 699 |
+
size = (size, size) if isinstance(size, int) else (size[0], size[0])
|
| 700 |
+
else:
|
| 701 |
+
width, height = image.size
|
| 702 |
+
# specified size only for the smallest edge
|
| 703 |
+
short, long = (width, height) if width <= height else (height, width)
|
| 704 |
+
requested_new_short = size if isinstance(size, int) else size[0]
|
| 705 |
+
|
| 706 |
+
if short == requested_new_short:
|
| 707 |
+
return image
|
| 708 |
+
|
| 709 |
+
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
|
| 710 |
+
|
| 711 |
+
if max_size is not None:
|
| 712 |
+
if max_size <= requested_new_short:
|
| 713 |
+
raise ValueError(
|
| 714 |
+
f"max_size = {max_size} must be strictly greater than the requested "
|
| 715 |
+
f"size for the smaller edge size = {size}"
|
| 716 |
+
)
|
| 717 |
+
if new_long > max_size:
|
| 718 |
+
new_short, new_long = int(max_size * new_short / new_long), max_size
|
| 719 |
+
|
| 720 |
+
size = (new_short, new_long) if width <= height else (new_long, new_short)
|
| 721 |
+
|
| 722 |
+
return image.resize(size, resample=resample)
|
| 723 |
+
|
| 724 |
+
def center_crop(self, image, size):
|
| 725 |
+
"""
|
| 726 |
+
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
|
| 727 |
+
size given, it will be padded (so the returned result has the size asked).
|
| 728 |
+
|
| 729 |
+
Args:
|
| 730 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
|
| 731 |
+
The image to resize.
|
| 732 |
+
size (`int` or `Tuple[int, int]`):
|
| 733 |
+
The size to which crop the image.
|
| 734 |
+
|
| 735 |
+
Returns:
|
| 736 |
+
new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
|
| 737 |
+
height, width).
|
| 738 |
+
"""
|
| 739 |
+
self._ensure_format_supported(image)
|
| 740 |
+
|
| 741 |
+
if not isinstance(size, tuple):
|
| 742 |
+
size = (size, size)
|
| 743 |
+
|
| 744 |
+
# PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
|
| 745 |
+
if is_torch_tensor(image) or isinstance(image, np.ndarray):
|
| 746 |
+
if image.ndim == 2:
|
| 747 |
+
image = self.expand_dims(image)
|
| 748 |
+
image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
|
| 749 |
+
else:
|
| 750 |
+
image_shape = (image.size[1], image.size[0])
|
| 751 |
+
|
| 752 |
+
top = (image_shape[0] - size[0]) // 2
|
| 753 |
+
bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
|
| 754 |
+
left = (image_shape[1] - size[1]) // 2
|
| 755 |
+
right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
|
| 756 |
+
|
| 757 |
+
# For PIL Images we have a method to crop directly.
|
| 758 |
+
if isinstance(image, PIL.Image.Image):
|
| 759 |
+
return image.crop((left, top, right, bottom))
|
| 760 |
+
|
| 761 |
+
# Check if image is in (n_channels, height, width) or (height, width, n_channels) format
|
| 762 |
+
channel_first = True if image.shape[0] in [1, 3] else False
|
| 763 |
+
|
| 764 |
+
# Transpose (height, width, n_channels) format images
|
| 765 |
+
if not channel_first:
|
| 766 |
+
if isinstance(image, np.ndarray):
|
| 767 |
+
image = image.transpose(2, 0, 1)
|
| 768 |
+
if is_torch_tensor(image):
|
| 769 |
+
image = image.permute(2, 0, 1)
|
| 770 |
+
|
| 771 |
+
# Check if cropped area is within image boundaries
|
| 772 |
+
if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
|
| 773 |
+
return image[..., top:bottom, left:right]
|
| 774 |
+
|
| 775 |
+
# Otherwise, we may need to pad if the image is too small. Oh joy...
|
| 776 |
+
new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
|
| 777 |
+
if isinstance(image, np.ndarray):
|
| 778 |
+
new_image = np.zeros_like(image, shape=new_shape)
|
| 779 |
+
elif is_torch_tensor(image):
|
| 780 |
+
new_image = image.new_zeros(new_shape)
|
| 781 |
+
|
| 782 |
+
top_pad = (new_shape[-2] - image_shape[0]) // 2
|
| 783 |
+
bottom_pad = top_pad + image_shape[0]
|
| 784 |
+
left_pad = (new_shape[-1] - image_shape[1]) // 2
|
| 785 |
+
right_pad = left_pad + image_shape[1]
|
| 786 |
+
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
|
| 787 |
+
|
| 788 |
+
top += top_pad
|
| 789 |
+
bottom += top_pad
|
| 790 |
+
left += left_pad
|
| 791 |
+
right += left_pad
|
| 792 |
+
|
| 793 |
+
new_image = new_image[
|
| 794 |
+
..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
|
| 795 |
+
]
|
| 796 |
+
|
| 797 |
+
return new_image
|
| 798 |
+
|
| 799 |
+
def flip_channel_order(self, image):
|
| 800 |
+
"""
|
| 801 |
+
Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
|
| 802 |
+
`image` to a NumPy array if it's a PIL Image.
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 806 |
+
The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
|
| 807 |
+
be first.
|
| 808 |
+
"""
|
| 809 |
+
self._ensure_format_supported(image)
|
| 810 |
+
|
| 811 |
+
if isinstance(image, PIL.Image.Image):
|
| 812 |
+
image = self.to_numpy_array(image)
|
| 813 |
+
|
| 814 |
+
return image[::-1, :, :]
|
| 815 |
+
|
| 816 |
+
def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
|
| 817 |
+
"""
|
| 818 |
+
Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
|
| 819 |
+
counter clockwise around its centre.
|
| 820 |
+
|
| 821 |
+
Args:
|
| 822 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 823 |
+
The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
|
| 824 |
+
rotating.
|
| 825 |
+
|
| 826 |
+
Returns:
|
| 827 |
+
image: A rotated `PIL.Image.Image`.
|
| 828 |
+
"""
|
| 829 |
+
resample = resample if resample is not None else PIL.Image.NEAREST
|
| 830 |
+
|
| 831 |
+
self._ensure_format_supported(image)
|
| 832 |
+
|
| 833 |
+
if not isinstance(image, PIL.Image.Image):
|
| 834 |
+
image = self.to_pil_image(image)
|
| 835 |
+
|
| 836 |
+
return image.rotate(
|
| 837 |
+
angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
def validate_annotations(
|
| 842 |
+
annotation_format: AnnotationFormat,
|
| 843 |
+
supported_annotation_formats: Tuple[AnnotationFormat, ...],
|
| 844 |
+
annotations: List[Dict],
|
| 845 |
+
) -> None:
|
| 846 |
+
if annotation_format not in supported_annotation_formats:
|
| 847 |
+
raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
|
| 848 |
+
|
| 849 |
+
if annotation_format is AnnotationFormat.COCO_DETECTION:
|
| 850 |
+
if not valid_coco_detection_annotations(annotations):
|
| 851 |
+
raise ValueError(
|
| 852 |
+
"Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
|
| 853 |
+
"(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
|
| 854 |
+
"being a list of annotations in the COCO format."
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
if annotation_format is AnnotationFormat.COCO_PANOPTIC:
|
| 858 |
+
if not valid_coco_panoptic_annotations(annotations):
|
| 859 |
+
raise ValueError(
|
| 860 |
+
"Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
|
| 861 |
+
"(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
|
| 862 |
+
"the latter being a list of annotations in the COCO format."
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
|
| 867 |
+
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
|
| 868 |
+
if unused_keys:
|
| 869 |
+
unused_key_str = ", ".join(unused_keys)
|
| 870 |
+
# TODO raise a warning here instead of simply logging?
|
| 871 |
+
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
|
.venv/lib/python3.11/site-packages/transformers/keras_callbacks.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from time import sleep
|
| 5 |
+
from typing import Callable, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
from huggingface_hub import Repository, create_repo
|
| 10 |
+
from packaging.version import parse
|
| 11 |
+
|
| 12 |
+
from . import IntervalStrategy, PreTrainedTokenizerBase
|
| 13 |
+
from .modelcard import TrainingSummary
|
| 14 |
+
from .modeling_tf_utils import keras
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class KerasMetricCallback(keras.callbacks.Callback):
|
| 21 |
+
"""
|
| 22 |
+
Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
|
| 23 |
+
compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
|
| 24 |
+
operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the
|
| 25 |
+
`eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute
|
| 26 |
+
metrics and return a dict mapping metric names to metric values.
|
| 27 |
+
|
| 28 |
+
We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that
|
| 29 |
+
this example skips some post-processing for readability and simplicity, and should probably not be used as-is!
|
| 30 |
+
|
| 31 |
+
```py
|
| 32 |
+
from datasets import load_metric
|
| 33 |
+
|
| 34 |
+
rouge_metric = load_metric("rouge")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def rouge_fn(predictions, labels):
|
| 38 |
+
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
| 39 |
+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 40 |
+
result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
|
| 41 |
+
return {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
The above function will return a dict containing values which will be logged like any other Keras metric:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
{'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
metric_fn (`Callable`):
|
| 52 |
+
Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`.
|
| 53 |
+
These contain the model's outputs and matching labels from the dataset. It should return a dict mapping
|
| 54 |
+
metric names to numerical values.
|
| 55 |
+
eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
|
| 56 |
+
Validation data to be used to generate predictions for the `metric_fn`.
|
| 57 |
+
output_cols (`List[str], *optional*):
|
| 58 |
+
A list of columns to be retained from the model output as the predictions. Defaults to all.
|
| 59 |
+
label_cols ('`List[str]`, *optional*'):
|
| 60 |
+
A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not
|
| 61 |
+
supplied.
|
| 62 |
+
batch_size (`int`, *optional*):
|
| 63 |
+
Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
|
| 64 |
+
predict_with_generate (`bool`, *optional*, defaults to `False`):
|
| 65 |
+
Whether we should use `model.generate()` to get outputs for the model.
|
| 66 |
+
use_xla_generation (`bool`, *optional*, defaults to `False`):
|
| 67 |
+
If we're generating, whether to compile model generation with XLA. This can massively increase the speed of
|
| 68 |
+
generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA
|
| 69 |
+
generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of`
|
| 70 |
+
argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and
|
| 71 |
+
save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`.
|
| 72 |
+
generate_kwargs (`dict`, *optional*):
|
| 73 |
+
Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate`
|
| 74 |
+
is `False`.
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
metric_fn: Callable,
|
| 81 |
+
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
|
| 82 |
+
output_cols: Optional[List[str]] = None,
|
| 83 |
+
label_cols: Optional[List[str]] = None,
|
| 84 |
+
batch_size: Optional[int] = None,
|
| 85 |
+
predict_with_generate: bool = False,
|
| 86 |
+
use_xla_generation: bool = False,
|
| 87 |
+
generate_kwargs: Optional[dict] = None,
|
| 88 |
+
):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.metric_fn = metric_fn
|
| 91 |
+
self.batch_size = batch_size
|
| 92 |
+
if not isinstance(eval_dataset, tf.data.Dataset):
|
| 93 |
+
if batch_size is None:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset "
|
| 96 |
+
"the batch_size argument must be set."
|
| 97 |
+
)
|
| 98 |
+
# Wrap a tf.data.Dataset around it
|
| 99 |
+
eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False)
|
| 100 |
+
self.eval_dataset = eval_dataset
|
| 101 |
+
self.predict_with_generate = predict_with_generate
|
| 102 |
+
self.output_cols = output_cols
|
| 103 |
+
|
| 104 |
+
# This next block attempts to parse out which elements of the dataset should be appended to the labels list
|
| 105 |
+
# that is passed to the metric_fn
|
| 106 |
+
if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2:
|
| 107 |
+
input_spec, label_spec = eval_dataset.element_spec
|
| 108 |
+
else:
|
| 109 |
+
input_spec = eval_dataset.element_spec
|
| 110 |
+
label_spec = None
|
| 111 |
+
if label_cols is not None:
|
| 112 |
+
for label in label_cols:
|
| 113 |
+
if label not in input_spec:
|
| 114 |
+
raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!")
|
| 115 |
+
self.label_cols = label_cols
|
| 116 |
+
self.use_keras_label = False
|
| 117 |
+
elif label_spec is not None:
|
| 118 |
+
# If the dataset inputs are split into a 2-tuple of inputs and labels,
|
| 119 |
+
# assume the second element is the labels
|
| 120 |
+
self.label_cols = None
|
| 121 |
+
self.use_keras_label = True
|
| 122 |
+
elif "labels" in input_spec:
|
| 123 |
+
self.label_cols = ["labels"]
|
| 124 |
+
self.use_keras_label = False
|
| 125 |
+
logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
|
| 126 |
+
elif "start_positions" in input_spec and "end_positions" in input_spec:
|
| 127 |
+
self.label_cols = ["start_positions", "end_positions"]
|
| 128 |
+
self.use_keras_label = False
|
| 129 |
+
logging.warning(
|
| 130 |
+
"No label_cols specified for KerasMetricCallback, assuming you want the "
|
| 131 |
+
"start_positions and end_positions keys."
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
|
| 135 |
+
if parse(tf.__version__) < parse("2.7"):
|
| 136 |
+
logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
|
| 137 |
+
|
| 138 |
+
self.use_xla_generation = use_xla_generation
|
| 139 |
+
self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs
|
| 140 |
+
|
| 141 |
+
self.generation_function = None
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def _concatenate_batches(batches, padding_index=-100):
|
| 145 |
+
# If all batches are unidimensional or same length, do a simple concatenation
|
| 146 |
+
if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches):
|
| 147 |
+
return np.concatenate(batches, axis=0)
|
| 148 |
+
|
| 149 |
+
# Welp, they're not the same length. Let's do some padding
|
| 150 |
+
max_len = max([batch.shape[1] for batch in batches])
|
| 151 |
+
num_samples = sum([batch.shape[0] for batch in batches])
|
| 152 |
+
output = np.full_like(
|
| 153 |
+
batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
|
| 154 |
+
)
|
| 155 |
+
# i keeps track of which part of the concatenated array we're writing the next batch to
|
| 156 |
+
i = 0
|
| 157 |
+
for batch in batches:
|
| 158 |
+
output[i : i + len(batch), : batch.shape[1]] = batch
|
| 159 |
+
i += len(batch)
|
| 160 |
+
return output
|
| 161 |
+
|
| 162 |
+
def _postprocess_predictions_or_labels(self, inputs):
|
| 163 |
+
if isinstance(inputs[0], dict):
|
| 164 |
+
outputs = {}
|
| 165 |
+
for key in inputs[0].keys():
|
| 166 |
+
outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
|
| 167 |
+
# If it's a dict with only one key, just return the array
|
| 168 |
+
if len(outputs) == 1:
|
| 169 |
+
outputs = list(outputs.values())[0]
|
| 170 |
+
elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
|
| 171 |
+
outputs = []
|
| 172 |
+
for input_list in zip(*inputs):
|
| 173 |
+
outputs.append(self._concatenate_batches(input_list))
|
| 174 |
+
if len(outputs) == 1:
|
| 175 |
+
outputs = outputs[0] # If it's a list with only one element, just return the array
|
| 176 |
+
elif isinstance(inputs[0], np.ndarray):
|
| 177 |
+
outputs = self._concatenate_batches(inputs)
|
| 178 |
+
elif isinstance(inputs[0], tf.Tensor):
|
| 179 |
+
outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
|
| 180 |
+
else:
|
| 181 |
+
raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
|
| 182 |
+
return outputs
|
| 183 |
+
|
| 184 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 185 |
+
if hasattr(self.model, "config"):
|
| 186 |
+
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
| 187 |
+
else:
|
| 188 |
+
ignore_keys = []
|
| 189 |
+
|
| 190 |
+
main_input_name = None
|
| 191 |
+
if self.predict_with_generate:
|
| 192 |
+
# This dense conditional recognizes the case where we have an encoder-decoder model, but
|
| 193 |
+
# avoids getting tangled up when we just have a model with a layer called 'encoder'
|
| 194 |
+
if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
|
| 195 |
+
main_input_name = self.model.encoder.main_input_name
|
| 196 |
+
else:
|
| 197 |
+
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
| 198 |
+
|
| 199 |
+
if self.use_xla_generation and self.generation_function is None:
|
| 200 |
+
|
| 201 |
+
def generation_function(inputs, attention_mask):
|
| 202 |
+
return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)
|
| 203 |
+
|
| 204 |
+
self.generation_function = tf.function(generation_function, jit_compile=True)
|
| 205 |
+
|
| 206 |
+
prediction_list = []
|
| 207 |
+
label_list = []
|
| 208 |
+
|
| 209 |
+
# The whole predict/generate loop is handled inside this method
|
| 210 |
+
for batch in self.eval_dataset:
|
| 211 |
+
if isinstance(batch, tuple):
|
| 212 |
+
batch, labels = batch
|
| 213 |
+
else:
|
| 214 |
+
labels = None
|
| 215 |
+
if self.predict_with_generate:
|
| 216 |
+
if isinstance(batch, dict):
|
| 217 |
+
generation_inputs = batch[main_input_name]
|
| 218 |
+
attention_mask = batch.get("attention_mask", None)
|
| 219 |
+
else:
|
| 220 |
+
generation_inputs = batch
|
| 221 |
+
attention_mask = None
|
| 222 |
+
if self.use_xla_generation:
|
| 223 |
+
predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
|
| 224 |
+
else:
|
| 225 |
+
predictions = self.model.generate(
|
| 226 |
+
generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
predictions = self.model.predict_on_batch(batch)
|
| 230 |
+
if isinstance(predictions, dict):
|
| 231 |
+
# This converts any dict-subclass to a regular dict
|
| 232 |
+
# Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
|
| 233 |
+
predictions = dict(predictions)
|
| 234 |
+
if self.output_cols is not None:
|
| 235 |
+
predictions = {key: predictions[key] for key in self.output_cols}
|
| 236 |
+
else:
|
| 237 |
+
predictions = {
|
| 238 |
+
key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]
|
| 239 |
+
}
|
| 240 |
+
prediction_list.append(predictions)
|
| 241 |
+
if not self.use_keras_label:
|
| 242 |
+
labels = {key: batch[key].numpy() for key in self.label_cols}
|
| 243 |
+
elif isinstance(labels, dict):
|
| 244 |
+
labels = {key: array.numpy() for key, array in labels.items()}
|
| 245 |
+
elif isinstance(labels, list) or isinstance(labels, tuple):
|
| 246 |
+
labels = [array.numpy() for array in labels]
|
| 247 |
+
elif isinstance(labels, tf.Tensor):
|
| 248 |
+
labels = labels.numpy()
|
| 249 |
+
else:
|
| 250 |
+
raise TypeError(f"Confused by labels of type {type(labels)}")
|
| 251 |
+
label_list.append(labels)
|
| 252 |
+
|
| 253 |
+
all_preds = self._postprocess_predictions_or_labels(prediction_list)
|
| 254 |
+
all_labels = self._postprocess_predictions_or_labels(label_list)
|
| 255 |
+
|
| 256 |
+
metric_output = self.metric_fn((all_preds, all_labels))
|
| 257 |
+
if not isinstance(metric_output, dict):
|
| 258 |
+
raise TypeError(
|
| 259 |
+
f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
|
| 260 |
+
)
|
| 261 |
+
# This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch
|
| 262 |
+
# in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of
|
| 263 |
+
# new keys in there, which will then get read by the History callback and treated like any other metric value.
|
| 264 |
+
# I promise that I have it in writing from Chollet that this is okay.
|
| 265 |
+
logs.update(metric_output)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class PushToHubCallback(keras.callbacks.Callback):
|
| 269 |
+
"""
|
| 270 |
+
Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
|
| 271 |
+
be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
|
| 272 |
+
as with the `from_pretrained` method.
|
| 273 |
+
|
| 274 |
+
```py
|
| 275 |
+
from transformers.keras_callbacks import PushToHubCallback
|
| 276 |
+
|
| 277 |
+
push_to_hub_callback = PushToHubCallback(
|
| 278 |
+
output_dir="./model_save",
|
| 279 |
+
tokenizer=tokenizer,
|
| 280 |
+
hub_model_id="gpt5-7xlarge",
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
model.fit(train_dataset, callbacks=[push_to_hub_callback])
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
output_dir (`str`):
|
| 288 |
+
The output directory where the model predictions and checkpoints will be written and synced with the
|
| 289 |
+
repository on the Hub.
|
| 290 |
+
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
|
| 291 |
+
The checkpoint save strategy to adopt during training. Possible values are:
|
| 292 |
+
|
| 293 |
+
- `"no"`: Save is done at the end of training.
|
| 294 |
+
- `"epoch"`: Save is done at the end of each epoch.
|
| 295 |
+
- `"steps"`: Save is done every `save_steps`
|
| 296 |
+
save_steps (`int`, *optional*):
|
| 297 |
+
The number of steps between saves when using the "steps" `save_strategy`.
|
| 298 |
+
tokenizer (`PreTrainedTokenizerBase`, *optional*):
|
| 299 |
+
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
|
| 300 |
+
hub_model_id (`str`, *optional*):
|
| 301 |
+
The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
|
| 302 |
+
which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
|
| 303 |
+
for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
|
| 304 |
+
`"organization_name/model"`.
|
| 305 |
+
|
| 306 |
+
Will default to the name of `output_dir`.
|
| 307 |
+
hub_token (`str`, *optional*):
|
| 308 |
+
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
| 309 |
+
`huggingface-cli login`.
|
| 310 |
+
checkpoint (`bool`, *optional*, defaults to `False`):
|
| 311 |
+
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
|
| 312 |
+
resumed. Only usable when `save_strategy` is `"epoch"`.
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
output_dir: Union[str, Path],
|
| 318 |
+
save_strategy: Union[str, IntervalStrategy] = "epoch",
|
| 319 |
+
save_steps: Optional[int] = None,
|
| 320 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 321 |
+
hub_model_id: Optional[str] = None,
|
| 322 |
+
hub_token: Optional[str] = None,
|
| 323 |
+
checkpoint: bool = False,
|
| 324 |
+
**model_card_args,
|
| 325 |
+
):
|
| 326 |
+
super().__init__()
|
| 327 |
+
if checkpoint and save_strategy != "epoch":
|
| 328 |
+
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
|
| 329 |
+
if isinstance(save_strategy, str):
|
| 330 |
+
save_strategy = IntervalStrategy(save_strategy.lower())
|
| 331 |
+
self.save_strategy = save_strategy
|
| 332 |
+
if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
|
| 333 |
+
raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
|
| 334 |
+
self.save_steps = save_steps
|
| 335 |
+
output_dir = Path(output_dir)
|
| 336 |
+
|
| 337 |
+
# Create repo and retrieve repo_id
|
| 338 |
+
if hub_model_id is None:
|
| 339 |
+
hub_model_id = output_dir.absolute().name
|
| 340 |
+
self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
|
| 341 |
+
|
| 342 |
+
self.output_dir = output_dir
|
| 343 |
+
self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
|
| 344 |
+
|
| 345 |
+
self.tokenizer = tokenizer
|
| 346 |
+
self.last_job = None
|
| 347 |
+
self.checkpoint = checkpoint
|
| 348 |
+
self.training_history = None
|
| 349 |
+
self.model_card_args = model_card_args
|
| 350 |
+
|
| 351 |
+
def on_train_begin(self, logs=None):
|
| 352 |
+
# Although we can access model.history, we have no guarantees that the History callback will fire before this
|
| 353 |
+
# one, so we keep track of it here too
|
| 354 |
+
self.training_history = []
|
| 355 |
+
|
| 356 |
+
def on_train_batch_end(self, batch, logs=None):
|
| 357 |
+
if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0:
|
| 358 |
+
if self.last_job is not None and not self.last_job.is_done:
|
| 359 |
+
return # The last upload is still running, don't start another
|
| 360 |
+
self.model.save_pretrained(self.output_dir)
|
| 361 |
+
if self.tokenizer is not None:
|
| 362 |
+
self.tokenizer.save_pretrained(self.output_dir)
|
| 363 |
+
_, self.last_job = self.repo.push_to_hub(
|
| 364 |
+
commit_message=f"Training in progress steps {batch}", blocking=False
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 368 |
+
logs = logs.copy() # Don't accidentally write things that Keras will read later
|
| 369 |
+
if "epoch" not in logs:
|
| 370 |
+
logs["epoch"] = epoch
|
| 371 |
+
self.training_history.append(logs)
|
| 372 |
+
if self.save_strategy == IntervalStrategy.EPOCH:
|
| 373 |
+
if self.last_job is not None and not self.last_job.is_done:
|
| 374 |
+
return # The last upload is still running, don't start another
|
| 375 |
+
self.model.save_pretrained(self.output_dir)
|
| 376 |
+
if self.tokenizer is not None:
|
| 377 |
+
self.tokenizer.save_pretrained(self.output_dir)
|
| 378 |
+
if self.checkpoint:
|
| 379 |
+
checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
|
| 380 |
+
self.model._save_checkpoint(checkpoint_dir, epoch)
|
| 381 |
+
train_summary = TrainingSummary.from_keras(
|
| 382 |
+
model=self.model,
|
| 383 |
+
model_name=self.hub_model_id,
|
| 384 |
+
keras_history=self.training_history,
|
| 385 |
+
**self.model_card_args,
|
| 386 |
+
)
|
| 387 |
+
model_card = train_summary.to_model_card()
|
| 388 |
+
with (self.output_dir / "README.md").open("w") as f:
|
| 389 |
+
f.write(model_card)
|
| 390 |
+
_, self.last_job = self.repo.push_to_hub(
|
| 391 |
+
commit_message=f"Training in progress epoch {epoch}", blocking=False
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
def on_train_end(self, logs=None):
|
| 395 |
+
# Makes sure the latest version of the model is uploaded
|
| 396 |
+
if self.last_job is not None and not self.last_job.is_done:
|
| 397 |
+
logging.info("Pushing the last epoch to the Hub, this may take a while...")
|
| 398 |
+
while not self.last_job.is_done:
|
| 399 |
+
sleep(1)
|
| 400 |
+
else:
|
| 401 |
+
self.model.save_pretrained(self.output_dir)
|
| 402 |
+
if self.tokenizer is not None:
|
| 403 |
+
self.tokenizer.save_pretrained(self.output_dir)
|
| 404 |
+
train_summary = TrainingSummary.from_keras(
|
| 405 |
+
model=self.model,
|
| 406 |
+
model_name=self.hub_model_id,
|
| 407 |
+
keras_history=self.training_history,
|
| 408 |
+
**self.model_card_args,
|
| 409 |
+
)
|
| 410 |
+
model_card = train_summary.to_model_card()
|
| 411 |
+
with (self.output_dir / "README.md").open("w") as f:
|
| 412 |
+
f.write(model_card)
|
| 413 |
+
self.repo.push_to_hub(commit_message="End of training", blocking=True)
|
.venv/lib/python3.11/site-packages/transformers/kernels/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh
ADDED
|
@@ -0,0 +1,1327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************
|
| 7 |
+
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
|
| 8 |
+
* Copyright (c) 2018 Microsoft
|
| 9 |
+
**************************************************************************
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include <cstdio>
|
| 13 |
+
#include <algorithm>
|
| 14 |
+
#include <cstring>
|
| 15 |
+
|
| 16 |
+
#include <ATen/ATen.h>
|
| 17 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 18 |
+
|
| 19 |
+
#include <THC/THCAtomics.cuh>
|
| 20 |
+
|
| 21 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
| 22 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
| 23 |
+
i < (n); \
|
| 24 |
+
i += blockDim.x * gridDim.x)
|
| 25 |
+
|
| 26 |
+
const int CUDA_NUM_THREADS = 1024;
|
| 27 |
+
inline int GET_BLOCKS(const int N, const int num_threads)
|
| 28 |
+
{
|
| 29 |
+
return (N + num_threads - 1) / num_threads;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
template <typename scalar_t>
|
| 34 |
+
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
|
| 35 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 36 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
|
| 37 |
+
{
|
| 38 |
+
const int h_low = floor(h);
|
| 39 |
+
const int w_low = floor(w);
|
| 40 |
+
const int h_high = h_low + 1;
|
| 41 |
+
const int w_high = w_low + 1;
|
| 42 |
+
|
| 43 |
+
const scalar_t lh = h - h_low;
|
| 44 |
+
const scalar_t lw = w - w_low;
|
| 45 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 46 |
+
|
| 47 |
+
const int w_stride = nheads * channels;
|
| 48 |
+
const int h_stride = width * w_stride;
|
| 49 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 50 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 51 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 52 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 53 |
+
const int base_ptr = m * channels + c;
|
| 54 |
+
|
| 55 |
+
scalar_t v1 = 0;
|
| 56 |
+
if (h_low >= 0 && w_low >= 0)
|
| 57 |
+
{
|
| 58 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 59 |
+
v1 = bottom_data[ptr1];
|
| 60 |
+
}
|
| 61 |
+
scalar_t v2 = 0;
|
| 62 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 63 |
+
{
|
| 64 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 65 |
+
v2 = bottom_data[ptr2];
|
| 66 |
+
}
|
| 67 |
+
scalar_t v3 = 0;
|
| 68 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 69 |
+
{
|
| 70 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 71 |
+
v3 = bottom_data[ptr3];
|
| 72 |
+
}
|
| 73 |
+
scalar_t v4 = 0;
|
| 74 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 75 |
+
{
|
| 76 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 77 |
+
v4 = bottom_data[ptr4];
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 81 |
+
|
| 82 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 83 |
+
return val;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
template <typename scalar_t>
|
| 88 |
+
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
|
| 89 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 90 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 91 |
+
const scalar_t &top_grad,
|
| 92 |
+
const scalar_t &attn_weight,
|
| 93 |
+
scalar_t* &grad_value,
|
| 94 |
+
scalar_t* grad_sampling_loc,
|
| 95 |
+
scalar_t* grad_attn_weight)
|
| 96 |
+
{
|
| 97 |
+
const int h_low = floor(h);
|
| 98 |
+
const int w_low = floor(w);
|
| 99 |
+
const int h_high = h_low + 1;
|
| 100 |
+
const int w_high = w_low + 1;
|
| 101 |
+
|
| 102 |
+
const scalar_t lh = h - h_low;
|
| 103 |
+
const scalar_t lw = w - w_low;
|
| 104 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 105 |
+
|
| 106 |
+
const int w_stride = nheads * channels;
|
| 107 |
+
const int h_stride = width * w_stride;
|
| 108 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 109 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 110 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 111 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 112 |
+
const int base_ptr = m * channels + c;
|
| 113 |
+
|
| 114 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 115 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 116 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 117 |
+
|
| 118 |
+
scalar_t v1 = 0;
|
| 119 |
+
if (h_low >= 0 && w_low >= 0)
|
| 120 |
+
{
|
| 121 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 122 |
+
v1 = bottom_data[ptr1];
|
| 123 |
+
grad_h_weight -= hw * v1;
|
| 124 |
+
grad_w_weight -= hh * v1;
|
| 125 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 126 |
+
}
|
| 127 |
+
scalar_t v2 = 0;
|
| 128 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 129 |
+
{
|
| 130 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 131 |
+
v2 = bottom_data[ptr2];
|
| 132 |
+
grad_h_weight -= lw * v2;
|
| 133 |
+
grad_w_weight += hh * v2;
|
| 134 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 135 |
+
}
|
| 136 |
+
scalar_t v3 = 0;
|
| 137 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 138 |
+
{
|
| 139 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 140 |
+
v3 = bottom_data[ptr3];
|
| 141 |
+
grad_h_weight += hw * v3;
|
| 142 |
+
grad_w_weight -= lh * v3;
|
| 143 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 144 |
+
}
|
| 145 |
+
scalar_t v4 = 0;
|
| 146 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 147 |
+
{
|
| 148 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 149 |
+
v4 = bottom_data[ptr4];
|
| 150 |
+
grad_h_weight += lw * v4;
|
| 151 |
+
grad_w_weight += lh * v4;
|
| 152 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 156 |
+
*grad_attn_weight = top_grad * val;
|
| 157 |
+
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
|
| 158 |
+
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
template <typename scalar_t>
|
| 163 |
+
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
|
| 164 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 165 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 166 |
+
const scalar_t &top_grad,
|
| 167 |
+
const scalar_t &attn_weight,
|
| 168 |
+
scalar_t* &grad_value,
|
| 169 |
+
scalar_t* grad_sampling_loc,
|
| 170 |
+
scalar_t* grad_attn_weight)
|
| 171 |
+
{
|
| 172 |
+
const int h_low = floor(h);
|
| 173 |
+
const int w_low = floor(w);
|
| 174 |
+
const int h_high = h_low + 1;
|
| 175 |
+
const int w_high = w_low + 1;
|
| 176 |
+
|
| 177 |
+
const scalar_t lh = h - h_low;
|
| 178 |
+
const scalar_t lw = w - w_low;
|
| 179 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 180 |
+
|
| 181 |
+
const int w_stride = nheads * channels;
|
| 182 |
+
const int h_stride = width * w_stride;
|
| 183 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 184 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 185 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 186 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 187 |
+
const int base_ptr = m * channels + c;
|
| 188 |
+
|
| 189 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 190 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 191 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 192 |
+
|
| 193 |
+
scalar_t v1 = 0;
|
| 194 |
+
if (h_low >= 0 && w_low >= 0)
|
| 195 |
+
{
|
| 196 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 197 |
+
v1 = bottom_data[ptr1];
|
| 198 |
+
grad_h_weight -= hw * v1;
|
| 199 |
+
grad_w_weight -= hh * v1;
|
| 200 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 201 |
+
}
|
| 202 |
+
scalar_t v2 = 0;
|
| 203 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 204 |
+
{
|
| 205 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 206 |
+
v2 = bottom_data[ptr2];
|
| 207 |
+
grad_h_weight -= lw * v2;
|
| 208 |
+
grad_w_weight += hh * v2;
|
| 209 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 210 |
+
}
|
| 211 |
+
scalar_t v3 = 0;
|
| 212 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 213 |
+
{
|
| 214 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 215 |
+
v3 = bottom_data[ptr3];
|
| 216 |
+
grad_h_weight += hw * v3;
|
| 217 |
+
grad_w_weight -= lh * v3;
|
| 218 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 219 |
+
}
|
| 220 |
+
scalar_t v4 = 0;
|
| 221 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 222 |
+
{
|
| 223 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 224 |
+
v4 = bottom_data[ptr4];
|
| 225 |
+
grad_h_weight += lw * v4;
|
| 226 |
+
grad_w_weight += lh * v4;
|
| 227 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 231 |
+
atomicAdd(grad_attn_weight, top_grad * val);
|
| 232 |
+
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
|
| 233 |
+
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
template <typename scalar_t>
|
| 238 |
+
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
|
| 239 |
+
const scalar_t *data_value,
|
| 240 |
+
const int64_t *data_spatial_shapes,
|
| 241 |
+
const int64_t *data_level_start_index,
|
| 242 |
+
const scalar_t *data_sampling_loc,
|
| 243 |
+
const scalar_t *data_attn_weight,
|
| 244 |
+
const int batch_size,
|
| 245 |
+
const int spatial_size,
|
| 246 |
+
const int num_heads,
|
| 247 |
+
const int channels,
|
| 248 |
+
const int num_levels,
|
| 249 |
+
const int num_query,
|
| 250 |
+
const int num_point,
|
| 251 |
+
scalar_t *data_col)
|
| 252 |
+
{
|
| 253 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 254 |
+
{
|
| 255 |
+
int _temp = index;
|
| 256 |
+
const int c_col = _temp % channels;
|
| 257 |
+
_temp /= channels;
|
| 258 |
+
const int sampling_index = _temp;
|
| 259 |
+
const int m_col = _temp % num_heads;
|
| 260 |
+
_temp /= num_heads;
|
| 261 |
+
const int q_col = _temp % num_query;
|
| 262 |
+
_temp /= num_query;
|
| 263 |
+
const int b_col = _temp;
|
| 264 |
+
|
| 265 |
+
scalar_t *data_col_ptr = data_col + index;
|
| 266 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 267 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 268 |
+
const int qid_stride = num_heads * channels;
|
| 269 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 270 |
+
scalar_t col = 0;
|
| 271 |
+
|
| 272 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 273 |
+
{
|
| 274 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 275 |
+
const int spatial_h_ptr = l_col << 1;
|
| 276 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 277 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 278 |
+
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
|
| 279 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 280 |
+
{
|
| 281 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 282 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 283 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 284 |
+
|
| 285 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 286 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 287 |
+
|
| 288 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 289 |
+
{
|
| 290 |
+
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
data_weight_ptr += 1;
|
| 294 |
+
data_loc_w_ptr += 2;
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
*data_col_ptr = col;
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 302 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
|
| 303 |
+
const scalar_t *grad_col,
|
| 304 |
+
const scalar_t *data_value,
|
| 305 |
+
const int64_t *data_spatial_shapes,
|
| 306 |
+
const int64_t *data_level_start_index,
|
| 307 |
+
const scalar_t *data_sampling_loc,
|
| 308 |
+
const scalar_t *data_attn_weight,
|
| 309 |
+
const int batch_size,
|
| 310 |
+
const int spatial_size,
|
| 311 |
+
const int num_heads,
|
| 312 |
+
const int channels,
|
| 313 |
+
const int num_levels,
|
| 314 |
+
const int num_query,
|
| 315 |
+
const int num_point,
|
| 316 |
+
scalar_t *grad_value,
|
| 317 |
+
scalar_t *grad_sampling_loc,
|
| 318 |
+
scalar_t *grad_attn_weight)
|
| 319 |
+
{
|
| 320 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 321 |
+
{
|
| 322 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 323 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 324 |
+
unsigned int tid = threadIdx.x;
|
| 325 |
+
int _temp = index;
|
| 326 |
+
const int c_col = _temp % channels;
|
| 327 |
+
_temp /= channels;
|
| 328 |
+
const int sampling_index = _temp;
|
| 329 |
+
const int m_col = _temp % num_heads;
|
| 330 |
+
_temp /= num_heads;
|
| 331 |
+
const int q_col = _temp % num_query;
|
| 332 |
+
_temp /= num_query;
|
| 333 |
+
const int b_col = _temp;
|
| 334 |
+
|
| 335 |
+
const scalar_t top_grad = grad_col[index];
|
| 336 |
+
|
| 337 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 338 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 339 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 340 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 341 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 342 |
+
const int grad_weight_stride = 1;
|
| 343 |
+
const int grad_loc_stride = 2;
|
| 344 |
+
const int qid_stride = num_heads * channels;
|
| 345 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 346 |
+
|
| 347 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 348 |
+
{
|
| 349 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 350 |
+
const int spatial_h_ptr = l_col << 1;
|
| 351 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 352 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 353 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 354 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 355 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 356 |
+
|
| 357 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 358 |
+
{
|
| 359 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 360 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 361 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 362 |
+
|
| 363 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 364 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 365 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 366 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 367 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 368 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 369 |
+
{
|
| 370 |
+
ms_deform_attn_col2im_bilinear(
|
| 371 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 372 |
+
top_grad, weight, grad_value_ptr,
|
| 373 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
__syncthreads();
|
| 377 |
+
if (tid == 0)
|
| 378 |
+
{
|
| 379 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 380 |
+
int sid=2;
|
| 381 |
+
for (unsigned int tid = 1; tid < blockSize; ++tid)
|
| 382 |
+
{
|
| 383 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 384 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 385 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 386 |
+
sid += 2;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
*grad_sampling_loc = _grad_w;
|
| 391 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 392 |
+
*grad_attn_weight = _grad_a;
|
| 393 |
+
}
|
| 394 |
+
__syncthreads();
|
| 395 |
+
|
| 396 |
+
data_weight_ptr += 1;
|
| 397 |
+
data_loc_w_ptr += 2;
|
| 398 |
+
grad_attn_weight += grad_weight_stride;
|
| 399 |
+
grad_sampling_loc += grad_loc_stride;
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 407 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
|
| 408 |
+
const scalar_t *grad_col,
|
| 409 |
+
const scalar_t *data_value,
|
| 410 |
+
const int64_t *data_spatial_shapes,
|
| 411 |
+
const int64_t *data_level_start_index,
|
| 412 |
+
const scalar_t *data_sampling_loc,
|
| 413 |
+
const scalar_t *data_attn_weight,
|
| 414 |
+
const int batch_size,
|
| 415 |
+
const int spatial_size,
|
| 416 |
+
const int num_heads,
|
| 417 |
+
const int channels,
|
| 418 |
+
const int num_levels,
|
| 419 |
+
const int num_query,
|
| 420 |
+
const int num_point,
|
| 421 |
+
scalar_t *grad_value,
|
| 422 |
+
scalar_t *grad_sampling_loc,
|
| 423 |
+
scalar_t *grad_attn_weight)
|
| 424 |
+
{
|
| 425 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 426 |
+
{
|
| 427 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 428 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 429 |
+
unsigned int tid = threadIdx.x;
|
| 430 |
+
int _temp = index;
|
| 431 |
+
const int c_col = _temp % channels;
|
| 432 |
+
_temp /= channels;
|
| 433 |
+
const int sampling_index = _temp;
|
| 434 |
+
const int m_col = _temp % num_heads;
|
| 435 |
+
_temp /= num_heads;
|
| 436 |
+
const int q_col = _temp % num_query;
|
| 437 |
+
_temp /= num_query;
|
| 438 |
+
const int b_col = _temp;
|
| 439 |
+
|
| 440 |
+
const scalar_t top_grad = grad_col[index];
|
| 441 |
+
|
| 442 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 443 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 444 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 445 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 446 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 447 |
+
const int grad_weight_stride = 1;
|
| 448 |
+
const int grad_loc_stride = 2;
|
| 449 |
+
const int qid_stride = num_heads * channels;
|
| 450 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 451 |
+
|
| 452 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 453 |
+
{
|
| 454 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 455 |
+
const int spatial_h_ptr = l_col << 1;
|
| 456 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 457 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 458 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 459 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 460 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 461 |
+
|
| 462 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 463 |
+
{
|
| 464 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 465 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 466 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 467 |
+
|
| 468 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 469 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 470 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 471 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 472 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 473 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 474 |
+
{
|
| 475 |
+
ms_deform_attn_col2im_bilinear(
|
| 476 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 477 |
+
top_grad, weight, grad_value_ptr,
|
| 478 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
__syncthreads();
|
| 482 |
+
|
| 483 |
+
for (unsigned int s=blockSize/2; s>0; s>>=1)
|
| 484 |
+
{
|
| 485 |
+
if (tid < s) {
|
| 486 |
+
const unsigned int xid1 = tid << 1;
|
| 487 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 488 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 489 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 490 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 491 |
+
}
|
| 492 |
+
__syncthreads();
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
if (tid == 0)
|
| 496 |
+
{
|
| 497 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 498 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 499 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 500 |
+
}
|
| 501 |
+
__syncthreads();
|
| 502 |
+
|
| 503 |
+
data_weight_ptr += 1;
|
| 504 |
+
data_loc_w_ptr += 2;
|
| 505 |
+
grad_attn_weight += grad_weight_stride;
|
| 506 |
+
grad_sampling_loc += grad_loc_stride;
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
template <typename scalar_t>
|
| 514 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
|
| 515 |
+
const scalar_t *grad_col,
|
| 516 |
+
const scalar_t *data_value,
|
| 517 |
+
const int64_t *data_spatial_shapes,
|
| 518 |
+
const int64_t *data_level_start_index,
|
| 519 |
+
const scalar_t *data_sampling_loc,
|
| 520 |
+
const scalar_t *data_attn_weight,
|
| 521 |
+
const int batch_size,
|
| 522 |
+
const int spatial_size,
|
| 523 |
+
const int num_heads,
|
| 524 |
+
const int channels,
|
| 525 |
+
const int num_levels,
|
| 526 |
+
const int num_query,
|
| 527 |
+
const int num_point,
|
| 528 |
+
scalar_t *grad_value,
|
| 529 |
+
scalar_t *grad_sampling_loc,
|
| 530 |
+
scalar_t *grad_attn_weight)
|
| 531 |
+
{
|
| 532 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 533 |
+
{
|
| 534 |
+
extern __shared__ int _s[];
|
| 535 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 536 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 537 |
+
unsigned int tid = threadIdx.x;
|
| 538 |
+
int _temp = index;
|
| 539 |
+
const int c_col = _temp % channels;
|
| 540 |
+
_temp /= channels;
|
| 541 |
+
const int sampling_index = _temp;
|
| 542 |
+
const int m_col = _temp % num_heads;
|
| 543 |
+
_temp /= num_heads;
|
| 544 |
+
const int q_col = _temp % num_query;
|
| 545 |
+
_temp /= num_query;
|
| 546 |
+
const int b_col = _temp;
|
| 547 |
+
|
| 548 |
+
const scalar_t top_grad = grad_col[index];
|
| 549 |
+
|
| 550 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 551 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 552 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 553 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 554 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 555 |
+
const int grad_weight_stride = 1;
|
| 556 |
+
const int grad_loc_stride = 2;
|
| 557 |
+
const int qid_stride = num_heads * channels;
|
| 558 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 559 |
+
|
| 560 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 561 |
+
{
|
| 562 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 563 |
+
const int spatial_h_ptr = l_col << 1;
|
| 564 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 565 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 566 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 567 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 568 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 569 |
+
|
| 570 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 571 |
+
{
|
| 572 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 573 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 574 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 575 |
+
|
| 576 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 577 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 578 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 579 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 580 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 581 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 582 |
+
{
|
| 583 |
+
ms_deform_attn_col2im_bilinear(
|
| 584 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 585 |
+
top_grad, weight, grad_value_ptr,
|
| 586 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
__syncthreads();
|
| 590 |
+
if (tid == 0)
|
| 591 |
+
{
|
| 592 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 593 |
+
int sid=2;
|
| 594 |
+
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
|
| 595 |
+
{
|
| 596 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 597 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 598 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 599 |
+
sid += 2;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
*grad_sampling_loc = _grad_w;
|
| 604 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 605 |
+
*grad_attn_weight = _grad_a;
|
| 606 |
+
}
|
| 607 |
+
__syncthreads();
|
| 608 |
+
|
| 609 |
+
data_weight_ptr += 1;
|
| 610 |
+
data_loc_w_ptr += 2;
|
| 611 |
+
grad_attn_weight += grad_weight_stride;
|
| 612 |
+
grad_sampling_loc += grad_loc_stride;
|
| 613 |
+
}
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
template <typename scalar_t>
|
| 619 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
|
| 620 |
+
const scalar_t *grad_col,
|
| 621 |
+
const scalar_t *data_value,
|
| 622 |
+
const int64_t *data_spatial_shapes,
|
| 623 |
+
const int64_t *data_level_start_index,
|
| 624 |
+
const scalar_t *data_sampling_loc,
|
| 625 |
+
const scalar_t *data_attn_weight,
|
| 626 |
+
const int batch_size,
|
| 627 |
+
const int spatial_size,
|
| 628 |
+
const int num_heads,
|
| 629 |
+
const int channels,
|
| 630 |
+
const int num_levels,
|
| 631 |
+
const int num_query,
|
| 632 |
+
const int num_point,
|
| 633 |
+
scalar_t *grad_value,
|
| 634 |
+
scalar_t *grad_sampling_loc,
|
| 635 |
+
scalar_t *grad_attn_weight)
|
| 636 |
+
{
|
| 637 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 638 |
+
{
|
| 639 |
+
extern __shared__ int _s[];
|
| 640 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 641 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 642 |
+
unsigned int tid = threadIdx.x;
|
| 643 |
+
int _temp = index;
|
| 644 |
+
const int c_col = _temp % channels;
|
| 645 |
+
_temp /= channels;
|
| 646 |
+
const int sampling_index = _temp;
|
| 647 |
+
const int m_col = _temp % num_heads;
|
| 648 |
+
_temp /= num_heads;
|
| 649 |
+
const int q_col = _temp % num_query;
|
| 650 |
+
_temp /= num_query;
|
| 651 |
+
const int b_col = _temp;
|
| 652 |
+
|
| 653 |
+
const scalar_t top_grad = grad_col[index];
|
| 654 |
+
|
| 655 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 656 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 657 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 658 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 659 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 660 |
+
const int grad_weight_stride = 1;
|
| 661 |
+
const int grad_loc_stride = 2;
|
| 662 |
+
const int qid_stride = num_heads * channels;
|
| 663 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 664 |
+
|
| 665 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 666 |
+
{
|
| 667 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 668 |
+
const int spatial_h_ptr = l_col << 1;
|
| 669 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 670 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 671 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 672 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 673 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 674 |
+
|
| 675 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 676 |
+
{
|
| 677 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 678 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 679 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 680 |
+
|
| 681 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 682 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 683 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 684 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 685 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 686 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 687 |
+
{
|
| 688 |
+
ms_deform_attn_col2im_bilinear(
|
| 689 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 690 |
+
top_grad, weight, grad_value_ptr,
|
| 691 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
__syncthreads();
|
| 695 |
+
|
| 696 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 697 |
+
{
|
| 698 |
+
if (tid < s) {
|
| 699 |
+
const unsigned int xid1 = tid << 1;
|
| 700 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 701 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 702 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 703 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 704 |
+
if (tid + (s << 1) < spre)
|
| 705 |
+
{
|
| 706 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 707 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 708 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 709 |
+
}
|
| 710 |
+
}
|
| 711 |
+
__syncthreads();
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
if (tid == 0)
|
| 715 |
+
{
|
| 716 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 717 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 718 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 719 |
+
}
|
| 720 |
+
__syncthreads();
|
| 721 |
+
|
| 722 |
+
data_weight_ptr += 1;
|
| 723 |
+
data_loc_w_ptr += 2;
|
| 724 |
+
grad_attn_weight += grad_weight_stride;
|
| 725 |
+
grad_sampling_loc += grad_loc_stride;
|
| 726 |
+
}
|
| 727 |
+
}
|
| 728 |
+
}
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
template <typename scalar_t>
|
| 732 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
|
| 733 |
+
const scalar_t *grad_col,
|
| 734 |
+
const scalar_t *data_value,
|
| 735 |
+
const int64_t *data_spatial_shapes,
|
| 736 |
+
const int64_t *data_level_start_index,
|
| 737 |
+
const scalar_t *data_sampling_loc,
|
| 738 |
+
const scalar_t *data_attn_weight,
|
| 739 |
+
const int batch_size,
|
| 740 |
+
const int spatial_size,
|
| 741 |
+
const int num_heads,
|
| 742 |
+
const int channels,
|
| 743 |
+
const int num_levels,
|
| 744 |
+
const int num_query,
|
| 745 |
+
const int num_point,
|
| 746 |
+
scalar_t *grad_value,
|
| 747 |
+
scalar_t *grad_sampling_loc,
|
| 748 |
+
scalar_t *grad_attn_weight)
|
| 749 |
+
{
|
| 750 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 751 |
+
{
|
| 752 |
+
extern __shared__ int _s[];
|
| 753 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 754 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 755 |
+
unsigned int tid = threadIdx.x;
|
| 756 |
+
int _temp = index;
|
| 757 |
+
const int c_col = _temp % channels;
|
| 758 |
+
_temp /= channels;
|
| 759 |
+
const int sampling_index = _temp;
|
| 760 |
+
const int m_col = _temp % num_heads;
|
| 761 |
+
_temp /= num_heads;
|
| 762 |
+
const int q_col = _temp % num_query;
|
| 763 |
+
_temp /= num_query;
|
| 764 |
+
const int b_col = _temp;
|
| 765 |
+
|
| 766 |
+
const scalar_t top_grad = grad_col[index];
|
| 767 |
+
|
| 768 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 769 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 770 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 771 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 772 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 773 |
+
const int grad_weight_stride = 1;
|
| 774 |
+
const int grad_loc_stride = 2;
|
| 775 |
+
const int qid_stride = num_heads * channels;
|
| 776 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 777 |
+
|
| 778 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 779 |
+
{
|
| 780 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 781 |
+
const int spatial_h_ptr = l_col << 1;
|
| 782 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 783 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 784 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 785 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 786 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 787 |
+
|
| 788 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 789 |
+
{
|
| 790 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 791 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 792 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 793 |
+
|
| 794 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 795 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 796 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 797 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 798 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 799 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 800 |
+
{
|
| 801 |
+
ms_deform_attn_col2im_bilinear(
|
| 802 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 803 |
+
top_grad, weight, grad_value_ptr,
|
| 804 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
__syncthreads();
|
| 808 |
+
|
| 809 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 810 |
+
{
|
| 811 |
+
if (tid < s) {
|
| 812 |
+
const unsigned int xid1 = tid << 1;
|
| 813 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 814 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 815 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 816 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 817 |
+
if (tid + (s << 1) < spre)
|
| 818 |
+
{
|
| 819 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 820 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 821 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 822 |
+
}
|
| 823 |
+
}
|
| 824 |
+
__syncthreads();
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
if (tid == 0)
|
| 828 |
+
{
|
| 829 |
+
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
|
| 830 |
+
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
|
| 831 |
+
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
|
| 832 |
+
}
|
| 833 |
+
__syncthreads();
|
| 834 |
+
|
| 835 |
+
data_weight_ptr += 1;
|
| 836 |
+
data_loc_w_ptr += 2;
|
| 837 |
+
grad_attn_weight += grad_weight_stride;
|
| 838 |
+
grad_sampling_loc += grad_loc_stride;
|
| 839 |
+
}
|
| 840 |
+
}
|
| 841 |
+
}
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
template <typename scalar_t>
|
| 846 |
+
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
|
| 847 |
+
const scalar_t *grad_col,
|
| 848 |
+
const scalar_t *data_value,
|
| 849 |
+
const int64_t *data_spatial_shapes,
|
| 850 |
+
const int64_t *data_level_start_index,
|
| 851 |
+
const scalar_t *data_sampling_loc,
|
| 852 |
+
const scalar_t *data_attn_weight,
|
| 853 |
+
const int batch_size,
|
| 854 |
+
const int spatial_size,
|
| 855 |
+
const int num_heads,
|
| 856 |
+
const int channels,
|
| 857 |
+
const int num_levels,
|
| 858 |
+
const int num_query,
|
| 859 |
+
const int num_point,
|
| 860 |
+
scalar_t *grad_value,
|
| 861 |
+
scalar_t *grad_sampling_loc,
|
| 862 |
+
scalar_t *grad_attn_weight)
|
| 863 |
+
{
|
| 864 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 865 |
+
{
|
| 866 |
+
int _temp = index;
|
| 867 |
+
const int c_col = _temp % channels;
|
| 868 |
+
_temp /= channels;
|
| 869 |
+
const int sampling_index = _temp;
|
| 870 |
+
const int m_col = _temp % num_heads;
|
| 871 |
+
_temp /= num_heads;
|
| 872 |
+
const int q_col = _temp % num_query;
|
| 873 |
+
_temp /= num_query;
|
| 874 |
+
const int b_col = _temp;
|
| 875 |
+
|
| 876 |
+
const scalar_t top_grad = grad_col[index];
|
| 877 |
+
|
| 878 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 879 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 880 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 881 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 882 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 883 |
+
const int grad_weight_stride = 1;
|
| 884 |
+
const int grad_loc_stride = 2;
|
| 885 |
+
const int qid_stride = num_heads * channels;
|
| 886 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 887 |
+
|
| 888 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 889 |
+
{
|
| 890 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 891 |
+
const int spatial_h_ptr = l_col << 1;
|
| 892 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 893 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 894 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 895 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 896 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 897 |
+
|
| 898 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 899 |
+
{
|
| 900 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 901 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 902 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 903 |
+
|
| 904 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 905 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 906 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 907 |
+
{
|
| 908 |
+
ms_deform_attn_col2im_bilinear_gm(
|
| 909 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 910 |
+
top_grad, weight, grad_value_ptr,
|
| 911 |
+
grad_sampling_loc, grad_attn_weight);
|
| 912 |
+
}
|
| 913 |
+
data_weight_ptr += 1;
|
| 914 |
+
data_loc_w_ptr += 2;
|
| 915 |
+
grad_attn_weight += grad_weight_stride;
|
| 916 |
+
grad_sampling_loc += grad_loc_stride;
|
| 917 |
+
}
|
| 918 |
+
}
|
| 919 |
+
}
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
template <typename scalar_t>
|
| 924 |
+
void ms_deformable_im2col_cuda(cudaStream_t stream,
|
| 925 |
+
const scalar_t* data_value,
|
| 926 |
+
const int64_t* data_spatial_shapes,
|
| 927 |
+
const int64_t* data_level_start_index,
|
| 928 |
+
const scalar_t* data_sampling_loc,
|
| 929 |
+
const scalar_t* data_attn_weight,
|
| 930 |
+
const int batch_size,
|
| 931 |
+
const int spatial_size,
|
| 932 |
+
const int num_heads,
|
| 933 |
+
const int channels,
|
| 934 |
+
const int num_levels,
|
| 935 |
+
const int num_query,
|
| 936 |
+
const int num_point,
|
| 937 |
+
scalar_t* data_col)
|
| 938 |
+
{
|
| 939 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 940 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 941 |
+
const int num_threads = CUDA_NUM_THREADS;
|
| 942 |
+
ms_deformable_im2col_gpu_kernel<scalar_t>
|
| 943 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 944 |
+
0, stream>>>(
|
| 945 |
+
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
|
| 946 |
+
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
|
| 947 |
+
|
| 948 |
+
cudaError_t err = cudaGetLastError();
|
| 949 |
+
if (err != cudaSuccess)
|
| 950 |
+
{
|
| 951 |
+
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
template <typename scalar_t>
|
| 957 |
+
void ms_deformable_col2im_cuda(cudaStream_t stream,
|
| 958 |
+
const scalar_t* grad_col,
|
| 959 |
+
const scalar_t* data_value,
|
| 960 |
+
const int64_t * data_spatial_shapes,
|
| 961 |
+
const int64_t * data_level_start_index,
|
| 962 |
+
const scalar_t * data_sampling_loc,
|
| 963 |
+
const scalar_t * data_attn_weight,
|
| 964 |
+
const int batch_size,
|
| 965 |
+
const int spatial_size,
|
| 966 |
+
const int num_heads,
|
| 967 |
+
const int channels,
|
| 968 |
+
const int num_levels,
|
| 969 |
+
const int num_query,
|
| 970 |
+
const int num_point,
|
| 971 |
+
scalar_t* grad_value,
|
| 972 |
+
scalar_t* grad_sampling_loc,
|
| 973 |
+
scalar_t* grad_attn_weight)
|
| 974 |
+
{
|
| 975 |
+
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
|
| 976 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 977 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 978 |
+
if (channels > 1024)
|
| 979 |
+
{
|
| 980 |
+
if ((channels & 1023) == 0)
|
| 981 |
+
{
|
| 982 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
|
| 983 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 984 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 985 |
+
num_kernels,
|
| 986 |
+
grad_col,
|
| 987 |
+
data_value,
|
| 988 |
+
data_spatial_shapes,
|
| 989 |
+
data_level_start_index,
|
| 990 |
+
data_sampling_loc,
|
| 991 |
+
data_attn_weight,
|
| 992 |
+
batch_size,
|
| 993 |
+
spatial_size,
|
| 994 |
+
num_heads,
|
| 995 |
+
channels,
|
| 996 |
+
num_levels,
|
| 997 |
+
num_query,
|
| 998 |
+
num_point,
|
| 999 |
+
grad_value,
|
| 1000 |
+
grad_sampling_loc,
|
| 1001 |
+
grad_attn_weight);
|
| 1002 |
+
}
|
| 1003 |
+
else
|
| 1004 |
+
{
|
| 1005 |
+
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
|
| 1006 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1007 |
+
0, stream>>>(
|
| 1008 |
+
num_kernels,
|
| 1009 |
+
grad_col,
|
| 1010 |
+
data_value,
|
| 1011 |
+
data_spatial_shapes,
|
| 1012 |
+
data_level_start_index,
|
| 1013 |
+
data_sampling_loc,
|
| 1014 |
+
data_attn_weight,
|
| 1015 |
+
batch_size,
|
| 1016 |
+
spatial_size,
|
| 1017 |
+
num_heads,
|
| 1018 |
+
channels,
|
| 1019 |
+
num_levels,
|
| 1020 |
+
num_query,
|
| 1021 |
+
num_point,
|
| 1022 |
+
grad_value,
|
| 1023 |
+
grad_sampling_loc,
|
| 1024 |
+
grad_attn_weight);
|
| 1025 |
+
}
|
| 1026 |
+
}
|
| 1027 |
+
else{
|
| 1028 |
+
switch(channels)
|
| 1029 |
+
{
|
| 1030 |
+
case 1:
|
| 1031 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
|
| 1032 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1033 |
+
0, stream>>>(
|
| 1034 |
+
num_kernels,
|
| 1035 |
+
grad_col,
|
| 1036 |
+
data_value,
|
| 1037 |
+
data_spatial_shapes,
|
| 1038 |
+
data_level_start_index,
|
| 1039 |
+
data_sampling_loc,
|
| 1040 |
+
data_attn_weight,
|
| 1041 |
+
batch_size,
|
| 1042 |
+
spatial_size,
|
| 1043 |
+
num_heads,
|
| 1044 |
+
channels,
|
| 1045 |
+
num_levels,
|
| 1046 |
+
num_query,
|
| 1047 |
+
num_point,
|
| 1048 |
+
grad_value,
|
| 1049 |
+
grad_sampling_loc,
|
| 1050 |
+
grad_attn_weight);
|
| 1051 |
+
break;
|
| 1052 |
+
case 2:
|
| 1053 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
|
| 1054 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1055 |
+
0, stream>>>(
|
| 1056 |
+
num_kernels,
|
| 1057 |
+
grad_col,
|
| 1058 |
+
data_value,
|
| 1059 |
+
data_spatial_shapes,
|
| 1060 |
+
data_level_start_index,
|
| 1061 |
+
data_sampling_loc,
|
| 1062 |
+
data_attn_weight,
|
| 1063 |
+
batch_size,
|
| 1064 |
+
spatial_size,
|
| 1065 |
+
num_heads,
|
| 1066 |
+
channels,
|
| 1067 |
+
num_levels,
|
| 1068 |
+
num_query,
|
| 1069 |
+
num_point,
|
| 1070 |
+
grad_value,
|
| 1071 |
+
grad_sampling_loc,
|
| 1072 |
+
grad_attn_weight);
|
| 1073 |
+
break;
|
| 1074 |
+
case 4:
|
| 1075 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
|
| 1076 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1077 |
+
0, stream>>>(
|
| 1078 |
+
num_kernels,
|
| 1079 |
+
grad_col,
|
| 1080 |
+
data_value,
|
| 1081 |
+
data_spatial_shapes,
|
| 1082 |
+
data_level_start_index,
|
| 1083 |
+
data_sampling_loc,
|
| 1084 |
+
data_attn_weight,
|
| 1085 |
+
batch_size,
|
| 1086 |
+
spatial_size,
|
| 1087 |
+
num_heads,
|
| 1088 |
+
channels,
|
| 1089 |
+
num_levels,
|
| 1090 |
+
num_query,
|
| 1091 |
+
num_point,
|
| 1092 |
+
grad_value,
|
| 1093 |
+
grad_sampling_loc,
|
| 1094 |
+
grad_attn_weight);
|
| 1095 |
+
break;
|
| 1096 |
+
case 8:
|
| 1097 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
|
| 1098 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1099 |
+
0, stream>>>(
|
| 1100 |
+
num_kernels,
|
| 1101 |
+
grad_col,
|
| 1102 |
+
data_value,
|
| 1103 |
+
data_spatial_shapes,
|
| 1104 |
+
data_level_start_index,
|
| 1105 |
+
data_sampling_loc,
|
| 1106 |
+
data_attn_weight,
|
| 1107 |
+
batch_size,
|
| 1108 |
+
spatial_size,
|
| 1109 |
+
num_heads,
|
| 1110 |
+
channels,
|
| 1111 |
+
num_levels,
|
| 1112 |
+
num_query,
|
| 1113 |
+
num_point,
|
| 1114 |
+
grad_value,
|
| 1115 |
+
grad_sampling_loc,
|
| 1116 |
+
grad_attn_weight);
|
| 1117 |
+
break;
|
| 1118 |
+
case 16:
|
| 1119 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
|
| 1120 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1121 |
+
0, stream>>>(
|
| 1122 |
+
num_kernels,
|
| 1123 |
+
grad_col,
|
| 1124 |
+
data_value,
|
| 1125 |
+
data_spatial_shapes,
|
| 1126 |
+
data_level_start_index,
|
| 1127 |
+
data_sampling_loc,
|
| 1128 |
+
data_attn_weight,
|
| 1129 |
+
batch_size,
|
| 1130 |
+
spatial_size,
|
| 1131 |
+
num_heads,
|
| 1132 |
+
channels,
|
| 1133 |
+
num_levels,
|
| 1134 |
+
num_query,
|
| 1135 |
+
num_point,
|
| 1136 |
+
grad_value,
|
| 1137 |
+
grad_sampling_loc,
|
| 1138 |
+
grad_attn_weight);
|
| 1139 |
+
break;
|
| 1140 |
+
case 32:
|
| 1141 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
|
| 1142 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1143 |
+
0, stream>>>(
|
| 1144 |
+
num_kernels,
|
| 1145 |
+
grad_col,
|
| 1146 |
+
data_value,
|
| 1147 |
+
data_spatial_shapes,
|
| 1148 |
+
data_level_start_index,
|
| 1149 |
+
data_sampling_loc,
|
| 1150 |
+
data_attn_weight,
|
| 1151 |
+
batch_size,
|
| 1152 |
+
spatial_size,
|
| 1153 |
+
num_heads,
|
| 1154 |
+
channels,
|
| 1155 |
+
num_levels,
|
| 1156 |
+
num_query,
|
| 1157 |
+
num_point,
|
| 1158 |
+
grad_value,
|
| 1159 |
+
grad_sampling_loc,
|
| 1160 |
+
grad_attn_weight);
|
| 1161 |
+
break;
|
| 1162 |
+
case 64:
|
| 1163 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
|
| 1164 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1165 |
+
0, stream>>>(
|
| 1166 |
+
num_kernels,
|
| 1167 |
+
grad_col,
|
| 1168 |
+
data_value,
|
| 1169 |
+
data_spatial_shapes,
|
| 1170 |
+
data_level_start_index,
|
| 1171 |
+
data_sampling_loc,
|
| 1172 |
+
data_attn_weight,
|
| 1173 |
+
batch_size,
|
| 1174 |
+
spatial_size,
|
| 1175 |
+
num_heads,
|
| 1176 |
+
channels,
|
| 1177 |
+
num_levels,
|
| 1178 |
+
num_query,
|
| 1179 |
+
num_point,
|
| 1180 |
+
grad_value,
|
| 1181 |
+
grad_sampling_loc,
|
| 1182 |
+
grad_attn_weight);
|
| 1183 |
+
break;
|
| 1184 |
+
case 128:
|
| 1185 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
|
| 1186 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1187 |
+
0, stream>>>(
|
| 1188 |
+
num_kernels,
|
| 1189 |
+
grad_col,
|
| 1190 |
+
data_value,
|
| 1191 |
+
data_spatial_shapes,
|
| 1192 |
+
data_level_start_index,
|
| 1193 |
+
data_sampling_loc,
|
| 1194 |
+
data_attn_weight,
|
| 1195 |
+
batch_size,
|
| 1196 |
+
spatial_size,
|
| 1197 |
+
num_heads,
|
| 1198 |
+
channels,
|
| 1199 |
+
num_levels,
|
| 1200 |
+
num_query,
|
| 1201 |
+
num_point,
|
| 1202 |
+
grad_value,
|
| 1203 |
+
grad_sampling_loc,
|
| 1204 |
+
grad_attn_weight);
|
| 1205 |
+
break;
|
| 1206 |
+
case 256:
|
| 1207 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
|
| 1208 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1209 |
+
0, stream>>>(
|
| 1210 |
+
num_kernels,
|
| 1211 |
+
grad_col,
|
| 1212 |
+
data_value,
|
| 1213 |
+
data_spatial_shapes,
|
| 1214 |
+
data_level_start_index,
|
| 1215 |
+
data_sampling_loc,
|
| 1216 |
+
data_attn_weight,
|
| 1217 |
+
batch_size,
|
| 1218 |
+
spatial_size,
|
| 1219 |
+
num_heads,
|
| 1220 |
+
channels,
|
| 1221 |
+
num_levels,
|
| 1222 |
+
num_query,
|
| 1223 |
+
num_point,
|
| 1224 |
+
grad_value,
|
| 1225 |
+
grad_sampling_loc,
|
| 1226 |
+
grad_attn_weight);
|
| 1227 |
+
break;
|
| 1228 |
+
case 512:
|
| 1229 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
|
| 1230 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1231 |
+
0, stream>>>(
|
| 1232 |
+
num_kernels,
|
| 1233 |
+
grad_col,
|
| 1234 |
+
data_value,
|
| 1235 |
+
data_spatial_shapes,
|
| 1236 |
+
data_level_start_index,
|
| 1237 |
+
data_sampling_loc,
|
| 1238 |
+
data_attn_weight,
|
| 1239 |
+
batch_size,
|
| 1240 |
+
spatial_size,
|
| 1241 |
+
num_heads,
|
| 1242 |
+
channels,
|
| 1243 |
+
num_levels,
|
| 1244 |
+
num_query,
|
| 1245 |
+
num_point,
|
| 1246 |
+
grad_value,
|
| 1247 |
+
grad_sampling_loc,
|
| 1248 |
+
grad_attn_weight);
|
| 1249 |
+
break;
|
| 1250 |
+
case 1024:
|
| 1251 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
|
| 1252 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1253 |
+
0, stream>>>(
|
| 1254 |
+
num_kernels,
|
| 1255 |
+
grad_col,
|
| 1256 |
+
data_value,
|
| 1257 |
+
data_spatial_shapes,
|
| 1258 |
+
data_level_start_index,
|
| 1259 |
+
data_sampling_loc,
|
| 1260 |
+
data_attn_weight,
|
| 1261 |
+
batch_size,
|
| 1262 |
+
spatial_size,
|
| 1263 |
+
num_heads,
|
| 1264 |
+
channels,
|
| 1265 |
+
num_levels,
|
| 1266 |
+
num_query,
|
| 1267 |
+
num_point,
|
| 1268 |
+
grad_value,
|
| 1269 |
+
grad_sampling_loc,
|
| 1270 |
+
grad_attn_weight);
|
| 1271 |
+
break;
|
| 1272 |
+
default:
|
| 1273 |
+
if (channels < 64)
|
| 1274 |
+
{
|
| 1275 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
|
| 1276 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1277 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1278 |
+
num_kernels,
|
| 1279 |
+
grad_col,
|
| 1280 |
+
data_value,
|
| 1281 |
+
data_spatial_shapes,
|
| 1282 |
+
data_level_start_index,
|
| 1283 |
+
data_sampling_loc,
|
| 1284 |
+
data_attn_weight,
|
| 1285 |
+
batch_size,
|
| 1286 |
+
spatial_size,
|
| 1287 |
+
num_heads,
|
| 1288 |
+
channels,
|
| 1289 |
+
num_levels,
|
| 1290 |
+
num_query,
|
| 1291 |
+
num_point,
|
| 1292 |
+
grad_value,
|
| 1293 |
+
grad_sampling_loc,
|
| 1294 |
+
grad_attn_weight);
|
| 1295 |
+
}
|
| 1296 |
+
else
|
| 1297 |
+
{
|
| 1298 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
|
| 1299 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1300 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1301 |
+
num_kernels,
|
| 1302 |
+
grad_col,
|
| 1303 |
+
data_value,
|
| 1304 |
+
data_spatial_shapes,
|
| 1305 |
+
data_level_start_index,
|
| 1306 |
+
data_sampling_loc,
|
| 1307 |
+
data_attn_weight,
|
| 1308 |
+
batch_size,
|
| 1309 |
+
spatial_size,
|
| 1310 |
+
num_heads,
|
| 1311 |
+
channels,
|
| 1312 |
+
num_levels,
|
| 1313 |
+
num_query,
|
| 1314 |
+
num_point,
|
| 1315 |
+
grad_value,
|
| 1316 |
+
grad_sampling_loc,
|
| 1317 |
+
grad_attn_weight);
|
| 1318 |
+
}
|
| 1319 |
+
}
|
| 1320 |
+
}
|
| 1321 |
+
cudaError_t err = cudaGetLastError();
|
| 1322 |
+
if (err != cudaSuccess)
|
| 1323 |
+
{
|
| 1324 |
+
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
| 1325 |
+
}
|
| 1326 |
+
|
| 1327 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/ms_deform_attn.h
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include "cpu/ms_deform_attn_cpu.h"
|
| 14 |
+
|
| 15 |
+
#ifdef WITH_CUDA
|
| 16 |
+
#include "cuda/ms_deform_attn_cuda.h"
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
at::Tensor
|
| 21 |
+
ms_deform_attn_forward(
|
| 22 |
+
const at::Tensor &value,
|
| 23 |
+
const at::Tensor &spatial_shapes,
|
| 24 |
+
const at::Tensor &level_start_index,
|
| 25 |
+
const at::Tensor &sampling_loc,
|
| 26 |
+
const at::Tensor &attn_weight,
|
| 27 |
+
const int im2col_step)
|
| 28 |
+
{
|
| 29 |
+
if (value.type().is_cuda())
|
| 30 |
+
{
|
| 31 |
+
#ifdef WITH_CUDA
|
| 32 |
+
return ms_deform_attn_cuda_forward(
|
| 33 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
| 34 |
+
#else
|
| 35 |
+
AT_ERROR("Not compiled with GPU support");
|
| 36 |
+
#endif
|
| 37 |
+
}
|
| 38 |
+
AT_ERROR("Not implemented on the CPU");
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
std::vector<at::Tensor>
|
| 42 |
+
ms_deform_attn_backward(
|
| 43 |
+
const at::Tensor &value,
|
| 44 |
+
const at::Tensor &spatial_shapes,
|
| 45 |
+
const at::Tensor &level_start_index,
|
| 46 |
+
const at::Tensor &sampling_loc,
|
| 47 |
+
const at::Tensor &attn_weight,
|
| 48 |
+
const at::Tensor &grad_output,
|
| 49 |
+
const int im2col_step)
|
| 50 |
+
{
|
| 51 |
+
if (value.type().is_cuda())
|
| 52 |
+
{
|
| 53 |
+
#ifdef WITH_CUDA
|
| 54 |
+
return ms_deform_attn_cuda_backward(
|
| 55 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
| 56 |
+
#else
|
| 57 |
+
AT_ERROR("Not compiled with GPU support");
|
| 58 |
+
#endif
|
| 59 |
+
}
|
| 60 |
+
AT_ERROR("Not implemented on the CPU");
|
| 61 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <vector>
|
| 12 |
+
|
| 13 |
+
#include <ATen/ATen.h>
|
| 14 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
at::Tensor
|
| 18 |
+
ms_deform_attn_cpu_forward(
|
| 19 |
+
const at::Tensor &value,
|
| 20 |
+
const at::Tensor &spatial_shapes,
|
| 21 |
+
const at::Tensor &level_start_index,
|
| 22 |
+
const at::Tensor &sampling_loc,
|
| 23 |
+
const at::Tensor &attn_weight,
|
| 24 |
+
const int im2col_step)
|
| 25 |
+
{
|
| 26 |
+
AT_ERROR("Not implement on cpu");
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
std::vector<at::Tensor>
|
| 30 |
+
ms_deform_attn_cpu_backward(
|
| 31 |
+
const at::Tensor &value,
|
| 32 |
+
const at::Tensor &spatial_shapes,
|
| 33 |
+
const at::Tensor &level_start_index,
|
| 34 |
+
const at::Tensor &sampling_loc,
|
| 35 |
+
const at::Tensor &attn_weight,
|
| 36 |
+
const at::Tensor &grad_output,
|
| 37 |
+
const int im2col_step)
|
| 38 |
+
{
|
| 39 |
+
AT_ERROR("Not implement on cpu");
|
| 40 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
|
| 14 |
+
at::Tensor
|
| 15 |
+
ms_deform_attn_cpu_forward(
|
| 16 |
+
const at::Tensor &value,
|
| 17 |
+
const at::Tensor &spatial_shapes,
|
| 18 |
+
const at::Tensor &level_start_index,
|
| 19 |
+
const at::Tensor &sampling_loc,
|
| 20 |
+
const at::Tensor &attn_weight,
|
| 21 |
+
const int im2col_step);
|
| 22 |
+
|
| 23 |
+
std::vector<at::Tensor>
|
| 24 |
+
ms_deform_attn_cpu_backward(
|
| 25 |
+
const at::Tensor &value,
|
| 26 |
+
const at::Tensor &spatial_shapes,
|
| 27 |
+
const at::Tensor &level_start_index,
|
| 28 |
+
const at::Tensor &sampling_loc,
|
| 29 |
+
const at::Tensor &attn_weight,
|
| 30 |
+
const at::Tensor &grad_output,
|
| 31 |
+
const int im2col_step);
|
| 32 |
+
|
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <vector>
|
| 12 |
+
#include "cuda/ms_deform_im2col_cuda.cuh"
|
| 13 |
+
|
| 14 |
+
#include <ATen/ATen.h>
|
| 15 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 16 |
+
#include <cuda.h>
|
| 17 |
+
#include <cuda_runtime.h>
|
| 18 |
+
|
| 19 |
+
#pragma once
|
| 20 |
+
#include <torch/extension.h>
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
| 24 |
+
const at::Tensor &value,
|
| 25 |
+
const at::Tensor &spatial_shapes,
|
| 26 |
+
const at::Tensor &level_start_index,
|
| 27 |
+
const at::Tensor &sampling_loc,
|
| 28 |
+
const at::Tensor &attn_weight,
|
| 29 |
+
const int im2col_step)
|
| 30 |
+
{
|
| 31 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
| 32 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
| 33 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
| 34 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 35 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 36 |
+
|
| 37 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
| 38 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
| 39 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
| 40 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 41 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
| 42 |
+
|
| 43 |
+
const int batch = value.size(0);
|
| 44 |
+
const int spatial_size = value.size(1);
|
| 45 |
+
const int num_heads = value.size(2);
|
| 46 |
+
const int channels = value.size(3);
|
| 47 |
+
|
| 48 |
+
const int num_levels = spatial_shapes.size(0);
|
| 49 |
+
|
| 50 |
+
const int num_query = sampling_loc.size(1);
|
| 51 |
+
const int num_point = sampling_loc.size(4);
|
| 52 |
+
|
| 53 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
| 54 |
+
|
| 55 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
| 56 |
+
|
| 57 |
+
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
| 58 |
+
|
| 59 |
+
const int batch_n = im2col_step_;
|
| 60 |
+
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
| 61 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
| 62 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
| 63 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
| 64 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
| 65 |
+
{
|
| 66 |
+
auto columns = output_n.select(0, n);
|
| 67 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
| 68 |
+
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
| 69 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 70 |
+
spatial_shapes.data<int64_t>(),
|
| 71 |
+
level_start_index.data<int64_t>(),
|
| 72 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 73 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
| 74 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
| 75 |
+
columns.data<scalar_t>());
|
| 76 |
+
|
| 77 |
+
}));
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
output = output.view({batch, num_query, num_heads*channels});
|
| 81 |
+
|
| 82 |
+
return output;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
| 87 |
+
const at::Tensor &value,
|
| 88 |
+
const at::Tensor &spatial_shapes,
|
| 89 |
+
const at::Tensor &level_start_index,
|
| 90 |
+
const at::Tensor &sampling_loc,
|
| 91 |
+
const at::Tensor &attn_weight,
|
| 92 |
+
const at::Tensor &grad_output,
|
| 93 |
+
const int im2col_step)
|
| 94 |
+
{
|
| 95 |
+
|
| 96 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
| 97 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
| 98 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
| 99 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 100 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 101 |
+
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
| 102 |
+
|
| 103 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
| 104 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
| 105 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
| 106 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 107 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
| 108 |
+
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
| 109 |
+
|
| 110 |
+
const int batch = value.size(0);
|
| 111 |
+
const int spatial_size = value.size(1);
|
| 112 |
+
const int num_heads = value.size(2);
|
| 113 |
+
const int channels = value.size(3);
|
| 114 |
+
|
| 115 |
+
const int num_levels = spatial_shapes.size(0);
|
| 116 |
+
|
| 117 |
+
const int num_query = sampling_loc.size(1);
|
| 118 |
+
const int num_point = sampling_loc.size(4);
|
| 119 |
+
|
| 120 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
| 121 |
+
|
| 122 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
| 123 |
+
|
| 124 |
+
auto grad_value = at::zeros_like(value);
|
| 125 |
+
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
| 126 |
+
auto grad_attn_weight = at::zeros_like(attn_weight);
|
| 127 |
+
|
| 128 |
+
const int batch_n = im2col_step_;
|
| 129 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
| 130 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
| 131 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
| 132 |
+
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
| 133 |
+
|
| 134 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
| 135 |
+
{
|
| 136 |
+
auto grad_output_g = grad_output_n.select(0, n);
|
| 137 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
| 138 |
+
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
| 139 |
+
grad_output_g.data<scalar_t>(),
|
| 140 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 141 |
+
spatial_shapes.data<int64_t>(),
|
| 142 |
+
level_start_index.data<int64_t>(),
|
| 143 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 144 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
| 145 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
| 146 |
+
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 147 |
+
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 148 |
+
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
| 149 |
+
|
| 150 |
+
}));
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
return {
|
| 154 |
+
grad_value, grad_sampling_loc, grad_attn_weight
|
| 155 |
+
};
|
| 156 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh
ADDED
|
@@ -0,0 +1,1467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <vector>
|
| 12 |
+
|
| 13 |
+
#include <cuda.h>
|
| 14 |
+
#include <cuda_runtime.h>
|
| 15 |
+
|
| 16 |
+
#include <cstdio>
|
| 17 |
+
#include <algorithm>
|
| 18 |
+
#include <cstring>
|
| 19 |
+
|
| 20 |
+
#include <ATen/ATen.h>
|
| 21 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 22 |
+
|
| 23 |
+
#include <THC/THCAtomics.cuh>
|
| 24 |
+
|
| 25 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
| 26 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
| 27 |
+
i < (n); \
|
| 28 |
+
i += blockDim.x * gridDim.x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
| 32 |
+
const at::Tensor &value,
|
| 33 |
+
const at::Tensor &spatial_shapes,
|
| 34 |
+
const at::Tensor &level_start_index,
|
| 35 |
+
const at::Tensor &sampling_loc,
|
| 36 |
+
const at::Tensor &attn_weight,
|
| 37 |
+
const int im2col_step)
|
| 38 |
+
{
|
| 39 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
| 40 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
| 41 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
| 42 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 43 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 44 |
+
|
| 45 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
| 46 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
| 47 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
| 48 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 49 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
| 50 |
+
|
| 51 |
+
const int batch = value.size(0);
|
| 52 |
+
const int spatial_size = value.size(1);
|
| 53 |
+
const int num_heads = value.size(2);
|
| 54 |
+
const int channels = value.size(3);
|
| 55 |
+
|
| 56 |
+
const int num_levels = spatial_shapes.size(0);
|
| 57 |
+
|
| 58 |
+
const int num_query = sampling_loc.size(1);
|
| 59 |
+
const int num_point = sampling_loc.size(4);
|
| 60 |
+
|
| 61 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
| 62 |
+
|
| 63 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
| 64 |
+
|
| 65 |
+
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
| 66 |
+
|
| 67 |
+
const int batch_n = im2col_step_;
|
| 68 |
+
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
| 69 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
| 70 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
| 71 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
| 72 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
| 73 |
+
{
|
| 74 |
+
auto columns = output_n.select(0, n);
|
| 75 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
| 76 |
+
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
| 77 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 78 |
+
spatial_shapes.data<int64_t>(),
|
| 79 |
+
level_start_index.data<int64_t>(),
|
| 80 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 81 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
| 82 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
| 83 |
+
columns.data<scalar_t>());
|
| 84 |
+
|
| 85 |
+
}));
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
output = output.view({batch, num_query, num_heads*channels});
|
| 89 |
+
|
| 90 |
+
return output;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
| 95 |
+
const at::Tensor &value,
|
| 96 |
+
const at::Tensor &spatial_shapes,
|
| 97 |
+
const at::Tensor &level_start_index,
|
| 98 |
+
const at::Tensor &sampling_loc,
|
| 99 |
+
const at::Tensor &attn_weight,
|
| 100 |
+
const at::Tensor &grad_output,
|
| 101 |
+
const int im2col_step)
|
| 102 |
+
{
|
| 103 |
+
|
| 104 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
| 105 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
| 106 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
| 107 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 108 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 109 |
+
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
| 110 |
+
|
| 111 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
| 112 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
| 113 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
| 114 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 115 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
| 116 |
+
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
| 117 |
+
|
| 118 |
+
const int batch = value.size(0);
|
| 119 |
+
const int spatial_size = value.size(1);
|
| 120 |
+
const int num_heads = value.size(2);
|
| 121 |
+
const int channels = value.size(3);
|
| 122 |
+
|
| 123 |
+
const int num_levels = spatial_shapes.size(0);
|
| 124 |
+
|
| 125 |
+
const int num_query = sampling_loc.size(1);
|
| 126 |
+
const int num_point = sampling_loc.size(4);
|
| 127 |
+
|
| 128 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
| 129 |
+
|
| 130 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
| 131 |
+
|
| 132 |
+
auto grad_value = at::zeros_like(value);
|
| 133 |
+
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
| 134 |
+
auto grad_attn_weight = at::zeros_like(attn_weight);
|
| 135 |
+
|
| 136 |
+
const int batch_n = im2col_step_;
|
| 137 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
| 138 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
| 139 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
| 140 |
+
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
| 141 |
+
|
| 142 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
| 143 |
+
{
|
| 144 |
+
auto grad_output_g = grad_output_n.select(0, n);
|
| 145 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
| 146 |
+
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
| 147 |
+
grad_output_g.data<scalar_t>(),
|
| 148 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 149 |
+
spatial_shapes.data<int64_t>(),
|
| 150 |
+
level_start_index.data<int64_t>(),
|
| 151 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 152 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
| 153 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
| 154 |
+
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 155 |
+
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 156 |
+
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
| 157 |
+
|
| 158 |
+
}));
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
return {
|
| 162 |
+
grad_value, grad_sampling_loc, grad_attn_weight
|
| 163 |
+
};
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
const int CUDA_NUM_THREADS = 1024;
|
| 167 |
+
inline int GET_BLOCKS(const int N, const int num_threads)
|
| 168 |
+
{
|
| 169 |
+
return (N + num_threads - 1) / num_threads;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
template <typename scalar_t>
|
| 174 |
+
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
|
| 175 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 176 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
|
| 177 |
+
{
|
| 178 |
+
const int h_low = floor(h);
|
| 179 |
+
const int w_low = floor(w);
|
| 180 |
+
const int h_high = h_low + 1;
|
| 181 |
+
const int w_high = w_low + 1;
|
| 182 |
+
|
| 183 |
+
const scalar_t lh = h - h_low;
|
| 184 |
+
const scalar_t lw = w - w_low;
|
| 185 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 186 |
+
|
| 187 |
+
const int w_stride = nheads * channels;
|
| 188 |
+
const int h_stride = width * w_stride;
|
| 189 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 190 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 191 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 192 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 193 |
+
const int base_ptr = m * channels + c;
|
| 194 |
+
|
| 195 |
+
scalar_t v1 = 0;
|
| 196 |
+
if (h_low >= 0 && w_low >= 0)
|
| 197 |
+
{
|
| 198 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 199 |
+
v1 = bottom_data[ptr1];
|
| 200 |
+
}
|
| 201 |
+
scalar_t v2 = 0;
|
| 202 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 203 |
+
{
|
| 204 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 205 |
+
v2 = bottom_data[ptr2];
|
| 206 |
+
}
|
| 207 |
+
scalar_t v3 = 0;
|
| 208 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 209 |
+
{
|
| 210 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 211 |
+
v3 = bottom_data[ptr3];
|
| 212 |
+
}
|
| 213 |
+
scalar_t v4 = 0;
|
| 214 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 215 |
+
{
|
| 216 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 217 |
+
v4 = bottom_data[ptr4];
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 221 |
+
|
| 222 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 223 |
+
return val;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
template <typename scalar_t>
|
| 228 |
+
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
|
| 229 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 230 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 231 |
+
const scalar_t &top_grad,
|
| 232 |
+
const scalar_t &attn_weight,
|
| 233 |
+
scalar_t* &grad_value,
|
| 234 |
+
scalar_t* grad_sampling_loc,
|
| 235 |
+
scalar_t* grad_attn_weight)
|
| 236 |
+
{
|
| 237 |
+
const int h_low = floor(h);
|
| 238 |
+
const int w_low = floor(w);
|
| 239 |
+
const int h_high = h_low + 1;
|
| 240 |
+
const int w_high = w_low + 1;
|
| 241 |
+
|
| 242 |
+
const scalar_t lh = h - h_low;
|
| 243 |
+
const scalar_t lw = w - w_low;
|
| 244 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 245 |
+
|
| 246 |
+
const int w_stride = nheads * channels;
|
| 247 |
+
const int h_stride = width * w_stride;
|
| 248 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 249 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 250 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 251 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 252 |
+
const int base_ptr = m * channels + c;
|
| 253 |
+
|
| 254 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 255 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 256 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 257 |
+
|
| 258 |
+
scalar_t v1 = 0;
|
| 259 |
+
if (h_low >= 0 && w_low >= 0)
|
| 260 |
+
{
|
| 261 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 262 |
+
v1 = bottom_data[ptr1];
|
| 263 |
+
grad_h_weight -= hw * v1;
|
| 264 |
+
grad_w_weight -= hh * v1;
|
| 265 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 266 |
+
}
|
| 267 |
+
scalar_t v2 = 0;
|
| 268 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 269 |
+
{
|
| 270 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 271 |
+
v2 = bottom_data[ptr2];
|
| 272 |
+
grad_h_weight -= lw * v2;
|
| 273 |
+
grad_w_weight += hh * v2;
|
| 274 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 275 |
+
}
|
| 276 |
+
scalar_t v3 = 0;
|
| 277 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 278 |
+
{
|
| 279 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 280 |
+
v3 = bottom_data[ptr3];
|
| 281 |
+
grad_h_weight += hw * v3;
|
| 282 |
+
grad_w_weight -= lh * v3;
|
| 283 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 284 |
+
}
|
| 285 |
+
scalar_t v4 = 0;
|
| 286 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 287 |
+
{
|
| 288 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 289 |
+
v4 = bottom_data[ptr4];
|
| 290 |
+
grad_h_weight += lw * v4;
|
| 291 |
+
grad_w_weight += lh * v4;
|
| 292 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 296 |
+
*grad_attn_weight = top_grad * val;
|
| 297 |
+
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
|
| 298 |
+
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
template <typename scalar_t>
|
| 303 |
+
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
|
| 304 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 305 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 306 |
+
const scalar_t &top_grad,
|
| 307 |
+
const scalar_t &attn_weight,
|
| 308 |
+
scalar_t* &grad_value,
|
| 309 |
+
scalar_t* grad_sampling_loc,
|
| 310 |
+
scalar_t* grad_attn_weight)
|
| 311 |
+
{
|
| 312 |
+
const int h_low = floor(h);
|
| 313 |
+
const int w_low = floor(w);
|
| 314 |
+
const int h_high = h_low + 1;
|
| 315 |
+
const int w_high = w_low + 1;
|
| 316 |
+
|
| 317 |
+
const scalar_t lh = h - h_low;
|
| 318 |
+
const scalar_t lw = w - w_low;
|
| 319 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 320 |
+
|
| 321 |
+
const int w_stride = nheads * channels;
|
| 322 |
+
const int h_stride = width * w_stride;
|
| 323 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 324 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 325 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 326 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 327 |
+
const int base_ptr = m * channels + c;
|
| 328 |
+
|
| 329 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 330 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 331 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 332 |
+
|
| 333 |
+
scalar_t v1 = 0;
|
| 334 |
+
if (h_low >= 0 && w_low >= 0)
|
| 335 |
+
{
|
| 336 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 337 |
+
v1 = bottom_data[ptr1];
|
| 338 |
+
grad_h_weight -= hw * v1;
|
| 339 |
+
grad_w_weight -= hh * v1;
|
| 340 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 341 |
+
}
|
| 342 |
+
scalar_t v2 = 0;
|
| 343 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 344 |
+
{
|
| 345 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 346 |
+
v2 = bottom_data[ptr2];
|
| 347 |
+
grad_h_weight -= lw * v2;
|
| 348 |
+
grad_w_weight += hh * v2;
|
| 349 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 350 |
+
}
|
| 351 |
+
scalar_t v3 = 0;
|
| 352 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 353 |
+
{
|
| 354 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 355 |
+
v3 = bottom_data[ptr3];
|
| 356 |
+
grad_h_weight += hw * v3;
|
| 357 |
+
grad_w_weight -= lh * v3;
|
| 358 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 359 |
+
}
|
| 360 |
+
scalar_t v4 = 0;
|
| 361 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 362 |
+
{
|
| 363 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 364 |
+
v4 = bottom_data[ptr4];
|
| 365 |
+
grad_h_weight += lw * v4;
|
| 366 |
+
grad_w_weight += lh * v4;
|
| 367 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 371 |
+
atomicAdd(grad_attn_weight, top_grad * val);
|
| 372 |
+
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
|
| 373 |
+
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
template <typename scalar_t>
|
| 378 |
+
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
|
| 379 |
+
const scalar_t *data_value,
|
| 380 |
+
const int64_t *data_spatial_shapes,
|
| 381 |
+
const int64_t *data_level_start_index,
|
| 382 |
+
const scalar_t *data_sampling_loc,
|
| 383 |
+
const scalar_t *data_attn_weight,
|
| 384 |
+
const int batch_size,
|
| 385 |
+
const int spatial_size,
|
| 386 |
+
const int num_heads,
|
| 387 |
+
const int channels,
|
| 388 |
+
const int num_levels,
|
| 389 |
+
const int num_query,
|
| 390 |
+
const int num_point,
|
| 391 |
+
scalar_t *data_col)
|
| 392 |
+
{
|
| 393 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 394 |
+
{
|
| 395 |
+
int _temp = index;
|
| 396 |
+
const int c_col = _temp % channels;
|
| 397 |
+
_temp /= channels;
|
| 398 |
+
const int sampling_index = _temp;
|
| 399 |
+
const int m_col = _temp % num_heads;
|
| 400 |
+
_temp /= num_heads;
|
| 401 |
+
const int q_col = _temp % num_query;
|
| 402 |
+
_temp /= num_query;
|
| 403 |
+
const int b_col = _temp;
|
| 404 |
+
|
| 405 |
+
scalar_t *data_col_ptr = data_col + index;
|
| 406 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 407 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 408 |
+
const int qid_stride = num_heads * channels;
|
| 409 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 410 |
+
scalar_t col = 0;
|
| 411 |
+
|
| 412 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 413 |
+
{
|
| 414 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 415 |
+
const int spatial_h_ptr = l_col << 1;
|
| 416 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 417 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 418 |
+
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
|
| 419 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 420 |
+
{
|
| 421 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 422 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 423 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 424 |
+
|
| 425 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 426 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 427 |
+
|
| 428 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 429 |
+
{
|
| 430 |
+
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
data_weight_ptr += 1;
|
| 434 |
+
data_loc_w_ptr += 2;
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
*data_col_ptr = col;
|
| 438 |
+
}
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 442 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
|
| 443 |
+
const scalar_t *grad_col,
|
| 444 |
+
const scalar_t *data_value,
|
| 445 |
+
const int64_t *data_spatial_shapes,
|
| 446 |
+
const int64_t *data_level_start_index,
|
| 447 |
+
const scalar_t *data_sampling_loc,
|
| 448 |
+
const scalar_t *data_attn_weight,
|
| 449 |
+
const int batch_size,
|
| 450 |
+
const int spatial_size,
|
| 451 |
+
const int num_heads,
|
| 452 |
+
const int channels,
|
| 453 |
+
const int num_levels,
|
| 454 |
+
const int num_query,
|
| 455 |
+
const int num_point,
|
| 456 |
+
scalar_t *grad_value,
|
| 457 |
+
scalar_t *grad_sampling_loc,
|
| 458 |
+
scalar_t *grad_attn_weight)
|
| 459 |
+
{
|
| 460 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 461 |
+
{
|
| 462 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 463 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 464 |
+
unsigned int tid = threadIdx.x;
|
| 465 |
+
int _temp = index;
|
| 466 |
+
const int c_col = _temp % channels;
|
| 467 |
+
_temp /= channels;
|
| 468 |
+
const int sampling_index = _temp;
|
| 469 |
+
const int m_col = _temp % num_heads;
|
| 470 |
+
_temp /= num_heads;
|
| 471 |
+
const int q_col = _temp % num_query;
|
| 472 |
+
_temp /= num_query;
|
| 473 |
+
const int b_col = _temp;
|
| 474 |
+
|
| 475 |
+
const scalar_t top_grad = grad_col[index];
|
| 476 |
+
|
| 477 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 478 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 479 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 480 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 481 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 482 |
+
const int grad_weight_stride = 1;
|
| 483 |
+
const int grad_loc_stride = 2;
|
| 484 |
+
const int qid_stride = num_heads * channels;
|
| 485 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 486 |
+
|
| 487 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 488 |
+
{
|
| 489 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 490 |
+
const int spatial_h_ptr = l_col << 1;
|
| 491 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 492 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 493 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 494 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 495 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 496 |
+
|
| 497 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 498 |
+
{
|
| 499 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 500 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 501 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 502 |
+
|
| 503 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 504 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 505 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 506 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 507 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 508 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 509 |
+
{
|
| 510 |
+
ms_deform_attn_col2im_bilinear(
|
| 511 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 512 |
+
top_grad, weight, grad_value_ptr,
|
| 513 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
__syncthreads();
|
| 517 |
+
if (tid == 0)
|
| 518 |
+
{
|
| 519 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 520 |
+
int sid=2;
|
| 521 |
+
for (unsigned int tid = 1; tid < blockSize; ++tid)
|
| 522 |
+
{
|
| 523 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 524 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 525 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 526 |
+
sid += 2;
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
*grad_sampling_loc = _grad_w;
|
| 531 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 532 |
+
*grad_attn_weight = _grad_a;
|
| 533 |
+
}
|
| 534 |
+
__syncthreads();
|
| 535 |
+
|
| 536 |
+
data_weight_ptr += 1;
|
| 537 |
+
data_loc_w_ptr += 2;
|
| 538 |
+
grad_attn_weight += grad_weight_stride;
|
| 539 |
+
grad_sampling_loc += grad_loc_stride;
|
| 540 |
+
}
|
| 541 |
+
}
|
| 542 |
+
}
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 547 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
|
| 548 |
+
const scalar_t *grad_col,
|
| 549 |
+
const scalar_t *data_value,
|
| 550 |
+
const int64_t *data_spatial_shapes,
|
| 551 |
+
const int64_t *data_level_start_index,
|
| 552 |
+
const scalar_t *data_sampling_loc,
|
| 553 |
+
const scalar_t *data_attn_weight,
|
| 554 |
+
const int batch_size,
|
| 555 |
+
const int spatial_size,
|
| 556 |
+
const int num_heads,
|
| 557 |
+
const int channels,
|
| 558 |
+
const int num_levels,
|
| 559 |
+
const int num_query,
|
| 560 |
+
const int num_point,
|
| 561 |
+
scalar_t *grad_value,
|
| 562 |
+
scalar_t *grad_sampling_loc,
|
| 563 |
+
scalar_t *grad_attn_weight)
|
| 564 |
+
{
|
| 565 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 566 |
+
{
|
| 567 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 568 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 569 |
+
unsigned int tid = threadIdx.x;
|
| 570 |
+
int _temp = index;
|
| 571 |
+
const int c_col = _temp % channels;
|
| 572 |
+
_temp /= channels;
|
| 573 |
+
const int sampling_index = _temp;
|
| 574 |
+
const int m_col = _temp % num_heads;
|
| 575 |
+
_temp /= num_heads;
|
| 576 |
+
const int q_col = _temp % num_query;
|
| 577 |
+
_temp /= num_query;
|
| 578 |
+
const int b_col = _temp;
|
| 579 |
+
|
| 580 |
+
const scalar_t top_grad = grad_col[index];
|
| 581 |
+
|
| 582 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 583 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 584 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 585 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 586 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 587 |
+
const int grad_weight_stride = 1;
|
| 588 |
+
const int grad_loc_stride = 2;
|
| 589 |
+
const int qid_stride = num_heads * channels;
|
| 590 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 591 |
+
|
| 592 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 593 |
+
{
|
| 594 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 595 |
+
const int spatial_h_ptr = l_col << 1;
|
| 596 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 597 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 598 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 599 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 600 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 601 |
+
|
| 602 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 603 |
+
{
|
| 604 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 605 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 606 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 607 |
+
|
| 608 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 609 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 610 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 611 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 612 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 613 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 614 |
+
{
|
| 615 |
+
ms_deform_attn_col2im_bilinear(
|
| 616 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 617 |
+
top_grad, weight, grad_value_ptr,
|
| 618 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
__syncthreads();
|
| 622 |
+
|
| 623 |
+
for (unsigned int s=blockSize/2; s>0; s>>=1)
|
| 624 |
+
{
|
| 625 |
+
if (tid < s) {
|
| 626 |
+
const unsigned int xid1 = tid << 1;
|
| 627 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 628 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 629 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 630 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 631 |
+
}
|
| 632 |
+
__syncthreads();
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
if (tid == 0)
|
| 636 |
+
{
|
| 637 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 638 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 639 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 640 |
+
}
|
| 641 |
+
__syncthreads();
|
| 642 |
+
|
| 643 |
+
data_weight_ptr += 1;
|
| 644 |
+
data_loc_w_ptr += 2;
|
| 645 |
+
grad_attn_weight += grad_weight_stride;
|
| 646 |
+
grad_sampling_loc += grad_loc_stride;
|
| 647 |
+
}
|
| 648 |
+
}
|
| 649 |
+
}
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
template <typename scalar_t>
|
| 654 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
|
| 655 |
+
const scalar_t *grad_col,
|
| 656 |
+
const scalar_t *data_value,
|
| 657 |
+
const int64_t *data_spatial_shapes,
|
| 658 |
+
const int64_t *data_level_start_index,
|
| 659 |
+
const scalar_t *data_sampling_loc,
|
| 660 |
+
const scalar_t *data_attn_weight,
|
| 661 |
+
const int batch_size,
|
| 662 |
+
const int spatial_size,
|
| 663 |
+
const int num_heads,
|
| 664 |
+
const int channels,
|
| 665 |
+
const int num_levels,
|
| 666 |
+
const int num_query,
|
| 667 |
+
const int num_point,
|
| 668 |
+
scalar_t *grad_value,
|
| 669 |
+
scalar_t *grad_sampling_loc,
|
| 670 |
+
scalar_t *grad_attn_weight)
|
| 671 |
+
{
|
| 672 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 673 |
+
{
|
| 674 |
+
extern __shared__ int _s[];
|
| 675 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 676 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 677 |
+
unsigned int tid = threadIdx.x;
|
| 678 |
+
int _temp = index;
|
| 679 |
+
const int c_col = _temp % channels;
|
| 680 |
+
_temp /= channels;
|
| 681 |
+
const int sampling_index = _temp;
|
| 682 |
+
const int m_col = _temp % num_heads;
|
| 683 |
+
_temp /= num_heads;
|
| 684 |
+
const int q_col = _temp % num_query;
|
| 685 |
+
_temp /= num_query;
|
| 686 |
+
const int b_col = _temp;
|
| 687 |
+
|
| 688 |
+
const scalar_t top_grad = grad_col[index];
|
| 689 |
+
|
| 690 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 691 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 692 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 693 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 694 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 695 |
+
const int grad_weight_stride = 1;
|
| 696 |
+
const int grad_loc_stride = 2;
|
| 697 |
+
const int qid_stride = num_heads * channels;
|
| 698 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 699 |
+
|
| 700 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 701 |
+
{
|
| 702 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 703 |
+
const int spatial_h_ptr = l_col << 1;
|
| 704 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 705 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 706 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 707 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 708 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 709 |
+
|
| 710 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 711 |
+
{
|
| 712 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 713 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 714 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 715 |
+
|
| 716 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 717 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 718 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 719 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 720 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 721 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 722 |
+
{
|
| 723 |
+
ms_deform_attn_col2im_bilinear(
|
| 724 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 725 |
+
top_grad, weight, grad_value_ptr,
|
| 726 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 727 |
+
}
|
| 728 |
+
|
| 729 |
+
__syncthreads();
|
| 730 |
+
if (tid == 0)
|
| 731 |
+
{
|
| 732 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 733 |
+
int sid=2;
|
| 734 |
+
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
|
| 735 |
+
{
|
| 736 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 737 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 738 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 739 |
+
sid += 2;
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
*grad_sampling_loc = _grad_w;
|
| 744 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 745 |
+
*grad_attn_weight = _grad_a;
|
| 746 |
+
}
|
| 747 |
+
__syncthreads();
|
| 748 |
+
|
| 749 |
+
data_weight_ptr += 1;
|
| 750 |
+
data_loc_w_ptr += 2;
|
| 751 |
+
grad_attn_weight += grad_weight_stride;
|
| 752 |
+
grad_sampling_loc += grad_loc_stride;
|
| 753 |
+
}
|
| 754 |
+
}
|
| 755 |
+
}
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
template <typename scalar_t>
|
| 759 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
|
| 760 |
+
const scalar_t *grad_col,
|
| 761 |
+
const scalar_t *data_value,
|
| 762 |
+
const int64_t *data_spatial_shapes,
|
| 763 |
+
const int64_t *data_level_start_index,
|
| 764 |
+
const scalar_t *data_sampling_loc,
|
| 765 |
+
const scalar_t *data_attn_weight,
|
| 766 |
+
const int batch_size,
|
| 767 |
+
const int spatial_size,
|
| 768 |
+
const int num_heads,
|
| 769 |
+
const int channels,
|
| 770 |
+
const int num_levels,
|
| 771 |
+
const int num_query,
|
| 772 |
+
const int num_point,
|
| 773 |
+
scalar_t *grad_value,
|
| 774 |
+
scalar_t *grad_sampling_loc,
|
| 775 |
+
scalar_t *grad_attn_weight)
|
| 776 |
+
{
|
| 777 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 778 |
+
{
|
| 779 |
+
extern __shared__ int _s[];
|
| 780 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 781 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 782 |
+
unsigned int tid = threadIdx.x;
|
| 783 |
+
int _temp = index;
|
| 784 |
+
const int c_col = _temp % channels;
|
| 785 |
+
_temp /= channels;
|
| 786 |
+
const int sampling_index = _temp;
|
| 787 |
+
const int m_col = _temp % num_heads;
|
| 788 |
+
_temp /= num_heads;
|
| 789 |
+
const int q_col = _temp % num_query;
|
| 790 |
+
_temp /= num_query;
|
| 791 |
+
const int b_col = _temp;
|
| 792 |
+
|
| 793 |
+
const scalar_t top_grad = grad_col[index];
|
| 794 |
+
|
| 795 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 796 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 797 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 798 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 799 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 800 |
+
const int grad_weight_stride = 1;
|
| 801 |
+
const int grad_loc_stride = 2;
|
| 802 |
+
const int qid_stride = num_heads * channels;
|
| 803 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 804 |
+
|
| 805 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 806 |
+
{
|
| 807 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 808 |
+
const int spatial_h_ptr = l_col << 1;
|
| 809 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 810 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 811 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 812 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 813 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 814 |
+
|
| 815 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 816 |
+
{
|
| 817 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 818 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 819 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 820 |
+
|
| 821 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 822 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 823 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 824 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 825 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 826 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 827 |
+
{
|
| 828 |
+
ms_deform_attn_col2im_bilinear(
|
| 829 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 830 |
+
top_grad, weight, grad_value_ptr,
|
| 831 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
__syncthreads();
|
| 835 |
+
|
| 836 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 837 |
+
{
|
| 838 |
+
if (tid < s) {
|
| 839 |
+
const unsigned int xid1 = tid << 1;
|
| 840 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 841 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 842 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 843 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 844 |
+
if (tid + (s << 1) < spre)
|
| 845 |
+
{
|
| 846 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 847 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 848 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 849 |
+
}
|
| 850 |
+
}
|
| 851 |
+
__syncthreads();
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
if (tid == 0)
|
| 855 |
+
{
|
| 856 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 857 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 858 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 859 |
+
}
|
| 860 |
+
__syncthreads();
|
| 861 |
+
|
| 862 |
+
data_weight_ptr += 1;
|
| 863 |
+
data_loc_w_ptr += 2;
|
| 864 |
+
grad_attn_weight += grad_weight_stride;
|
| 865 |
+
grad_sampling_loc += grad_loc_stride;
|
| 866 |
+
}
|
| 867 |
+
}
|
| 868 |
+
}
|
| 869 |
+
}
|
| 870 |
+
|
| 871 |
+
template <typename scalar_t>
|
| 872 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
|
| 873 |
+
const scalar_t *grad_col,
|
| 874 |
+
const scalar_t *data_value,
|
| 875 |
+
const int64_t *data_spatial_shapes,
|
| 876 |
+
const int64_t *data_level_start_index,
|
| 877 |
+
const scalar_t *data_sampling_loc,
|
| 878 |
+
const scalar_t *data_attn_weight,
|
| 879 |
+
const int batch_size,
|
| 880 |
+
const int spatial_size,
|
| 881 |
+
const int num_heads,
|
| 882 |
+
const int channels,
|
| 883 |
+
const int num_levels,
|
| 884 |
+
const int num_query,
|
| 885 |
+
const int num_point,
|
| 886 |
+
scalar_t *grad_value,
|
| 887 |
+
scalar_t *grad_sampling_loc,
|
| 888 |
+
scalar_t *grad_attn_weight)
|
| 889 |
+
{
|
| 890 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 891 |
+
{
|
| 892 |
+
extern __shared__ int _s[];
|
| 893 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 894 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 895 |
+
unsigned int tid = threadIdx.x;
|
| 896 |
+
int _temp = index;
|
| 897 |
+
const int c_col = _temp % channels;
|
| 898 |
+
_temp /= channels;
|
| 899 |
+
const int sampling_index = _temp;
|
| 900 |
+
const int m_col = _temp % num_heads;
|
| 901 |
+
_temp /= num_heads;
|
| 902 |
+
const int q_col = _temp % num_query;
|
| 903 |
+
_temp /= num_query;
|
| 904 |
+
const int b_col = _temp;
|
| 905 |
+
|
| 906 |
+
const scalar_t top_grad = grad_col[index];
|
| 907 |
+
|
| 908 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 909 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 910 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 911 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 912 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 913 |
+
const int grad_weight_stride = 1;
|
| 914 |
+
const int grad_loc_stride = 2;
|
| 915 |
+
const int qid_stride = num_heads * channels;
|
| 916 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 917 |
+
|
| 918 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 919 |
+
{
|
| 920 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 921 |
+
const int spatial_h_ptr = l_col << 1;
|
| 922 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 923 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 924 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 925 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 926 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 927 |
+
|
| 928 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 929 |
+
{
|
| 930 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 931 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 932 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 933 |
+
|
| 934 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 935 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 936 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 937 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 938 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 939 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 940 |
+
{
|
| 941 |
+
ms_deform_attn_col2im_bilinear(
|
| 942 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 943 |
+
top_grad, weight, grad_value_ptr,
|
| 944 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 945 |
+
}
|
| 946 |
+
|
| 947 |
+
__syncthreads();
|
| 948 |
+
|
| 949 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 950 |
+
{
|
| 951 |
+
if (tid < s) {
|
| 952 |
+
const unsigned int xid1 = tid << 1;
|
| 953 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 954 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 955 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 956 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 957 |
+
if (tid + (s << 1) < spre)
|
| 958 |
+
{
|
| 959 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 960 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 961 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 962 |
+
}
|
| 963 |
+
}
|
| 964 |
+
__syncthreads();
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
if (tid == 0)
|
| 968 |
+
{
|
| 969 |
+
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
|
| 970 |
+
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
|
| 971 |
+
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
|
| 972 |
+
}
|
| 973 |
+
__syncthreads();
|
| 974 |
+
|
| 975 |
+
data_weight_ptr += 1;
|
| 976 |
+
data_loc_w_ptr += 2;
|
| 977 |
+
grad_attn_weight += grad_weight_stride;
|
| 978 |
+
grad_sampling_loc += grad_loc_stride;
|
| 979 |
+
}
|
| 980 |
+
}
|
| 981 |
+
}
|
| 982 |
+
}
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
template <typename scalar_t>
|
| 986 |
+
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
|
| 987 |
+
const scalar_t *grad_col,
|
| 988 |
+
const scalar_t *data_value,
|
| 989 |
+
const int64_t *data_spatial_shapes,
|
| 990 |
+
const int64_t *data_level_start_index,
|
| 991 |
+
const scalar_t *data_sampling_loc,
|
| 992 |
+
const scalar_t *data_attn_weight,
|
| 993 |
+
const int batch_size,
|
| 994 |
+
const int spatial_size,
|
| 995 |
+
const int num_heads,
|
| 996 |
+
const int channels,
|
| 997 |
+
const int num_levels,
|
| 998 |
+
const int num_query,
|
| 999 |
+
const int num_point,
|
| 1000 |
+
scalar_t *grad_value,
|
| 1001 |
+
scalar_t *grad_sampling_loc,
|
| 1002 |
+
scalar_t *grad_attn_weight)
|
| 1003 |
+
{
|
| 1004 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 1005 |
+
{
|
| 1006 |
+
int _temp = index;
|
| 1007 |
+
const int c_col = _temp % channels;
|
| 1008 |
+
_temp /= channels;
|
| 1009 |
+
const int sampling_index = _temp;
|
| 1010 |
+
const int m_col = _temp % num_heads;
|
| 1011 |
+
_temp /= num_heads;
|
| 1012 |
+
const int q_col = _temp % num_query;
|
| 1013 |
+
_temp /= num_query;
|
| 1014 |
+
const int b_col = _temp;
|
| 1015 |
+
|
| 1016 |
+
const scalar_t top_grad = grad_col[index];
|
| 1017 |
+
|
| 1018 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 1019 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 1020 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 1021 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 1022 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 1023 |
+
const int grad_weight_stride = 1;
|
| 1024 |
+
const int grad_loc_stride = 2;
|
| 1025 |
+
const int qid_stride = num_heads * channels;
|
| 1026 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 1027 |
+
|
| 1028 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 1029 |
+
{
|
| 1030 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 1031 |
+
const int spatial_h_ptr = l_col << 1;
|
| 1032 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 1033 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 1034 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 1035 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 1036 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 1037 |
+
|
| 1038 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 1039 |
+
{
|
| 1040 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 1041 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 1042 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 1043 |
+
|
| 1044 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 1045 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 1046 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 1047 |
+
{
|
| 1048 |
+
ms_deform_attn_col2im_bilinear_gm(
|
| 1049 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 1050 |
+
top_grad, weight, grad_value_ptr,
|
| 1051 |
+
grad_sampling_loc, grad_attn_weight);
|
| 1052 |
+
}
|
| 1053 |
+
data_weight_ptr += 1;
|
| 1054 |
+
data_loc_w_ptr += 2;
|
| 1055 |
+
grad_attn_weight += grad_weight_stride;
|
| 1056 |
+
grad_sampling_loc += grad_loc_stride;
|
| 1057 |
+
}
|
| 1058 |
+
}
|
| 1059 |
+
}
|
| 1060 |
+
}
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
template <typename scalar_t>
|
| 1064 |
+
void ms_deformable_im2col_cuda(cudaStream_t stream,
|
| 1065 |
+
const scalar_t* data_value,
|
| 1066 |
+
const int64_t* data_spatial_shapes,
|
| 1067 |
+
const int64_t* data_level_start_index,
|
| 1068 |
+
const scalar_t* data_sampling_loc,
|
| 1069 |
+
const scalar_t* data_attn_weight,
|
| 1070 |
+
const int batch_size,
|
| 1071 |
+
const int spatial_size,
|
| 1072 |
+
const int num_heads,
|
| 1073 |
+
const int channels,
|
| 1074 |
+
const int num_levels,
|
| 1075 |
+
const int num_query,
|
| 1076 |
+
const int num_point,
|
| 1077 |
+
scalar_t* data_col)
|
| 1078 |
+
{
|
| 1079 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 1080 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 1081 |
+
const int num_threads = CUDA_NUM_THREADS;
|
| 1082 |
+
ms_deformable_im2col_gpu_kernel<scalar_t>
|
| 1083 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1084 |
+
0, stream>>>(
|
| 1085 |
+
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
|
| 1086 |
+
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
|
| 1087 |
+
|
| 1088 |
+
cudaError_t err = cudaGetLastError();
|
| 1089 |
+
if (err != cudaSuccess)
|
| 1090 |
+
{
|
| 1091 |
+
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
| 1092 |
+
}
|
| 1093 |
+
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
template <typename scalar_t>
|
| 1097 |
+
void ms_deformable_col2im_cuda(cudaStream_t stream,
|
| 1098 |
+
const scalar_t* grad_col,
|
| 1099 |
+
const scalar_t* data_value,
|
| 1100 |
+
const int64_t * data_spatial_shapes,
|
| 1101 |
+
const int64_t * data_level_start_index,
|
| 1102 |
+
const scalar_t * data_sampling_loc,
|
| 1103 |
+
const scalar_t * data_attn_weight,
|
| 1104 |
+
const int batch_size,
|
| 1105 |
+
const int spatial_size,
|
| 1106 |
+
const int num_heads,
|
| 1107 |
+
const int channels,
|
| 1108 |
+
const int num_levels,
|
| 1109 |
+
const int num_query,
|
| 1110 |
+
const int num_point,
|
| 1111 |
+
scalar_t* grad_value,
|
| 1112 |
+
scalar_t* grad_sampling_loc,
|
| 1113 |
+
scalar_t* grad_attn_weight)
|
| 1114 |
+
{
|
| 1115 |
+
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
|
| 1116 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 1117 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 1118 |
+
if (channels > 1024)
|
| 1119 |
+
{
|
| 1120 |
+
if ((channels & 1023) == 0)
|
| 1121 |
+
{
|
| 1122 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
|
| 1123 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1124 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1125 |
+
num_kernels,
|
| 1126 |
+
grad_col,
|
| 1127 |
+
data_value,
|
| 1128 |
+
data_spatial_shapes,
|
| 1129 |
+
data_level_start_index,
|
| 1130 |
+
data_sampling_loc,
|
| 1131 |
+
data_attn_weight,
|
| 1132 |
+
batch_size,
|
| 1133 |
+
spatial_size,
|
| 1134 |
+
num_heads,
|
| 1135 |
+
channels,
|
| 1136 |
+
num_levels,
|
| 1137 |
+
num_query,
|
| 1138 |
+
num_point,
|
| 1139 |
+
grad_value,
|
| 1140 |
+
grad_sampling_loc,
|
| 1141 |
+
grad_attn_weight);
|
| 1142 |
+
}
|
| 1143 |
+
else
|
| 1144 |
+
{
|
| 1145 |
+
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
|
| 1146 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1147 |
+
0, stream>>>(
|
| 1148 |
+
num_kernels,
|
| 1149 |
+
grad_col,
|
| 1150 |
+
data_value,
|
| 1151 |
+
data_spatial_shapes,
|
| 1152 |
+
data_level_start_index,
|
| 1153 |
+
data_sampling_loc,
|
| 1154 |
+
data_attn_weight,
|
| 1155 |
+
batch_size,
|
| 1156 |
+
spatial_size,
|
| 1157 |
+
num_heads,
|
| 1158 |
+
channels,
|
| 1159 |
+
num_levels,
|
| 1160 |
+
num_query,
|
| 1161 |
+
num_point,
|
| 1162 |
+
grad_value,
|
| 1163 |
+
grad_sampling_loc,
|
| 1164 |
+
grad_attn_weight);
|
| 1165 |
+
}
|
| 1166 |
+
}
|
| 1167 |
+
else{
|
| 1168 |
+
switch(channels)
|
| 1169 |
+
{
|
| 1170 |
+
case 1:
|
| 1171 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
|
| 1172 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1173 |
+
0, stream>>>(
|
| 1174 |
+
num_kernels,
|
| 1175 |
+
grad_col,
|
| 1176 |
+
data_value,
|
| 1177 |
+
data_spatial_shapes,
|
| 1178 |
+
data_level_start_index,
|
| 1179 |
+
data_sampling_loc,
|
| 1180 |
+
data_attn_weight,
|
| 1181 |
+
batch_size,
|
| 1182 |
+
spatial_size,
|
| 1183 |
+
num_heads,
|
| 1184 |
+
channels,
|
| 1185 |
+
num_levels,
|
| 1186 |
+
num_query,
|
| 1187 |
+
num_point,
|
| 1188 |
+
grad_value,
|
| 1189 |
+
grad_sampling_loc,
|
| 1190 |
+
grad_attn_weight);
|
| 1191 |
+
break;
|
| 1192 |
+
case 2:
|
| 1193 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
|
| 1194 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1195 |
+
0, stream>>>(
|
| 1196 |
+
num_kernels,
|
| 1197 |
+
grad_col,
|
| 1198 |
+
data_value,
|
| 1199 |
+
data_spatial_shapes,
|
| 1200 |
+
data_level_start_index,
|
| 1201 |
+
data_sampling_loc,
|
| 1202 |
+
data_attn_weight,
|
| 1203 |
+
batch_size,
|
| 1204 |
+
spatial_size,
|
| 1205 |
+
num_heads,
|
| 1206 |
+
channels,
|
| 1207 |
+
num_levels,
|
| 1208 |
+
num_query,
|
| 1209 |
+
num_point,
|
| 1210 |
+
grad_value,
|
| 1211 |
+
grad_sampling_loc,
|
| 1212 |
+
grad_attn_weight);
|
| 1213 |
+
break;
|
| 1214 |
+
case 4:
|
| 1215 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
|
| 1216 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1217 |
+
0, stream>>>(
|
| 1218 |
+
num_kernels,
|
| 1219 |
+
grad_col,
|
| 1220 |
+
data_value,
|
| 1221 |
+
data_spatial_shapes,
|
| 1222 |
+
data_level_start_index,
|
| 1223 |
+
data_sampling_loc,
|
| 1224 |
+
data_attn_weight,
|
| 1225 |
+
batch_size,
|
| 1226 |
+
spatial_size,
|
| 1227 |
+
num_heads,
|
| 1228 |
+
channels,
|
| 1229 |
+
num_levels,
|
| 1230 |
+
num_query,
|
| 1231 |
+
num_point,
|
| 1232 |
+
grad_value,
|
| 1233 |
+
grad_sampling_loc,
|
| 1234 |
+
grad_attn_weight);
|
| 1235 |
+
break;
|
| 1236 |
+
case 8:
|
| 1237 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
|
| 1238 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1239 |
+
0, stream>>>(
|
| 1240 |
+
num_kernels,
|
| 1241 |
+
grad_col,
|
| 1242 |
+
data_value,
|
| 1243 |
+
data_spatial_shapes,
|
| 1244 |
+
data_level_start_index,
|
| 1245 |
+
data_sampling_loc,
|
| 1246 |
+
data_attn_weight,
|
| 1247 |
+
batch_size,
|
| 1248 |
+
spatial_size,
|
| 1249 |
+
num_heads,
|
| 1250 |
+
channels,
|
| 1251 |
+
num_levels,
|
| 1252 |
+
num_query,
|
| 1253 |
+
num_point,
|
| 1254 |
+
grad_value,
|
| 1255 |
+
grad_sampling_loc,
|
| 1256 |
+
grad_attn_weight);
|
| 1257 |
+
break;
|
| 1258 |
+
case 16:
|
| 1259 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
|
| 1260 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1261 |
+
0, stream>>>(
|
| 1262 |
+
num_kernels,
|
| 1263 |
+
grad_col,
|
| 1264 |
+
data_value,
|
| 1265 |
+
data_spatial_shapes,
|
| 1266 |
+
data_level_start_index,
|
| 1267 |
+
data_sampling_loc,
|
| 1268 |
+
data_attn_weight,
|
| 1269 |
+
batch_size,
|
| 1270 |
+
spatial_size,
|
| 1271 |
+
num_heads,
|
| 1272 |
+
channels,
|
| 1273 |
+
num_levels,
|
| 1274 |
+
num_query,
|
| 1275 |
+
num_point,
|
| 1276 |
+
grad_value,
|
| 1277 |
+
grad_sampling_loc,
|
| 1278 |
+
grad_attn_weight);
|
| 1279 |
+
break;
|
| 1280 |
+
case 32:
|
| 1281 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
|
| 1282 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1283 |
+
0, stream>>>(
|
| 1284 |
+
num_kernels,
|
| 1285 |
+
grad_col,
|
| 1286 |
+
data_value,
|
| 1287 |
+
data_spatial_shapes,
|
| 1288 |
+
data_level_start_index,
|
| 1289 |
+
data_sampling_loc,
|
| 1290 |
+
data_attn_weight,
|
| 1291 |
+
batch_size,
|
| 1292 |
+
spatial_size,
|
| 1293 |
+
num_heads,
|
| 1294 |
+
channels,
|
| 1295 |
+
num_levels,
|
| 1296 |
+
num_query,
|
| 1297 |
+
num_point,
|
| 1298 |
+
grad_value,
|
| 1299 |
+
grad_sampling_loc,
|
| 1300 |
+
grad_attn_weight);
|
| 1301 |
+
break;
|
| 1302 |
+
case 64:
|
| 1303 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
|
| 1304 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1305 |
+
0, stream>>>(
|
| 1306 |
+
num_kernels,
|
| 1307 |
+
grad_col,
|
| 1308 |
+
data_value,
|
| 1309 |
+
data_spatial_shapes,
|
| 1310 |
+
data_level_start_index,
|
| 1311 |
+
data_sampling_loc,
|
| 1312 |
+
data_attn_weight,
|
| 1313 |
+
batch_size,
|
| 1314 |
+
spatial_size,
|
| 1315 |
+
num_heads,
|
| 1316 |
+
channels,
|
| 1317 |
+
num_levels,
|
| 1318 |
+
num_query,
|
| 1319 |
+
num_point,
|
| 1320 |
+
grad_value,
|
| 1321 |
+
grad_sampling_loc,
|
| 1322 |
+
grad_attn_weight);
|
| 1323 |
+
break;
|
| 1324 |
+
case 128:
|
| 1325 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
|
| 1326 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1327 |
+
0, stream>>>(
|
| 1328 |
+
num_kernels,
|
| 1329 |
+
grad_col,
|
| 1330 |
+
data_value,
|
| 1331 |
+
data_spatial_shapes,
|
| 1332 |
+
data_level_start_index,
|
| 1333 |
+
data_sampling_loc,
|
| 1334 |
+
data_attn_weight,
|
| 1335 |
+
batch_size,
|
| 1336 |
+
spatial_size,
|
| 1337 |
+
num_heads,
|
| 1338 |
+
channels,
|
| 1339 |
+
num_levels,
|
| 1340 |
+
num_query,
|
| 1341 |
+
num_point,
|
| 1342 |
+
grad_value,
|
| 1343 |
+
grad_sampling_loc,
|
| 1344 |
+
grad_attn_weight);
|
| 1345 |
+
break;
|
| 1346 |
+
case 256:
|
| 1347 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
|
| 1348 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1349 |
+
0, stream>>>(
|
| 1350 |
+
num_kernels,
|
| 1351 |
+
grad_col,
|
| 1352 |
+
data_value,
|
| 1353 |
+
data_spatial_shapes,
|
| 1354 |
+
data_level_start_index,
|
| 1355 |
+
data_sampling_loc,
|
| 1356 |
+
data_attn_weight,
|
| 1357 |
+
batch_size,
|
| 1358 |
+
spatial_size,
|
| 1359 |
+
num_heads,
|
| 1360 |
+
channels,
|
| 1361 |
+
num_levels,
|
| 1362 |
+
num_query,
|
| 1363 |
+
num_point,
|
| 1364 |
+
grad_value,
|
| 1365 |
+
grad_sampling_loc,
|
| 1366 |
+
grad_attn_weight);
|
| 1367 |
+
break;
|
| 1368 |
+
case 512:
|
| 1369 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
|
| 1370 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1371 |
+
0, stream>>>(
|
| 1372 |
+
num_kernels,
|
| 1373 |
+
grad_col,
|
| 1374 |
+
data_value,
|
| 1375 |
+
data_spatial_shapes,
|
| 1376 |
+
data_level_start_index,
|
| 1377 |
+
data_sampling_loc,
|
| 1378 |
+
data_attn_weight,
|
| 1379 |
+
batch_size,
|
| 1380 |
+
spatial_size,
|
| 1381 |
+
num_heads,
|
| 1382 |
+
channels,
|
| 1383 |
+
num_levels,
|
| 1384 |
+
num_query,
|
| 1385 |
+
num_point,
|
| 1386 |
+
grad_value,
|
| 1387 |
+
grad_sampling_loc,
|
| 1388 |
+
grad_attn_weight);
|
| 1389 |
+
break;
|
| 1390 |
+
case 1024:
|
| 1391 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
|
| 1392 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1393 |
+
0, stream>>>(
|
| 1394 |
+
num_kernels,
|
| 1395 |
+
grad_col,
|
| 1396 |
+
data_value,
|
| 1397 |
+
data_spatial_shapes,
|
| 1398 |
+
data_level_start_index,
|
| 1399 |
+
data_sampling_loc,
|
| 1400 |
+
data_attn_weight,
|
| 1401 |
+
batch_size,
|
| 1402 |
+
spatial_size,
|
| 1403 |
+
num_heads,
|
| 1404 |
+
channels,
|
| 1405 |
+
num_levels,
|
| 1406 |
+
num_query,
|
| 1407 |
+
num_point,
|
| 1408 |
+
grad_value,
|
| 1409 |
+
grad_sampling_loc,
|
| 1410 |
+
grad_attn_weight);
|
| 1411 |
+
break;
|
| 1412 |
+
default:
|
| 1413 |
+
if (channels < 64)
|
| 1414 |
+
{
|
| 1415 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
|
| 1416 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1417 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1418 |
+
num_kernels,
|
| 1419 |
+
grad_col,
|
| 1420 |
+
data_value,
|
| 1421 |
+
data_spatial_shapes,
|
| 1422 |
+
data_level_start_index,
|
| 1423 |
+
data_sampling_loc,
|
| 1424 |
+
data_attn_weight,
|
| 1425 |
+
batch_size,
|
| 1426 |
+
spatial_size,
|
| 1427 |
+
num_heads,
|
| 1428 |
+
channels,
|
| 1429 |
+
num_levels,
|
| 1430 |
+
num_query,
|
| 1431 |
+
num_point,
|
| 1432 |
+
grad_value,
|
| 1433 |
+
grad_sampling_loc,
|
| 1434 |
+
grad_attn_weight);
|
| 1435 |
+
}
|
| 1436 |
+
else
|
| 1437 |
+
{
|
| 1438 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
|
| 1439 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1440 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1441 |
+
num_kernels,
|
| 1442 |
+
grad_col,
|
| 1443 |
+
data_value,
|
| 1444 |
+
data_spatial_shapes,
|
| 1445 |
+
data_level_start_index,
|
| 1446 |
+
data_sampling_loc,
|
| 1447 |
+
data_attn_weight,
|
| 1448 |
+
batch_size,
|
| 1449 |
+
spatial_size,
|
| 1450 |
+
num_heads,
|
| 1451 |
+
channels,
|
| 1452 |
+
num_levels,
|
| 1453 |
+
num_query,
|
| 1454 |
+
num_point,
|
| 1455 |
+
grad_value,
|
| 1456 |
+
grad_sampling_loc,
|
| 1457 |
+
grad_attn_weight);
|
| 1458 |
+
}
|
| 1459 |
+
}
|
| 1460 |
+
}
|
| 1461 |
+
cudaError_t err = cudaGetLastError();
|
| 1462 |
+
if (err != cudaSuccess)
|
| 1463 |
+
{
|
| 1464 |
+
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
| 1465 |
+
}
|
| 1466 |
+
|
| 1467 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
|
| 14 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
| 15 |
+
const at::Tensor &value,
|
| 16 |
+
const at::Tensor &spatial_shapes,
|
| 17 |
+
const at::Tensor &level_start_index,
|
| 18 |
+
const at::Tensor &sampling_loc,
|
| 19 |
+
const at::Tensor &attn_weight,
|
| 20 |
+
const int im2col_step);
|
| 21 |
+
|
| 22 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
| 23 |
+
const at::Tensor &value,
|
| 24 |
+
const at::Tensor &spatial_shapes,
|
| 25 |
+
const at::Tensor &level_start_index,
|
| 26 |
+
const at::Tensor &sampling_loc,
|
| 27 |
+
const at::Tensor &attn_weight,
|
| 28 |
+
const at::Tensor &grad_output,
|
| 29 |
+
const int im2col_step);
|
.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh
ADDED
|
@@ -0,0 +1,1327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************
|
| 7 |
+
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
|
| 8 |
+
* Copyright (c) 2018 Microsoft
|
| 9 |
+
**************************************************************************
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include <cstdio>
|
| 13 |
+
#include <algorithm>
|
| 14 |
+
#include <cstring>
|
| 15 |
+
|
| 16 |
+
#include <ATen/ATen.h>
|
| 17 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 18 |
+
|
| 19 |
+
#include <THC/THCAtomics.cuh>
|
| 20 |
+
|
| 21 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
| 22 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
| 23 |
+
i < (n); \
|
| 24 |
+
i += blockDim.x * gridDim.x)
|
| 25 |
+
|
| 26 |
+
const int CUDA_NUM_THREADS = 1024;
|
| 27 |
+
inline int GET_BLOCKS(const int N, const int num_threads)
|
| 28 |
+
{
|
| 29 |
+
return (N + num_threads - 1) / num_threads;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
template <typename scalar_t>
|
| 34 |
+
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
|
| 35 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 36 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
|
| 37 |
+
{
|
| 38 |
+
const int h_low = floor(h);
|
| 39 |
+
const int w_low = floor(w);
|
| 40 |
+
const int h_high = h_low + 1;
|
| 41 |
+
const int w_high = w_low + 1;
|
| 42 |
+
|
| 43 |
+
const scalar_t lh = h - h_low;
|
| 44 |
+
const scalar_t lw = w - w_low;
|
| 45 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 46 |
+
|
| 47 |
+
const int w_stride = nheads * channels;
|
| 48 |
+
const int h_stride = width * w_stride;
|
| 49 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 50 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 51 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 52 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 53 |
+
const int base_ptr = m * channels + c;
|
| 54 |
+
|
| 55 |
+
scalar_t v1 = 0;
|
| 56 |
+
if (h_low >= 0 && w_low >= 0)
|
| 57 |
+
{
|
| 58 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 59 |
+
v1 = bottom_data[ptr1];
|
| 60 |
+
}
|
| 61 |
+
scalar_t v2 = 0;
|
| 62 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 63 |
+
{
|
| 64 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 65 |
+
v2 = bottom_data[ptr2];
|
| 66 |
+
}
|
| 67 |
+
scalar_t v3 = 0;
|
| 68 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 69 |
+
{
|
| 70 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 71 |
+
v3 = bottom_data[ptr3];
|
| 72 |
+
}
|
| 73 |
+
scalar_t v4 = 0;
|
| 74 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 75 |
+
{
|
| 76 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 77 |
+
v4 = bottom_data[ptr4];
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 81 |
+
|
| 82 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 83 |
+
return val;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
template <typename scalar_t>
|
| 88 |
+
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
|
| 89 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 90 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 91 |
+
const scalar_t &top_grad,
|
| 92 |
+
const scalar_t &attn_weight,
|
| 93 |
+
scalar_t* &grad_value,
|
| 94 |
+
scalar_t* grad_sampling_loc,
|
| 95 |
+
scalar_t* grad_attn_weight)
|
| 96 |
+
{
|
| 97 |
+
const int h_low = floor(h);
|
| 98 |
+
const int w_low = floor(w);
|
| 99 |
+
const int h_high = h_low + 1;
|
| 100 |
+
const int w_high = w_low + 1;
|
| 101 |
+
|
| 102 |
+
const scalar_t lh = h - h_low;
|
| 103 |
+
const scalar_t lw = w - w_low;
|
| 104 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 105 |
+
|
| 106 |
+
const int w_stride = nheads * channels;
|
| 107 |
+
const int h_stride = width * w_stride;
|
| 108 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 109 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 110 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 111 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 112 |
+
const int base_ptr = m * channels + c;
|
| 113 |
+
|
| 114 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 115 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 116 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 117 |
+
|
| 118 |
+
scalar_t v1 = 0;
|
| 119 |
+
if (h_low >= 0 && w_low >= 0)
|
| 120 |
+
{
|
| 121 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 122 |
+
v1 = bottom_data[ptr1];
|
| 123 |
+
grad_h_weight -= hw * v1;
|
| 124 |
+
grad_w_weight -= hh * v1;
|
| 125 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 126 |
+
}
|
| 127 |
+
scalar_t v2 = 0;
|
| 128 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 129 |
+
{
|
| 130 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 131 |
+
v2 = bottom_data[ptr2];
|
| 132 |
+
grad_h_weight -= lw * v2;
|
| 133 |
+
grad_w_weight += hh * v2;
|
| 134 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 135 |
+
}
|
| 136 |
+
scalar_t v3 = 0;
|
| 137 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 138 |
+
{
|
| 139 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 140 |
+
v3 = bottom_data[ptr3];
|
| 141 |
+
grad_h_weight += hw * v3;
|
| 142 |
+
grad_w_weight -= lh * v3;
|
| 143 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 144 |
+
}
|
| 145 |
+
scalar_t v4 = 0;
|
| 146 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 147 |
+
{
|
| 148 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 149 |
+
v4 = bottom_data[ptr4];
|
| 150 |
+
grad_h_weight += lw * v4;
|
| 151 |
+
grad_w_weight += lh * v4;
|
| 152 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 156 |
+
*grad_attn_weight = top_grad * val;
|
| 157 |
+
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
|
| 158 |
+
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
template <typename scalar_t>
|
| 163 |
+
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
|
| 164 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 165 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 166 |
+
const scalar_t &top_grad,
|
| 167 |
+
const scalar_t &attn_weight,
|
| 168 |
+
scalar_t* &grad_value,
|
| 169 |
+
scalar_t* grad_sampling_loc,
|
| 170 |
+
scalar_t* grad_attn_weight)
|
| 171 |
+
{
|
| 172 |
+
const int h_low = floor(h);
|
| 173 |
+
const int w_low = floor(w);
|
| 174 |
+
const int h_high = h_low + 1;
|
| 175 |
+
const int w_high = w_low + 1;
|
| 176 |
+
|
| 177 |
+
const scalar_t lh = h - h_low;
|
| 178 |
+
const scalar_t lw = w - w_low;
|
| 179 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 180 |
+
|
| 181 |
+
const int w_stride = nheads * channels;
|
| 182 |
+
const int h_stride = width * w_stride;
|
| 183 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 184 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 185 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 186 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 187 |
+
const int base_ptr = m * channels + c;
|
| 188 |
+
|
| 189 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 190 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 191 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 192 |
+
|
| 193 |
+
scalar_t v1 = 0;
|
| 194 |
+
if (h_low >= 0 && w_low >= 0)
|
| 195 |
+
{
|
| 196 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 197 |
+
v1 = bottom_data[ptr1];
|
| 198 |
+
grad_h_weight -= hw * v1;
|
| 199 |
+
grad_w_weight -= hh * v1;
|
| 200 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 201 |
+
}
|
| 202 |
+
scalar_t v2 = 0;
|
| 203 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 204 |
+
{
|
| 205 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 206 |
+
v2 = bottom_data[ptr2];
|
| 207 |
+
grad_h_weight -= lw * v2;
|
| 208 |
+
grad_w_weight += hh * v2;
|
| 209 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 210 |
+
}
|
| 211 |
+
scalar_t v3 = 0;
|
| 212 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 213 |
+
{
|
| 214 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 215 |
+
v3 = bottom_data[ptr3];
|
| 216 |
+
grad_h_weight += hw * v3;
|
| 217 |
+
grad_w_weight -= lh * v3;
|
| 218 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 219 |
+
}
|
| 220 |
+
scalar_t v4 = 0;
|
| 221 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 222 |
+
{
|
| 223 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 224 |
+
v4 = bottom_data[ptr4];
|
| 225 |
+
grad_h_weight += lw * v4;
|
| 226 |
+
grad_w_weight += lh * v4;
|
| 227 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 231 |
+
atomicAdd(grad_attn_weight, top_grad * val);
|
| 232 |
+
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
|
| 233 |
+
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
template <typename scalar_t>
|
| 238 |
+
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
|
| 239 |
+
const scalar_t *data_value,
|
| 240 |
+
const int64_t *data_spatial_shapes,
|
| 241 |
+
const int64_t *data_level_start_index,
|
| 242 |
+
const scalar_t *data_sampling_loc,
|
| 243 |
+
const scalar_t *data_attn_weight,
|
| 244 |
+
const int batch_size,
|
| 245 |
+
const int spatial_size,
|
| 246 |
+
const int num_heads,
|
| 247 |
+
const int channels,
|
| 248 |
+
const int num_levels,
|
| 249 |
+
const int num_query,
|
| 250 |
+
const int num_point,
|
| 251 |
+
scalar_t *data_col)
|
| 252 |
+
{
|
| 253 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 254 |
+
{
|
| 255 |
+
int _temp = index;
|
| 256 |
+
const int c_col = _temp % channels;
|
| 257 |
+
_temp /= channels;
|
| 258 |
+
const int sampling_index = _temp;
|
| 259 |
+
const int m_col = _temp % num_heads;
|
| 260 |
+
_temp /= num_heads;
|
| 261 |
+
const int q_col = _temp % num_query;
|
| 262 |
+
_temp /= num_query;
|
| 263 |
+
const int b_col = _temp;
|
| 264 |
+
|
| 265 |
+
scalar_t *data_col_ptr = data_col + index;
|
| 266 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 267 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 268 |
+
const int qid_stride = num_heads * channels;
|
| 269 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 270 |
+
scalar_t col = 0;
|
| 271 |
+
|
| 272 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 273 |
+
{
|
| 274 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 275 |
+
const int spatial_h_ptr = l_col << 1;
|
| 276 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 277 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 278 |
+
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
|
| 279 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 280 |
+
{
|
| 281 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 282 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 283 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 284 |
+
|
| 285 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 286 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 287 |
+
|
| 288 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 289 |
+
{
|
| 290 |
+
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
data_weight_ptr += 1;
|
| 294 |
+
data_loc_w_ptr += 2;
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
*data_col_ptr = col;
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 302 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
|
| 303 |
+
const scalar_t *grad_col,
|
| 304 |
+
const scalar_t *data_value,
|
| 305 |
+
const int64_t *data_spatial_shapes,
|
| 306 |
+
const int64_t *data_level_start_index,
|
| 307 |
+
const scalar_t *data_sampling_loc,
|
| 308 |
+
const scalar_t *data_attn_weight,
|
| 309 |
+
const int batch_size,
|
| 310 |
+
const int spatial_size,
|
| 311 |
+
const int num_heads,
|
| 312 |
+
const int channels,
|
| 313 |
+
const int num_levels,
|
| 314 |
+
const int num_query,
|
| 315 |
+
const int num_point,
|
| 316 |
+
scalar_t *grad_value,
|
| 317 |
+
scalar_t *grad_sampling_loc,
|
| 318 |
+
scalar_t *grad_attn_weight)
|
| 319 |
+
{
|
| 320 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 321 |
+
{
|
| 322 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 323 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 324 |
+
unsigned int tid = threadIdx.x;
|
| 325 |
+
int _temp = index;
|
| 326 |
+
const int c_col = _temp % channels;
|
| 327 |
+
_temp /= channels;
|
| 328 |
+
const int sampling_index = _temp;
|
| 329 |
+
const int m_col = _temp % num_heads;
|
| 330 |
+
_temp /= num_heads;
|
| 331 |
+
const int q_col = _temp % num_query;
|
| 332 |
+
_temp /= num_query;
|
| 333 |
+
const int b_col = _temp;
|
| 334 |
+
|
| 335 |
+
const scalar_t top_grad = grad_col[index];
|
| 336 |
+
|
| 337 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 338 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 339 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 340 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 341 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 342 |
+
const int grad_weight_stride = 1;
|
| 343 |
+
const int grad_loc_stride = 2;
|
| 344 |
+
const int qid_stride = num_heads * channels;
|
| 345 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 346 |
+
|
| 347 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 348 |
+
{
|
| 349 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 350 |
+
const int spatial_h_ptr = l_col << 1;
|
| 351 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 352 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 353 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 354 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 355 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 356 |
+
|
| 357 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 358 |
+
{
|
| 359 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 360 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 361 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 362 |
+
|
| 363 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 364 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 365 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 366 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 367 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 368 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 369 |
+
{
|
| 370 |
+
ms_deform_attn_col2im_bilinear(
|
| 371 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 372 |
+
top_grad, weight, grad_value_ptr,
|
| 373 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
__syncthreads();
|
| 377 |
+
if (tid == 0)
|
| 378 |
+
{
|
| 379 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 380 |
+
int sid=2;
|
| 381 |
+
for (unsigned int tid = 1; tid < blockSize; ++tid)
|
| 382 |
+
{
|
| 383 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 384 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 385 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 386 |
+
sid += 2;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
*grad_sampling_loc = _grad_w;
|
| 391 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 392 |
+
*grad_attn_weight = _grad_a;
|
| 393 |
+
}
|
| 394 |
+
__syncthreads();
|
| 395 |
+
|
| 396 |
+
data_weight_ptr += 1;
|
| 397 |
+
data_loc_w_ptr += 2;
|
| 398 |
+
grad_attn_weight += grad_weight_stride;
|
| 399 |
+
grad_sampling_loc += grad_loc_stride;
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 407 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
|
| 408 |
+
const scalar_t *grad_col,
|
| 409 |
+
const scalar_t *data_value,
|
| 410 |
+
const int64_t *data_spatial_shapes,
|
| 411 |
+
const int64_t *data_level_start_index,
|
| 412 |
+
const scalar_t *data_sampling_loc,
|
| 413 |
+
const scalar_t *data_attn_weight,
|
| 414 |
+
const int batch_size,
|
| 415 |
+
const int spatial_size,
|
| 416 |
+
const int num_heads,
|
| 417 |
+
const int channels,
|
| 418 |
+
const int num_levels,
|
| 419 |
+
const int num_query,
|
| 420 |
+
const int num_point,
|
| 421 |
+
scalar_t *grad_value,
|
| 422 |
+
scalar_t *grad_sampling_loc,
|
| 423 |
+
scalar_t *grad_attn_weight)
|
| 424 |
+
{
|
| 425 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 426 |
+
{
|
| 427 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 428 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 429 |
+
unsigned int tid = threadIdx.x;
|
| 430 |
+
int _temp = index;
|
| 431 |
+
const int c_col = _temp % channels;
|
| 432 |
+
_temp /= channels;
|
| 433 |
+
const int sampling_index = _temp;
|
| 434 |
+
const int m_col = _temp % num_heads;
|
| 435 |
+
_temp /= num_heads;
|
| 436 |
+
const int q_col = _temp % num_query;
|
| 437 |
+
_temp /= num_query;
|
| 438 |
+
const int b_col = _temp;
|
| 439 |
+
|
| 440 |
+
const scalar_t top_grad = grad_col[index];
|
| 441 |
+
|
| 442 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 443 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 444 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 445 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 446 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 447 |
+
const int grad_weight_stride = 1;
|
| 448 |
+
const int grad_loc_stride = 2;
|
| 449 |
+
const int qid_stride = num_heads * channels;
|
| 450 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 451 |
+
|
| 452 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 453 |
+
{
|
| 454 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 455 |
+
const int spatial_h_ptr = l_col << 1;
|
| 456 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 457 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 458 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 459 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 460 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 461 |
+
|
| 462 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 463 |
+
{
|
| 464 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 465 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 466 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 467 |
+
|
| 468 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 469 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 470 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 471 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 472 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 473 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 474 |
+
{
|
| 475 |
+
ms_deform_attn_col2im_bilinear(
|
| 476 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 477 |
+
top_grad, weight, grad_value_ptr,
|
| 478 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
__syncthreads();
|
| 482 |
+
|
| 483 |
+
for (unsigned int s=blockSize/2; s>0; s>>=1)
|
| 484 |
+
{
|
| 485 |
+
if (tid < s) {
|
| 486 |
+
const unsigned int xid1 = tid << 1;
|
| 487 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 488 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 489 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 490 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 491 |
+
}
|
| 492 |
+
__syncthreads();
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
if (tid == 0)
|
| 496 |
+
{
|
| 497 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 498 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 499 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 500 |
+
}
|
| 501 |
+
__syncthreads();
|
| 502 |
+
|
| 503 |
+
data_weight_ptr += 1;
|
| 504 |
+
data_loc_w_ptr += 2;
|
| 505 |
+
grad_attn_weight += grad_weight_stride;
|
| 506 |
+
grad_sampling_loc += grad_loc_stride;
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
template <typename scalar_t>
|
| 514 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
|
| 515 |
+
const scalar_t *grad_col,
|
| 516 |
+
const scalar_t *data_value,
|
| 517 |
+
const int64_t *data_spatial_shapes,
|
| 518 |
+
const int64_t *data_level_start_index,
|
| 519 |
+
const scalar_t *data_sampling_loc,
|
| 520 |
+
const scalar_t *data_attn_weight,
|
| 521 |
+
const int batch_size,
|
| 522 |
+
const int spatial_size,
|
| 523 |
+
const int num_heads,
|
| 524 |
+
const int channels,
|
| 525 |
+
const int num_levels,
|
| 526 |
+
const int num_query,
|
| 527 |
+
const int num_point,
|
| 528 |
+
scalar_t *grad_value,
|
| 529 |
+
scalar_t *grad_sampling_loc,
|
| 530 |
+
scalar_t *grad_attn_weight)
|
| 531 |
+
{
|
| 532 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 533 |
+
{
|
| 534 |
+
extern __shared__ int _s[];
|
| 535 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 536 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 537 |
+
unsigned int tid = threadIdx.x;
|
| 538 |
+
int _temp = index;
|
| 539 |
+
const int c_col = _temp % channels;
|
| 540 |
+
_temp /= channels;
|
| 541 |
+
const int sampling_index = _temp;
|
| 542 |
+
const int m_col = _temp % num_heads;
|
| 543 |
+
_temp /= num_heads;
|
| 544 |
+
const int q_col = _temp % num_query;
|
| 545 |
+
_temp /= num_query;
|
| 546 |
+
const int b_col = _temp;
|
| 547 |
+
|
| 548 |
+
const scalar_t top_grad = grad_col[index];
|
| 549 |
+
|
| 550 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 551 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 552 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 553 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 554 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 555 |
+
const int grad_weight_stride = 1;
|
| 556 |
+
const int grad_loc_stride = 2;
|
| 557 |
+
const int qid_stride = num_heads * channels;
|
| 558 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 559 |
+
|
| 560 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 561 |
+
{
|
| 562 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 563 |
+
const int spatial_h_ptr = l_col << 1;
|
| 564 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 565 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 566 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 567 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 568 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 569 |
+
|
| 570 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 571 |
+
{
|
| 572 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 573 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 574 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 575 |
+
|
| 576 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 577 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 578 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 579 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 580 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 581 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 582 |
+
{
|
| 583 |
+
ms_deform_attn_col2im_bilinear(
|
| 584 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 585 |
+
top_grad, weight, grad_value_ptr,
|
| 586 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
__syncthreads();
|
| 590 |
+
if (tid == 0)
|
| 591 |
+
{
|
| 592 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 593 |
+
int sid=2;
|
| 594 |
+
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
|
| 595 |
+
{
|
| 596 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 597 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 598 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 599 |
+
sid += 2;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
*grad_sampling_loc = _grad_w;
|
| 604 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 605 |
+
*grad_attn_weight = _grad_a;
|
| 606 |
+
}
|
| 607 |
+
__syncthreads();
|
| 608 |
+
|
| 609 |
+
data_weight_ptr += 1;
|
| 610 |
+
data_loc_w_ptr += 2;
|
| 611 |
+
grad_attn_weight += grad_weight_stride;
|
| 612 |
+
grad_sampling_loc += grad_loc_stride;
|
| 613 |
+
}
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
template <typename scalar_t>
|
| 619 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
|
| 620 |
+
const scalar_t *grad_col,
|
| 621 |
+
const scalar_t *data_value,
|
| 622 |
+
const int64_t *data_spatial_shapes,
|
| 623 |
+
const int64_t *data_level_start_index,
|
| 624 |
+
const scalar_t *data_sampling_loc,
|
| 625 |
+
const scalar_t *data_attn_weight,
|
| 626 |
+
const int batch_size,
|
| 627 |
+
const int spatial_size,
|
| 628 |
+
const int num_heads,
|
| 629 |
+
const int channels,
|
| 630 |
+
const int num_levels,
|
| 631 |
+
const int num_query,
|
| 632 |
+
const int num_point,
|
| 633 |
+
scalar_t *grad_value,
|
| 634 |
+
scalar_t *grad_sampling_loc,
|
| 635 |
+
scalar_t *grad_attn_weight)
|
| 636 |
+
{
|
| 637 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 638 |
+
{
|
| 639 |
+
extern __shared__ int _s[];
|
| 640 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 641 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 642 |
+
unsigned int tid = threadIdx.x;
|
| 643 |
+
int _temp = index;
|
| 644 |
+
const int c_col = _temp % channels;
|
| 645 |
+
_temp /= channels;
|
| 646 |
+
const int sampling_index = _temp;
|
| 647 |
+
const int m_col = _temp % num_heads;
|
| 648 |
+
_temp /= num_heads;
|
| 649 |
+
const int q_col = _temp % num_query;
|
| 650 |
+
_temp /= num_query;
|
| 651 |
+
const int b_col = _temp;
|
| 652 |
+
|
| 653 |
+
const scalar_t top_grad = grad_col[index];
|
| 654 |
+
|
| 655 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 656 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 657 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 658 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 659 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 660 |
+
const int grad_weight_stride = 1;
|
| 661 |
+
const int grad_loc_stride = 2;
|
| 662 |
+
const int qid_stride = num_heads * channels;
|
| 663 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 664 |
+
|
| 665 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 666 |
+
{
|
| 667 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 668 |
+
const int spatial_h_ptr = l_col << 1;
|
| 669 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 670 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 671 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 672 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 673 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 674 |
+
|
| 675 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 676 |
+
{
|
| 677 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 678 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 679 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 680 |
+
|
| 681 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 682 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 683 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 684 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 685 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 686 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 687 |
+
{
|
| 688 |
+
ms_deform_attn_col2im_bilinear(
|
| 689 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 690 |
+
top_grad, weight, grad_value_ptr,
|
| 691 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
__syncthreads();
|
| 695 |
+
|
| 696 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 697 |
+
{
|
| 698 |
+
if (tid < s) {
|
| 699 |
+
const unsigned int xid1 = tid << 1;
|
| 700 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 701 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 702 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 703 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 704 |
+
if (tid + (s << 1) < spre)
|
| 705 |
+
{
|
| 706 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 707 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 708 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 709 |
+
}
|
| 710 |
+
}
|
| 711 |
+
__syncthreads();
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
if (tid == 0)
|
| 715 |
+
{
|
| 716 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 717 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 718 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 719 |
+
}
|
| 720 |
+
__syncthreads();
|
| 721 |
+
|
| 722 |
+
data_weight_ptr += 1;
|
| 723 |
+
data_loc_w_ptr += 2;
|
| 724 |
+
grad_attn_weight += grad_weight_stride;
|
| 725 |
+
grad_sampling_loc += grad_loc_stride;
|
| 726 |
+
}
|
| 727 |
+
}
|
| 728 |
+
}
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
template <typename scalar_t>
|
| 732 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
|
| 733 |
+
const scalar_t *grad_col,
|
| 734 |
+
const scalar_t *data_value,
|
| 735 |
+
const int64_t *data_spatial_shapes,
|
| 736 |
+
const int64_t *data_level_start_index,
|
| 737 |
+
const scalar_t *data_sampling_loc,
|
| 738 |
+
const scalar_t *data_attn_weight,
|
| 739 |
+
const int batch_size,
|
| 740 |
+
const int spatial_size,
|
| 741 |
+
const int num_heads,
|
| 742 |
+
const int channels,
|
| 743 |
+
const int num_levels,
|
| 744 |
+
const int num_query,
|
| 745 |
+
const int num_point,
|
| 746 |
+
scalar_t *grad_value,
|
| 747 |
+
scalar_t *grad_sampling_loc,
|
| 748 |
+
scalar_t *grad_attn_weight)
|
| 749 |
+
{
|
| 750 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 751 |
+
{
|
| 752 |
+
extern __shared__ int _s[];
|
| 753 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 754 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 755 |
+
unsigned int tid = threadIdx.x;
|
| 756 |
+
int _temp = index;
|
| 757 |
+
const int c_col = _temp % channels;
|
| 758 |
+
_temp /= channels;
|
| 759 |
+
const int sampling_index = _temp;
|
| 760 |
+
const int m_col = _temp % num_heads;
|
| 761 |
+
_temp /= num_heads;
|
| 762 |
+
const int q_col = _temp % num_query;
|
| 763 |
+
_temp /= num_query;
|
| 764 |
+
const int b_col = _temp;
|
| 765 |
+
|
| 766 |
+
const scalar_t top_grad = grad_col[index];
|
| 767 |
+
|
| 768 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 769 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 770 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 771 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 772 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 773 |
+
const int grad_weight_stride = 1;
|
| 774 |
+
const int grad_loc_stride = 2;
|
| 775 |
+
const int qid_stride = num_heads * channels;
|
| 776 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 777 |
+
|
| 778 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 779 |
+
{
|
| 780 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 781 |
+
const int spatial_h_ptr = l_col << 1;
|
| 782 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 783 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 784 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 785 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 786 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 787 |
+
|
| 788 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 789 |
+
{
|
| 790 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 791 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 792 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 793 |
+
|
| 794 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 795 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 796 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 797 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 798 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 799 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 800 |
+
{
|
| 801 |
+
ms_deform_attn_col2im_bilinear(
|
| 802 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 803 |
+
top_grad, weight, grad_value_ptr,
|
| 804 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
__syncthreads();
|
| 808 |
+
|
| 809 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 810 |
+
{
|
| 811 |
+
if (tid < s) {
|
| 812 |
+
const unsigned int xid1 = tid << 1;
|
| 813 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 814 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 815 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 816 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 817 |
+
if (tid + (s << 1) < spre)
|
| 818 |
+
{
|
| 819 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 820 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 821 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 822 |
+
}
|
| 823 |
+
}
|
| 824 |
+
__syncthreads();
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
if (tid == 0)
|
| 828 |
+
{
|
| 829 |
+
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
|
| 830 |
+
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
|
| 831 |
+
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
|
| 832 |
+
}
|
| 833 |
+
__syncthreads();
|
| 834 |
+
|
| 835 |
+
data_weight_ptr += 1;
|
| 836 |
+
data_loc_w_ptr += 2;
|
| 837 |
+
grad_attn_weight += grad_weight_stride;
|
| 838 |
+
grad_sampling_loc += grad_loc_stride;
|
| 839 |
+
}
|
| 840 |
+
}
|
| 841 |
+
}
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
template <typename scalar_t>
|
| 846 |
+
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
|
| 847 |
+
const scalar_t *grad_col,
|
| 848 |
+
const scalar_t *data_value,
|
| 849 |
+
const int64_t *data_spatial_shapes,
|
| 850 |
+
const int64_t *data_level_start_index,
|
| 851 |
+
const scalar_t *data_sampling_loc,
|
| 852 |
+
const scalar_t *data_attn_weight,
|
| 853 |
+
const int batch_size,
|
| 854 |
+
const int spatial_size,
|
| 855 |
+
const int num_heads,
|
| 856 |
+
const int channels,
|
| 857 |
+
const int num_levels,
|
| 858 |
+
const int num_query,
|
| 859 |
+
const int num_point,
|
| 860 |
+
scalar_t *grad_value,
|
| 861 |
+
scalar_t *grad_sampling_loc,
|
| 862 |
+
scalar_t *grad_attn_weight)
|
| 863 |
+
{
|
| 864 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 865 |
+
{
|
| 866 |
+
int _temp = index;
|
| 867 |
+
const int c_col = _temp % channels;
|
| 868 |
+
_temp /= channels;
|
| 869 |
+
const int sampling_index = _temp;
|
| 870 |
+
const int m_col = _temp % num_heads;
|
| 871 |
+
_temp /= num_heads;
|
| 872 |
+
const int q_col = _temp % num_query;
|
| 873 |
+
_temp /= num_query;
|
| 874 |
+
const int b_col = _temp;
|
| 875 |
+
|
| 876 |
+
const scalar_t top_grad = grad_col[index];
|
| 877 |
+
|
| 878 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 879 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 880 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 881 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 882 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 883 |
+
const int grad_weight_stride = 1;
|
| 884 |
+
const int grad_loc_stride = 2;
|
| 885 |
+
const int qid_stride = num_heads * channels;
|
| 886 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 887 |
+
|
| 888 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 889 |
+
{
|
| 890 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 891 |
+
const int spatial_h_ptr = l_col << 1;
|
| 892 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 893 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 894 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 895 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 896 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 897 |
+
|
| 898 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 899 |
+
{
|
| 900 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 901 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 902 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 903 |
+
|
| 904 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 905 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 906 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 907 |
+
{
|
| 908 |
+
ms_deform_attn_col2im_bilinear_gm(
|
| 909 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 910 |
+
top_grad, weight, grad_value_ptr,
|
| 911 |
+
grad_sampling_loc, grad_attn_weight);
|
| 912 |
+
}
|
| 913 |
+
data_weight_ptr += 1;
|
| 914 |
+
data_loc_w_ptr += 2;
|
| 915 |
+
grad_attn_weight += grad_weight_stride;
|
| 916 |
+
grad_sampling_loc += grad_loc_stride;
|
| 917 |
+
}
|
| 918 |
+
}
|
| 919 |
+
}
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
template <typename scalar_t>
|
| 924 |
+
void ms_deformable_im2col_cuda(cudaStream_t stream,
|
| 925 |
+
const scalar_t* data_value,
|
| 926 |
+
const int64_t* data_spatial_shapes,
|
| 927 |
+
const int64_t* data_level_start_index,
|
| 928 |
+
const scalar_t* data_sampling_loc,
|
| 929 |
+
const scalar_t* data_attn_weight,
|
| 930 |
+
const int batch_size,
|
| 931 |
+
const int spatial_size,
|
| 932 |
+
const int num_heads,
|
| 933 |
+
const int channels,
|
| 934 |
+
const int num_levels,
|
| 935 |
+
const int num_query,
|
| 936 |
+
const int num_point,
|
| 937 |
+
scalar_t* data_col)
|
| 938 |
+
{
|
| 939 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 940 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 941 |
+
const int num_threads = CUDA_NUM_THREADS;
|
| 942 |
+
ms_deformable_im2col_gpu_kernel<scalar_t>
|
| 943 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 944 |
+
0, stream>>>(
|
| 945 |
+
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
|
| 946 |
+
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
|
| 947 |
+
|
| 948 |
+
cudaError_t err = cudaGetLastError();
|
| 949 |
+
if (err != cudaSuccess)
|
| 950 |
+
{
|
| 951 |
+
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
template <typename scalar_t>
|
| 957 |
+
void ms_deformable_col2im_cuda(cudaStream_t stream,
|
| 958 |
+
const scalar_t* grad_col,
|
| 959 |
+
const scalar_t* data_value,
|
| 960 |
+
const int64_t * data_spatial_shapes,
|
| 961 |
+
const int64_t * data_level_start_index,
|
| 962 |
+
const scalar_t * data_sampling_loc,
|
| 963 |
+
const scalar_t * data_attn_weight,
|
| 964 |
+
const int batch_size,
|
| 965 |
+
const int spatial_size,
|
| 966 |
+
const int num_heads,
|
| 967 |
+
const int channels,
|
| 968 |
+
const int num_levels,
|
| 969 |
+
const int num_query,
|
| 970 |
+
const int num_point,
|
| 971 |
+
scalar_t* grad_value,
|
| 972 |
+
scalar_t* grad_sampling_loc,
|
| 973 |
+
scalar_t* grad_attn_weight)
|
| 974 |
+
{
|
| 975 |
+
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
|
| 976 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 977 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 978 |
+
if (channels > 1024)
|
| 979 |
+
{
|
| 980 |
+
if ((channels & 1023) == 0)
|
| 981 |
+
{
|
| 982 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
|
| 983 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 984 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 985 |
+
num_kernels,
|
| 986 |
+
grad_col,
|
| 987 |
+
data_value,
|
| 988 |
+
data_spatial_shapes,
|
| 989 |
+
data_level_start_index,
|
| 990 |
+
data_sampling_loc,
|
| 991 |
+
data_attn_weight,
|
| 992 |
+
batch_size,
|
| 993 |
+
spatial_size,
|
| 994 |
+
num_heads,
|
| 995 |
+
channels,
|
| 996 |
+
num_levels,
|
| 997 |
+
num_query,
|
| 998 |
+
num_point,
|
| 999 |
+
grad_value,
|
| 1000 |
+
grad_sampling_loc,
|
| 1001 |
+
grad_attn_weight);
|
| 1002 |
+
}
|
| 1003 |
+
else
|
| 1004 |
+
{
|
| 1005 |
+
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
|
| 1006 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1007 |
+
0, stream>>>(
|
| 1008 |
+
num_kernels,
|
| 1009 |
+
grad_col,
|
| 1010 |
+
data_value,
|
| 1011 |
+
data_spatial_shapes,
|
| 1012 |
+
data_level_start_index,
|
| 1013 |
+
data_sampling_loc,
|
| 1014 |
+
data_attn_weight,
|
| 1015 |
+
batch_size,
|
| 1016 |
+
spatial_size,
|
| 1017 |
+
num_heads,
|
| 1018 |
+
channels,
|
| 1019 |
+
num_levels,
|
| 1020 |
+
num_query,
|
| 1021 |
+
num_point,
|
| 1022 |
+
grad_value,
|
| 1023 |
+
grad_sampling_loc,
|
| 1024 |
+
grad_attn_weight);
|
| 1025 |
+
}
|
| 1026 |
+
}
|
| 1027 |
+
else{
|
| 1028 |
+
switch(channels)
|
| 1029 |
+
{
|
| 1030 |
+
case 1:
|
| 1031 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
|
| 1032 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1033 |
+
0, stream>>>(
|
| 1034 |
+
num_kernels,
|
| 1035 |
+
grad_col,
|
| 1036 |
+
data_value,
|
| 1037 |
+
data_spatial_shapes,
|
| 1038 |
+
data_level_start_index,
|
| 1039 |
+
data_sampling_loc,
|
| 1040 |
+
data_attn_weight,
|
| 1041 |
+
batch_size,
|
| 1042 |
+
spatial_size,
|
| 1043 |
+
num_heads,
|
| 1044 |
+
channels,
|
| 1045 |
+
num_levels,
|
| 1046 |
+
num_query,
|
| 1047 |
+
num_point,
|
| 1048 |
+
grad_value,
|
| 1049 |
+
grad_sampling_loc,
|
| 1050 |
+
grad_attn_weight);
|
| 1051 |
+
break;
|
| 1052 |
+
case 2:
|
| 1053 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
|
| 1054 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1055 |
+
0, stream>>>(
|
| 1056 |
+
num_kernels,
|
| 1057 |
+
grad_col,
|
| 1058 |
+
data_value,
|
| 1059 |
+
data_spatial_shapes,
|
| 1060 |
+
data_level_start_index,
|
| 1061 |
+
data_sampling_loc,
|
| 1062 |
+
data_attn_weight,
|
| 1063 |
+
batch_size,
|
| 1064 |
+
spatial_size,
|
| 1065 |
+
num_heads,
|
| 1066 |
+
channels,
|
| 1067 |
+
num_levels,
|
| 1068 |
+
num_query,
|
| 1069 |
+
num_point,
|
| 1070 |
+
grad_value,
|
| 1071 |
+
grad_sampling_loc,
|
| 1072 |
+
grad_attn_weight);
|
| 1073 |
+
break;
|
| 1074 |
+
case 4:
|
| 1075 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
|
| 1076 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1077 |
+
0, stream>>>(
|
| 1078 |
+
num_kernels,
|
| 1079 |
+
grad_col,
|
| 1080 |
+
data_value,
|
| 1081 |
+
data_spatial_shapes,
|
| 1082 |
+
data_level_start_index,
|
| 1083 |
+
data_sampling_loc,
|
| 1084 |
+
data_attn_weight,
|
| 1085 |
+
batch_size,
|
| 1086 |
+
spatial_size,
|
| 1087 |
+
num_heads,
|
| 1088 |
+
channels,
|
| 1089 |
+
num_levels,
|
| 1090 |
+
num_query,
|
| 1091 |
+
num_point,
|
| 1092 |
+
grad_value,
|
| 1093 |
+
grad_sampling_loc,
|
| 1094 |
+
grad_attn_weight);
|
| 1095 |
+
break;
|
| 1096 |
+
case 8:
|
| 1097 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
|
| 1098 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1099 |
+
0, stream>>>(
|
| 1100 |
+
num_kernels,
|
| 1101 |
+
grad_col,
|
| 1102 |
+
data_value,
|
| 1103 |
+
data_spatial_shapes,
|
| 1104 |
+
data_level_start_index,
|
| 1105 |
+
data_sampling_loc,
|
| 1106 |
+
data_attn_weight,
|
| 1107 |
+
batch_size,
|
| 1108 |
+
spatial_size,
|
| 1109 |
+
num_heads,
|
| 1110 |
+
channels,
|
| 1111 |
+
num_levels,
|
| 1112 |
+
num_query,
|
| 1113 |
+
num_point,
|
| 1114 |
+
grad_value,
|
| 1115 |
+
grad_sampling_loc,
|
| 1116 |
+
grad_attn_weight);
|
| 1117 |
+
break;
|
| 1118 |
+
case 16:
|
| 1119 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
|
| 1120 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1121 |
+
0, stream>>>(
|
| 1122 |
+
num_kernels,
|
| 1123 |
+
grad_col,
|
| 1124 |
+
data_value,
|
| 1125 |
+
data_spatial_shapes,
|
| 1126 |
+
data_level_start_index,
|
| 1127 |
+
data_sampling_loc,
|
| 1128 |
+
data_attn_weight,
|
| 1129 |
+
batch_size,
|
| 1130 |
+
spatial_size,
|
| 1131 |
+
num_heads,
|
| 1132 |
+
channels,
|
| 1133 |
+
num_levels,
|
| 1134 |
+
num_query,
|
| 1135 |
+
num_point,
|
| 1136 |
+
grad_value,
|
| 1137 |
+
grad_sampling_loc,
|
| 1138 |
+
grad_attn_weight);
|
| 1139 |
+
break;
|
| 1140 |
+
case 32:
|
| 1141 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
|
| 1142 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1143 |
+
0, stream>>>(
|
| 1144 |
+
num_kernels,
|
| 1145 |
+
grad_col,
|
| 1146 |
+
data_value,
|
| 1147 |
+
data_spatial_shapes,
|
| 1148 |
+
data_level_start_index,
|
| 1149 |
+
data_sampling_loc,
|
| 1150 |
+
data_attn_weight,
|
| 1151 |
+
batch_size,
|
| 1152 |
+
spatial_size,
|
| 1153 |
+
num_heads,
|
| 1154 |
+
channels,
|
| 1155 |
+
num_levels,
|
| 1156 |
+
num_query,
|
| 1157 |
+
num_point,
|
| 1158 |
+
grad_value,
|
| 1159 |
+
grad_sampling_loc,
|
| 1160 |
+
grad_attn_weight);
|
| 1161 |
+
break;
|
| 1162 |
+
case 64:
|
| 1163 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
|
| 1164 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1165 |
+
0, stream>>>(
|
| 1166 |
+
num_kernels,
|
| 1167 |
+
grad_col,
|
| 1168 |
+
data_value,
|
| 1169 |
+
data_spatial_shapes,
|
| 1170 |
+
data_level_start_index,
|
| 1171 |
+
data_sampling_loc,
|
| 1172 |
+
data_attn_weight,
|
| 1173 |
+
batch_size,
|
| 1174 |
+
spatial_size,
|
| 1175 |
+
num_heads,
|
| 1176 |
+
channels,
|
| 1177 |
+
num_levels,
|
| 1178 |
+
num_query,
|
| 1179 |
+
num_point,
|
| 1180 |
+
grad_value,
|
| 1181 |
+
grad_sampling_loc,
|
| 1182 |
+
grad_attn_weight);
|
| 1183 |
+
break;
|
| 1184 |
+
case 128:
|
| 1185 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
|
| 1186 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1187 |
+
0, stream>>>(
|
| 1188 |
+
num_kernels,
|
| 1189 |
+
grad_col,
|
| 1190 |
+
data_value,
|
| 1191 |
+
data_spatial_shapes,
|
| 1192 |
+
data_level_start_index,
|
| 1193 |
+
data_sampling_loc,
|
| 1194 |
+
data_attn_weight,
|
| 1195 |
+
batch_size,
|
| 1196 |
+
spatial_size,
|
| 1197 |
+
num_heads,
|
| 1198 |
+
channels,
|
| 1199 |
+
num_levels,
|
| 1200 |
+
num_query,
|
| 1201 |
+
num_point,
|
| 1202 |
+
grad_value,
|
| 1203 |
+
grad_sampling_loc,
|
| 1204 |
+
grad_attn_weight);
|
| 1205 |
+
break;
|
| 1206 |
+
case 256:
|
| 1207 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
|
| 1208 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1209 |
+
0, stream>>>(
|
| 1210 |
+
num_kernels,
|
| 1211 |
+
grad_col,
|
| 1212 |
+
data_value,
|
| 1213 |
+
data_spatial_shapes,
|
| 1214 |
+
data_level_start_index,
|
| 1215 |
+
data_sampling_loc,
|
| 1216 |
+
data_attn_weight,
|
| 1217 |
+
batch_size,
|
| 1218 |
+
spatial_size,
|
| 1219 |
+
num_heads,
|
| 1220 |
+
channels,
|
| 1221 |
+
num_levels,
|
| 1222 |
+
num_query,
|
| 1223 |
+
num_point,
|
| 1224 |
+
grad_value,
|
| 1225 |
+
grad_sampling_loc,
|
| 1226 |
+
grad_attn_weight);
|
| 1227 |
+
break;
|
| 1228 |
+
case 512:
|
| 1229 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
|
| 1230 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1231 |
+
0, stream>>>(
|
| 1232 |
+
num_kernels,
|
| 1233 |
+
grad_col,
|
| 1234 |
+
data_value,
|
| 1235 |
+
data_spatial_shapes,
|
| 1236 |
+
data_level_start_index,
|
| 1237 |
+
data_sampling_loc,
|
| 1238 |
+
data_attn_weight,
|
| 1239 |
+
batch_size,
|
| 1240 |
+
spatial_size,
|
| 1241 |
+
num_heads,
|
| 1242 |
+
channels,
|
| 1243 |
+
num_levels,
|
| 1244 |
+
num_query,
|
| 1245 |
+
num_point,
|
| 1246 |
+
grad_value,
|
| 1247 |
+
grad_sampling_loc,
|
| 1248 |
+
grad_attn_weight);
|
| 1249 |
+
break;
|
| 1250 |
+
case 1024:
|
| 1251 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
|
| 1252 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1253 |
+
0, stream>>>(
|
| 1254 |
+
num_kernels,
|
| 1255 |
+
grad_col,
|
| 1256 |
+
data_value,
|
| 1257 |
+
data_spatial_shapes,
|
| 1258 |
+
data_level_start_index,
|
| 1259 |
+
data_sampling_loc,
|
| 1260 |
+
data_attn_weight,
|
| 1261 |
+
batch_size,
|
| 1262 |
+
spatial_size,
|
| 1263 |
+
num_heads,
|
| 1264 |
+
channels,
|
| 1265 |
+
num_levels,
|
| 1266 |
+
num_query,
|
| 1267 |
+
num_point,
|
| 1268 |
+
grad_value,
|
| 1269 |
+
grad_sampling_loc,
|
| 1270 |
+
grad_attn_weight);
|
| 1271 |
+
break;
|
| 1272 |
+
default:
|
| 1273 |
+
if (channels < 64)
|
| 1274 |
+
{
|
| 1275 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
|
| 1276 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1277 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1278 |
+
num_kernels,
|
| 1279 |
+
grad_col,
|
| 1280 |
+
data_value,
|
| 1281 |
+
data_spatial_shapes,
|
| 1282 |
+
data_level_start_index,
|
| 1283 |
+
data_sampling_loc,
|
| 1284 |
+
data_attn_weight,
|
| 1285 |
+
batch_size,
|
| 1286 |
+
spatial_size,
|
| 1287 |
+
num_heads,
|
| 1288 |
+
channels,
|
| 1289 |
+
num_levels,
|
| 1290 |
+
num_query,
|
| 1291 |
+
num_point,
|
| 1292 |
+
grad_value,
|
| 1293 |
+
grad_sampling_loc,
|
| 1294 |
+
grad_attn_weight);
|
| 1295 |
+
}
|
| 1296 |
+
else
|
| 1297 |
+
{
|
| 1298 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
|
| 1299 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1300 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1301 |
+
num_kernels,
|
| 1302 |
+
grad_col,
|
| 1303 |
+
data_value,
|
| 1304 |
+
data_spatial_shapes,
|
| 1305 |
+
data_level_start_index,
|
| 1306 |
+
data_sampling_loc,
|
| 1307 |
+
data_attn_weight,
|
| 1308 |
+
batch_size,
|
| 1309 |
+
spatial_size,
|
| 1310 |
+
num_heads,
|
| 1311 |
+
channels,
|
| 1312 |
+
num_levels,
|
| 1313 |
+
num_query,
|
| 1314 |
+
num_point,
|
| 1315 |
+
grad_value,
|
| 1316 |
+
grad_sampling_loc,
|
| 1317 |
+
grad_attn_weight);
|
| 1318 |
+
}
|
| 1319 |
+
}
|
| 1320 |
+
}
|
| 1321 |
+
cudaError_t err = cudaGetLastError();
|
| 1322 |
+
if (err != cudaSuccess)
|
| 1323 |
+
{
|
| 1324 |
+
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
| 1325 |
+
}
|
| 1326 |
+
|
| 1327 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/deta/ms_deform_attn.h
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include "cpu/ms_deform_attn_cpu.h"
|
| 14 |
+
|
| 15 |
+
#ifdef WITH_CUDA
|
| 16 |
+
#include "cuda/ms_deform_attn_cuda.h"
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
at::Tensor
|
| 21 |
+
ms_deform_attn_forward(
|
| 22 |
+
const at::Tensor &value,
|
| 23 |
+
const at::Tensor &spatial_shapes,
|
| 24 |
+
const at::Tensor &level_start_index,
|
| 25 |
+
const at::Tensor &sampling_loc,
|
| 26 |
+
const at::Tensor &attn_weight,
|
| 27 |
+
const int im2col_step)
|
| 28 |
+
{
|
| 29 |
+
if (value.type().is_cuda())
|
| 30 |
+
{
|
| 31 |
+
#ifdef WITH_CUDA
|
| 32 |
+
return ms_deform_attn_cuda_forward(
|
| 33 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
| 34 |
+
#else
|
| 35 |
+
AT_ERROR("Not compiled with GPU support");
|
| 36 |
+
#endif
|
| 37 |
+
}
|
| 38 |
+
AT_ERROR("Not implemented on the CPU");
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
std::vector<at::Tensor>
|
| 42 |
+
ms_deform_attn_backward(
|
| 43 |
+
const at::Tensor &value,
|
| 44 |
+
const at::Tensor &spatial_shapes,
|
| 45 |
+
const at::Tensor &level_start_index,
|
| 46 |
+
const at::Tensor &sampling_loc,
|
| 47 |
+
const at::Tensor &attn_weight,
|
| 48 |
+
const at::Tensor &grad_output,
|
| 49 |
+
const int im2col_step)
|
| 50 |
+
{
|
| 51 |
+
if (value.type().is_cuda())
|
| 52 |
+
{
|
| 53 |
+
#ifdef WITH_CUDA
|
| 54 |
+
return ms_deform_attn_cuda_backward(
|
| 55 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
| 56 |
+
#else
|
| 57 |
+
AT_ERROR("Not compiled with GPU support");
|
| 58 |
+
#endif
|
| 59 |
+
}
|
| 60 |
+
AT_ERROR("Not implemented on the CPU");
|
| 61 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/deta/vision.cpp
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include "ms_deform_attn.h"
|
| 12 |
+
|
| 13 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 14 |
+
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
| 15 |
+
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
| 16 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-311.pyc
ADDED
|
Binary file (21.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda.cu
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <stdio.h>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
|
| 4 |
+
#define MIN_VALUE (-1e38)
|
| 5 |
+
|
| 6 |
+
template <typename F>
|
| 7 |
+
__global__ void kernel_forward(
|
| 8 |
+
const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
|
| 9 |
+
const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y
|
| 10 |
+
) {
|
| 11 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 12 |
+
const int _b = idx / C;
|
| 13 |
+
const int _c = idx % C;
|
| 14 |
+
const int _offset = _b * T * C + _c;
|
| 15 |
+
|
| 16 |
+
F u = _u[_c];
|
| 17 |
+
F w = _w[_c];
|
| 18 |
+
const F *__restrict__ const k = _k + _offset;
|
| 19 |
+
const F *__restrict__ const v = _v + _offset;
|
| 20 |
+
F *__restrict__ const y = _y + _offset;
|
| 21 |
+
|
| 22 |
+
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
| 23 |
+
F aa = 0, bb = 0, pp = MIN_VALUE;
|
| 24 |
+
for (int i = 0; i < T; i++) {
|
| 25 |
+
const int ii = i * C;
|
| 26 |
+
const F kk = k[ii];
|
| 27 |
+
const F vv = v[ii];
|
| 28 |
+
|
| 29 |
+
F ww = u + kk;
|
| 30 |
+
F p = max(pp, ww);
|
| 31 |
+
F e1 = exp(pp - p);
|
| 32 |
+
F e2 = exp(ww - p);
|
| 33 |
+
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
| 34 |
+
|
| 35 |
+
ww = w + pp;
|
| 36 |
+
p = max(ww, kk);
|
| 37 |
+
e1 = exp(ww - p);
|
| 38 |
+
e2 = exp(kk - p);
|
| 39 |
+
aa = e1 * aa + e2 * vv;
|
| 40 |
+
bb = e1 * bb + e2;
|
| 41 |
+
pp = p;
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <typename F>
|
| 46 |
+
__global__ void kernel_forward_with_state(
|
| 47 |
+
const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
|
| 48 |
+
const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s
|
| 49 |
+
) {
|
| 50 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 51 |
+
const int _b = idx / C;
|
| 52 |
+
const int _c = idx % C;
|
| 53 |
+
const int _offset_s = _b * C * 3 + _c * 3;
|
| 54 |
+
const int _offset = _b * T * C + _c;
|
| 55 |
+
|
| 56 |
+
F u = _u[_c];
|
| 57 |
+
F w = _w[_c];
|
| 58 |
+
const F *__restrict__ const k = _k + _offset;
|
| 59 |
+
const F *__restrict__ const v = _v + _offset;
|
| 60 |
+
F *__restrict__ const y = _y + _offset;
|
| 61 |
+
F *__restrict__ const s = _s + _offset_s;
|
| 62 |
+
|
| 63 |
+
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
| 64 |
+
F aa = s[0], bb = s[1], pp = s[2];
|
| 65 |
+
for (int i = 0; i < T; i++) {
|
| 66 |
+
const int ii = i * C;
|
| 67 |
+
const F kk = k[ii];
|
| 68 |
+
const F vv = v[ii];
|
| 69 |
+
|
| 70 |
+
F ww = u + kk;
|
| 71 |
+
F p = max(pp, ww);
|
| 72 |
+
F e1 = exp(pp - p);
|
| 73 |
+
F e2 = exp(ww - p);
|
| 74 |
+
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
| 75 |
+
|
| 76 |
+
ww = w + pp;
|
| 77 |
+
p = max(ww, kk);
|
| 78 |
+
e1 = exp(ww - p);
|
| 79 |
+
e2 = exp(kk - p);
|
| 80 |
+
aa = e1 * aa + e2 * vv;
|
| 81 |
+
bb = e1 * bb + e2;
|
| 82 |
+
pp = p;
|
| 83 |
+
}
|
| 84 |
+
s[0] = aa;
|
| 85 |
+
s[1] = bb;
|
| 86 |
+
s[2] = pp;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
template <typename F>
|
| 90 |
+
__global__ void kernel_backward(
|
| 91 |
+
const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
|
| 92 |
+
const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y,
|
| 93 |
+
const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk,
|
| 94 |
+
F *__restrict__ const _gv
|
| 95 |
+
) {
|
| 96 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 97 |
+
const int _b = idx / C;
|
| 98 |
+
const int _c = idx % C;
|
| 99 |
+
const int _offset = _b * T * C + _c;
|
| 100 |
+
|
| 101 |
+
F u = _u[_c];
|
| 102 |
+
F w = _w[_c];
|
| 103 |
+
const F *__restrict__ const k = _k + _offset;
|
| 104 |
+
const F *__restrict__ const v = _v + _offset;
|
| 105 |
+
const F *__restrict__ const y = _y + _offset;
|
| 106 |
+
const F *__restrict__ const gy = _gy + _offset;
|
| 107 |
+
F *__restrict__ const gk = _gk + _offset;
|
| 108 |
+
F *__restrict__ const gv = _gv + _offset;
|
| 109 |
+
|
| 110 |
+
F q[Tmax], r[Tmax];
|
| 111 |
+
|
| 112 |
+
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
| 113 |
+
for (int i = 0; i < T; i++) {
|
| 114 |
+
const int ii = i * C;
|
| 115 |
+
const F kk = k[ii];
|
| 116 |
+
const F vv = v[ii];
|
| 117 |
+
const F yy = y[ii];
|
| 118 |
+
|
| 119 |
+
F ww = u + kk;
|
| 120 |
+
F p = max(pp, ww);
|
| 121 |
+
F e1 = exp(pp - p);
|
| 122 |
+
F e2 = exp(ww - p);
|
| 123 |
+
const F qq = gy[ii] / (e1 * bb + e2);
|
| 124 |
+
gw += (ga - gb * yy) * e1 * qq;
|
| 125 |
+
gu += (vv - yy) * e2 * qq;
|
| 126 |
+
q[i] = qq;
|
| 127 |
+
r[i] = ww - p;
|
| 128 |
+
|
| 129 |
+
ww = w + pp;
|
| 130 |
+
p = max(ww, kk);
|
| 131 |
+
e1 = exp(ww - p);
|
| 132 |
+
e2 = exp(kk - p);
|
| 133 |
+
ga = e1 * (aa + ga);
|
| 134 |
+
gb = e1 * (bb + gb);
|
| 135 |
+
aa = e1 * aa + e2 * vv;
|
| 136 |
+
bb = e1 * bb + e2;
|
| 137 |
+
pp = p;
|
| 138 |
+
}
|
| 139 |
+
const int _offsetBC = _b * C + _c;
|
| 140 |
+
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
|
| 141 |
+
_gu[_offsetBC] = gu;
|
| 142 |
+
|
| 143 |
+
aa = 0, bb = 0, pp = MIN_VALUE;
|
| 144 |
+
for (int i = T - 1; i >= 0; i--) {
|
| 145 |
+
const int ii = i * C;
|
| 146 |
+
const F kk = k[ii];
|
| 147 |
+
const F vv = v[ii];
|
| 148 |
+
const F yy = y[ii];
|
| 149 |
+
const F qq = q[i];
|
| 150 |
+
const F rr = r[i];
|
| 151 |
+
|
| 152 |
+
F e1 = qq * exp(rr);
|
| 153 |
+
F e2 = exp(kk + pp);
|
| 154 |
+
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
|
| 155 |
+
gv[ii] = e1 + e2 * aa;
|
| 156 |
+
|
| 157 |
+
const F ww = w + pp;
|
| 158 |
+
const F www = rr - u - kk;
|
| 159 |
+
const F p = max(ww, www);
|
| 160 |
+
e1 = exp(ww - p);
|
| 161 |
+
e2 = qq * exp(www - p);
|
| 162 |
+
aa = e1 * aa + e2;
|
| 163 |
+
bb = e1 * bb - e2 * yy;
|
| 164 |
+
pp = p;
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
| 169 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
| 170 |
+
assert(B * C % threadsPerBlock.x == 0);
|
| 171 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
| 172 |
+
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) {
|
| 176 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
| 177 |
+
assert(B * C % threadsPerBlock.x == 0);
|
| 178 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
| 179 |
+
kernel_forward_with_state<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s);
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
| 183 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
| 184 |
+
assert(B * C % threadsPerBlock.x == 0);
|
| 185 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
| 186 |
+
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
| 187 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <stdio.h>
|
| 2 |
+
#include <assert.h>
|
| 3 |
+
#include "ATen/ATen.h"
|
| 4 |
+
#define MIN_VALUE (-1e38)
|
| 5 |
+
typedef at::BFloat16 bf16;
|
| 6 |
+
|
| 7 |
+
__global__ void kernel_forward_bf16(
|
| 8 |
+
const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
|
| 9 |
+
const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y
|
| 10 |
+
) {
|
| 11 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 12 |
+
const int _b = idx / C;
|
| 13 |
+
const int _c = idx % C;
|
| 14 |
+
const int _offset = _b * T * C + _c;
|
| 15 |
+
|
| 16 |
+
float u = float(_u[_c]);
|
| 17 |
+
float w = _w[_c];
|
| 18 |
+
const bf16 *__restrict__ const k = _k + _offset;
|
| 19 |
+
const bf16 *__restrict__ const v = _v + _offset;
|
| 20 |
+
bf16 *__restrict__ const y = _y + _offset;
|
| 21 |
+
|
| 22 |
+
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
| 23 |
+
float aa = 0, bb = 0, pp = MIN_VALUE;
|
| 24 |
+
for (int i = 0; i < T; i++) {
|
| 25 |
+
const int ii = i * C;
|
| 26 |
+
const float kk = float(k[ii]);
|
| 27 |
+
const float vv = float(v[ii]);
|
| 28 |
+
|
| 29 |
+
float ww = u + kk;
|
| 30 |
+
float p = max(pp, ww);
|
| 31 |
+
float e1 = exp(pp - p);
|
| 32 |
+
float e2 = exp(ww - p);
|
| 33 |
+
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
| 34 |
+
|
| 35 |
+
ww = w + pp;
|
| 36 |
+
p = max(ww, kk);
|
| 37 |
+
e1 = exp(ww - p);
|
| 38 |
+
e2 = exp(kk - p);
|
| 39 |
+
aa = e1 * aa + e2 * vv;
|
| 40 |
+
bb = e1 * bb + e2;
|
| 41 |
+
pp = p;
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
__global__ void kernel_forward_with_state_bf16(
|
| 46 |
+
const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
|
| 47 |
+
const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y,
|
| 48 |
+
float *__restrict__ const _s
|
| 49 |
+
) {
|
| 50 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 51 |
+
const int _b = idx / C;
|
| 52 |
+
const int _c = idx % C;
|
| 53 |
+
const int _offset_s = _b * C * 3 + _c * 3;
|
| 54 |
+
const int _offset = _b * T * C + _c;
|
| 55 |
+
|
| 56 |
+
float u = float(_u[_c]);
|
| 57 |
+
float w = _w[_c];
|
| 58 |
+
const bf16 *__restrict__ const k = _k + _offset;
|
| 59 |
+
const bf16 *__restrict__ const v = _v + _offset;
|
| 60 |
+
bf16 *__restrict__ const y = _y + _offset;
|
| 61 |
+
float *__restrict__ const s = _s + _offset_s;
|
| 62 |
+
|
| 63 |
+
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
| 64 |
+
float aa = s[0], bb = s[1], pp = s[2];
|
| 65 |
+
for (int i = 0; i < T; i++) {
|
| 66 |
+
const int ii = i * C;
|
| 67 |
+
const float kk = float(k[ii]);
|
| 68 |
+
const float vv = float(v[ii]);
|
| 69 |
+
|
| 70 |
+
float ww = u + kk;
|
| 71 |
+
float p = max(pp, ww);
|
| 72 |
+
float e1 = exp(pp - p);
|
| 73 |
+
float e2 = exp(ww - p);
|
| 74 |
+
y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2);
|
| 75 |
+
|
| 76 |
+
ww = w + pp;
|
| 77 |
+
p = max(ww, kk);
|
| 78 |
+
e1 = exp(ww - p);
|
| 79 |
+
e2 = exp(kk - p);
|
| 80 |
+
aa = e1 * aa + e2 * vv;
|
| 81 |
+
bb = e1 * bb + e2;
|
| 82 |
+
pp = p;
|
| 83 |
+
}
|
| 84 |
+
s[0] = aa;
|
| 85 |
+
s[1] = bb;
|
| 86 |
+
s[2] = pp;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
__global__ void kernel_backward_bf16(
|
| 90 |
+
const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
|
| 91 |
+
const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y,
|
| 92 |
+
const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu,
|
| 93 |
+
bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv
|
| 94 |
+
) {
|
| 95 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 96 |
+
const int _b = idx / C;
|
| 97 |
+
const int _c = idx % C;
|
| 98 |
+
const int _offset = _b * T * C + _c;
|
| 99 |
+
|
| 100 |
+
float u = float(_u[_c]);
|
| 101 |
+
float w = _w[_c];
|
| 102 |
+
const bf16 *__restrict__ const k = _k + _offset;
|
| 103 |
+
const bf16 *__restrict__ const v = _v + _offset;
|
| 104 |
+
const bf16 *__restrict__ const y = _y + _offset;
|
| 105 |
+
const bf16 *__restrict__ const gy = _gy + _offset;
|
| 106 |
+
bf16 *__restrict__ const gk = _gk + _offset;
|
| 107 |
+
bf16 *__restrict__ const gv = _gv + _offset;
|
| 108 |
+
|
| 109 |
+
float q[Tmax], r[Tmax];
|
| 110 |
+
|
| 111 |
+
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
| 112 |
+
for (int i = 0; i < T; i++) {
|
| 113 |
+
const int ii = i * C;
|
| 114 |
+
const float kk = float(k[ii]);
|
| 115 |
+
const float vv = float(v[ii]);
|
| 116 |
+
const float yy = float(y[ii]);
|
| 117 |
+
|
| 118 |
+
float ww = u + kk;
|
| 119 |
+
float p = max(pp, ww);
|
| 120 |
+
float e1 = exp(pp - p);
|
| 121 |
+
float e2 = exp(ww - p);
|
| 122 |
+
const float qq = float(gy[ii]) / (e1 * bb + e2);
|
| 123 |
+
gw += (ga - gb * yy) * e1 * qq;
|
| 124 |
+
gu += (vv - yy) * e2 * qq;
|
| 125 |
+
q[i] = qq;
|
| 126 |
+
r[i] = ww - p;
|
| 127 |
+
|
| 128 |
+
ww = w + pp;
|
| 129 |
+
p = max(ww, kk);
|
| 130 |
+
e1 = exp(ww - p);
|
| 131 |
+
e2 = exp(kk - p);
|
| 132 |
+
ga = e1 * (aa + ga);
|
| 133 |
+
gb = e1 * (bb + gb);
|
| 134 |
+
aa = e1 * aa + e2 * vv;
|
| 135 |
+
bb = e1 * bb + e2;
|
| 136 |
+
pp = p;
|
| 137 |
+
}
|
| 138 |
+
const int _offsetBC = _b * C + _c;
|
| 139 |
+
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
|
| 140 |
+
_gu[_offsetBC] = bf16(gu);
|
| 141 |
+
|
| 142 |
+
aa = 0, bb = 0, pp = MIN_VALUE;
|
| 143 |
+
for (int i = T - 1; i >= 0; i--) {
|
| 144 |
+
const int ii = i * C;
|
| 145 |
+
const float kk = float(k[ii]);
|
| 146 |
+
const float vv = float(v[ii]);
|
| 147 |
+
const float yy = float(y[ii]);
|
| 148 |
+
const float qq = q[i];
|
| 149 |
+
const float rr = r[i];
|
| 150 |
+
|
| 151 |
+
float e1 = qq * exp(rr);
|
| 152 |
+
float e2 = exp(kk + pp);
|
| 153 |
+
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
|
| 154 |
+
gv[ii] = bf16(e1 + e2 * aa);
|
| 155 |
+
|
| 156 |
+
const float ww = w + pp;
|
| 157 |
+
const float www = rr - u - kk;
|
| 158 |
+
const float p = max(ww, www);
|
| 159 |
+
e1 = exp(ww - p);
|
| 160 |
+
e2 = qq * exp(www - p);
|
| 161 |
+
aa = e1 * aa + e2;
|
| 162 |
+
bb = e1 * bb - e2 * yy;
|
| 163 |
+
pp = p;
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
|
| 168 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
| 169 |
+
assert(B * C % threadsPerBlock.x == 0);
|
| 170 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
| 171 |
+
kernel_forward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) {
|
| 175 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
| 176 |
+
assert(B * C % threadsPerBlock.x == 0);
|
| 177 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
| 178 |
+
kernel_forward_with_state_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
|
| 182 |
+
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
| 183 |
+
assert(B * C % threadsPerBlock.x == 0);
|
| 184 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
| 185 |
+
kernel_backward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
| 186 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_op.cpp
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "ATen/ATen.h"
|
| 3 |
+
typedef at::BFloat16 bf16;
|
| 4 |
+
|
| 5 |
+
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
| 6 |
+
void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
|
| 7 |
+
void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);
|
| 8 |
+
void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);
|
| 9 |
+
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
|
| 10 |
+
void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
|
| 11 |
+
|
| 12 |
+
void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
| 13 |
+
const int B = k.size(0);
|
| 14 |
+
const int T = k.size(1);
|
| 15 |
+
const int C = k.size(2);
|
| 16 |
+
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
| 17 |
+
}
|
| 18 |
+
void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
| 19 |
+
const int B = k.size(0);
|
| 20 |
+
const int T = k.size(1);
|
| 21 |
+
const int C = k.size(2);
|
| 22 |
+
cuda_forward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
|
| 23 |
+
}
|
| 24 |
+
void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
|
| 25 |
+
const int B = k.size(0);
|
| 26 |
+
const int T = k.size(1);
|
| 27 |
+
const int C = k.size(2);
|
| 28 |
+
cuda_forward_with_state(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), s.data_ptr<float>());
|
| 29 |
+
}
|
| 30 |
+
void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
|
| 31 |
+
const int B = k.size(0);
|
| 32 |
+
const int T = k.size(1);
|
| 33 |
+
const int C = k.size(2);
|
| 34 |
+
cuda_forward_with_state_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), s.data_ptr<float>());
|
| 35 |
+
}
|
| 36 |
+
void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
| 37 |
+
const int B = k.size(0);
|
| 38 |
+
const int T = k.size(1);
|
| 39 |
+
const int C = k.size(2);
|
| 40 |
+
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
| 41 |
+
}
|
| 42 |
+
void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
| 43 |
+
const int B = k.size(0);
|
| 44 |
+
const int T = k.size(1);
|
| 45 |
+
const int C = k.size(2);
|
| 46 |
+
cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
|
| 47 |
+
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 51 |
+
m.def("forward", &forward, "wkv forward");
|
| 52 |
+
m.def("forward_bf16", &forward_bf16, "wkv forward bf16");
|
| 53 |
+
m.def("forward_with_state", &forward_with_state, "wkv forward with state");
|
| 54 |
+
m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16");
|
| 55 |
+
m.def("backward", &backward, "wkv backward");
|
| 56 |
+
m.def("backward_bf16", &backward_bf16, "wkv backward bf16");
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
TORCH_LIBRARY(wkv, m) {
|
| 60 |
+
m.def("forward", forward);
|
| 61 |
+
m.def("forward_bf16", forward_bf16);
|
| 62 |
+
m.def("forward_with_state", forward_with_state);
|
| 63 |
+
m.def("forward_with_state_bf16", forward_with_state_bf16);
|
| 64 |
+
m.def("backward", backward);
|
| 65 |
+
m.def("backward_bf16", backward_bf16);
|
| 66 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/common.h
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#define min(a, b) ((a)<(b)?(a):(b))
|
| 3 |
+
#define max(a, b) ((a)>(b)?(a):(b))
|
| 4 |
+
#define ceil_divide(a, b) ((a)/(b)+((a)%(b)!=0))
|
| 5 |
+
#define select(cond, a, b) ((cond)?(a):(b))
|
| 6 |
+
#define PI 3.141592
|
| 7 |
+
#define EPSILON 1e-8
|
| 8 |
+
#define MAX_VAL 1e12
|
| 9 |
+
#define MIN_VAL -1e12
|
| 10 |
+
#define EMPTY_VALUE -1
|
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation.cu
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation.cu
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
#include <ATen/ATen.h>
|
| 5 |
+
#include "fast_lsh_cumulation.h"
|
| 6 |
+
#include "fast_lsh_cumulation_cuda.h"
|
| 7 |
+
#include "common_cuda.h"
|
| 8 |
+
#include "common.h"
|
| 9 |
+
#include <vector>
|
| 10 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////
|
| 11 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////
|
| 12 |
+
|
| 13 |
+
std::vector<at::Tensor> fast_hash_ver1_kernel(
|
| 14 |
+
at::Tensor query_mask,
|
| 15 |
+
at::Tensor query_vector,
|
| 16 |
+
at::Tensor key_mask,
|
| 17 |
+
at::Tensor key_vector,
|
| 18 |
+
int num_hash_f,
|
| 19 |
+
int hash_code_len,
|
| 20 |
+
bool use_cuda
|
| 21 |
+
) {
|
| 22 |
+
|
| 23 |
+
int batch_size = query_vector.size(0);
|
| 24 |
+
int num_query = query_vector.size(1);
|
| 25 |
+
int num_key = key_vector.size(1);
|
| 26 |
+
int vector_dim = query_vector.size(2);
|
| 27 |
+
|
| 28 |
+
int num_hash_per_part = vector_dim / hash_code_len;
|
| 29 |
+
int num_part = max(1, ceil_divide(num_hash_f, num_hash_per_part));
|
| 30 |
+
|
| 31 |
+
at::Tensor Dmat = 2 * at::randint(0, 2, {batch_size, 3, num_part, vector_dim}, query_mask.options()) - 1;
|
| 32 |
+
at::Tensor query_hash_code = at::zeros({batch_size, num_query, num_hash_f}, query_mask.options());
|
| 33 |
+
at::Tensor key_hash_code = at::zeros({batch_size, num_key, num_hash_f}, key_mask.options());
|
| 34 |
+
|
| 35 |
+
int *query_mask_ptr = query_mask.data_ptr<int>();
|
| 36 |
+
float *query_vector_ptr = query_vector.data_ptr<float>();
|
| 37 |
+
int *key_mask_ptr = key_mask.data_ptr<int>();
|
| 38 |
+
float *key_vector_ptr = key_vector.data_ptr<float>();
|
| 39 |
+
|
| 40 |
+
int *Dmat_ptr = Dmat.data_ptr<int>();
|
| 41 |
+
|
| 42 |
+
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
| 43 |
+
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
| 44 |
+
|
| 45 |
+
if (use_cuda) {
|
| 46 |
+
{
|
| 47 |
+
dim3 threads(vector_dim);
|
| 48 |
+
dim3 blocks(num_part, num_query, batch_size);
|
| 49 |
+
int shared_mem = vector_dim * sizeof(float);
|
| 50 |
+
fast_hash_ver1_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
| 51 |
+
query_mask_ptr,
|
| 52 |
+
query_vector_ptr,
|
| 53 |
+
Dmat_ptr,
|
| 54 |
+
query_hash_code_ptr,
|
| 55 |
+
batch_size,
|
| 56 |
+
num_query,
|
| 57 |
+
vector_dim,
|
| 58 |
+
num_part,
|
| 59 |
+
num_hash_f,
|
| 60 |
+
hash_code_len
|
| 61 |
+
);
|
| 62 |
+
}
|
| 63 |
+
{
|
| 64 |
+
dim3 threads(vector_dim);
|
| 65 |
+
dim3 blocks(num_part, num_key, batch_size);
|
| 66 |
+
int shared_mem = vector_dim * sizeof(float);
|
| 67 |
+
fast_hash_ver1_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
| 68 |
+
key_mask_ptr,
|
| 69 |
+
key_vector_ptr,
|
| 70 |
+
Dmat_ptr,
|
| 71 |
+
key_hash_code_ptr,
|
| 72 |
+
batch_size,
|
| 73 |
+
num_key,
|
| 74 |
+
vector_dim,
|
| 75 |
+
num_part,
|
| 76 |
+
num_hash_f,
|
| 77 |
+
hash_code_len
|
| 78 |
+
);
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
return {query_hash_code, key_hash_code};
|
| 83 |
+
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
at::Tensor lsh_cumulation_ver1_kernel(
|
| 87 |
+
at::Tensor query_mask,
|
| 88 |
+
at::Tensor query_hash_code,
|
| 89 |
+
at::Tensor key_mask,
|
| 90 |
+
at::Tensor key_hash_code,
|
| 91 |
+
at::Tensor value,
|
| 92 |
+
int hashtable_capacity,
|
| 93 |
+
bool use_cuda
|
| 94 |
+
) {
|
| 95 |
+
|
| 96 |
+
int batch_size = query_hash_code.size(0);
|
| 97 |
+
int num_hash_f = query_hash_code.size(2);
|
| 98 |
+
|
| 99 |
+
int num_query = query_hash_code.size(1);
|
| 100 |
+
int num_key = key_hash_code.size(1);
|
| 101 |
+
int value_dim = value.size(2);
|
| 102 |
+
|
| 103 |
+
at::Tensor hashtable_value = at::empty({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options());
|
| 104 |
+
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
| 105 |
+
|
| 106 |
+
if (use_cuda) {
|
| 107 |
+
int threads_x = WARP_SIZE;
|
| 108 |
+
int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE;
|
| 109 |
+
int block_x_step1 = num_key / threads_y;
|
| 110 |
+
int block_x_step2 = num_query / threads_y;
|
| 111 |
+
int block_y = batch_size;
|
| 112 |
+
|
| 113 |
+
dim3 threads(threads_x, threads_y);
|
| 114 |
+
dim3 blocks_step1(block_x_step1, block_y);
|
| 115 |
+
dim3 blocks_step2(block_x_step2, block_y);
|
| 116 |
+
|
| 117 |
+
int *query_mask_ptr = query_mask.data_ptr<int>();
|
| 118 |
+
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
| 119 |
+
int *key_mask_ptr = key_mask.data_ptr<int>();
|
| 120 |
+
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
| 121 |
+
float *value_ptr = value.data_ptr<float>();
|
| 122 |
+
float *hashtable_value_ptr = hashtable_value.data_ptr<float>();
|
| 123 |
+
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
| 124 |
+
|
| 125 |
+
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
| 126 |
+
|
| 127 |
+
cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float));
|
| 128 |
+
|
| 129 |
+
lsh_cumulation_ver1_step1_cuda_kernel<<<blocks_step1, threads>>>(
|
| 130 |
+
key_mask_ptr,
|
| 131 |
+
key_hash_code_ptr,
|
| 132 |
+
value_ptr,
|
| 133 |
+
hashtable_value_ptr,
|
| 134 |
+
batch_size,
|
| 135 |
+
num_hash_f,
|
| 136 |
+
hashtable_capacity,
|
| 137 |
+
num_key,
|
| 138 |
+
value_dim,
|
| 139 |
+
value_offset
|
| 140 |
+
);
|
| 141 |
+
|
| 142 |
+
lsh_cumulation_ver1_step2_cuda_kernel<<<blocks_step2, threads>>>(
|
| 143 |
+
query_mask_ptr,
|
| 144 |
+
query_hash_code_ptr,
|
| 145 |
+
hashtable_value_ptr,
|
| 146 |
+
cumulation_value_ptr,
|
| 147 |
+
batch_size,
|
| 148 |
+
num_hash_f,
|
| 149 |
+
hashtable_capacity,
|
| 150 |
+
num_query,
|
| 151 |
+
value_dim,
|
| 152 |
+
value_offset
|
| 153 |
+
);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
return cumulation_value;
|
| 159 |
+
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
at::Tensor lsh_weighted_cumulation_ver1_kernel(
|
| 163 |
+
at::Tensor query_mask,
|
| 164 |
+
at::Tensor query_hash_code,
|
| 165 |
+
at::Tensor query_weight,
|
| 166 |
+
at::Tensor key_mask,
|
| 167 |
+
at::Tensor key_hash_code,
|
| 168 |
+
at::Tensor key_weight,
|
| 169 |
+
at::Tensor value,
|
| 170 |
+
int hashtable_capacity,
|
| 171 |
+
bool use_cuda
|
| 172 |
+
) {
|
| 173 |
+
|
| 174 |
+
int batch_size = query_hash_code.size(0);
|
| 175 |
+
int num_hash_f = query_hash_code.size(2);
|
| 176 |
+
|
| 177 |
+
int num_query = query_hash_code.size(1);
|
| 178 |
+
int num_key = key_hash_code.size(1);
|
| 179 |
+
int value_dim = value.size(2);
|
| 180 |
+
int weight_dim = query_weight.size(2);
|
| 181 |
+
|
| 182 |
+
at::Tensor hashtable_value = at::zeros({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options());
|
| 183 |
+
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
| 184 |
+
|
| 185 |
+
if (use_cuda) {
|
| 186 |
+
int threads_x = WARP_SIZE;
|
| 187 |
+
int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE;
|
| 188 |
+
int block_x_step1 = num_key / threads_y;
|
| 189 |
+
int block_x_step2 = num_query / threads_y;
|
| 190 |
+
int block_y = batch_size;
|
| 191 |
+
|
| 192 |
+
dim3 threads(threads_x, threads_y);
|
| 193 |
+
dim3 blocks_step1(block_x_step1, block_y);
|
| 194 |
+
dim3 blocks_step2(block_x_step2, block_y);
|
| 195 |
+
|
| 196 |
+
int *query_mask_ptr = query_mask.data_ptr<int>();
|
| 197 |
+
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
| 198 |
+
float *query_weight_ptr = query_weight.data_ptr<float>();
|
| 199 |
+
int *key_mask_ptr = key_mask.data_ptr<int>();
|
| 200 |
+
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
| 201 |
+
float *key_weight_ptr = key_weight.data_ptr<float>();
|
| 202 |
+
float *value_ptr = value.data_ptr<float>();
|
| 203 |
+
float *hashtable_value_ptr = hashtable_value.data_ptr<float>();
|
| 204 |
+
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
| 205 |
+
|
| 206 |
+
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
| 207 |
+
for (int weight_idx = 0; weight_idx < weight_dim; weight_idx++) {
|
| 208 |
+
|
| 209 |
+
cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float));
|
| 210 |
+
|
| 211 |
+
lsh_weighted_cumulation_ver1_step1_cuda_kernel<<<blocks_step1, threads>>>(
|
| 212 |
+
key_mask_ptr,
|
| 213 |
+
key_hash_code_ptr,
|
| 214 |
+
key_weight_ptr,
|
| 215 |
+
value_ptr,
|
| 216 |
+
hashtable_value_ptr,
|
| 217 |
+
batch_size,
|
| 218 |
+
num_hash_f,
|
| 219 |
+
hashtable_capacity,
|
| 220 |
+
num_key,
|
| 221 |
+
value_dim,
|
| 222 |
+
weight_dim,
|
| 223 |
+
value_offset,
|
| 224 |
+
weight_idx
|
| 225 |
+
);
|
| 226 |
+
|
| 227 |
+
lsh_weighted_cumulation_ver1_step2_cuda_kernel<<<blocks_step2, threads>>>(
|
| 228 |
+
query_mask_ptr,
|
| 229 |
+
query_hash_code_ptr,
|
| 230 |
+
query_weight_ptr,
|
| 231 |
+
hashtable_value_ptr,
|
| 232 |
+
cumulation_value_ptr,
|
| 233 |
+
batch_size,
|
| 234 |
+
num_hash_f,
|
| 235 |
+
hashtable_capacity,
|
| 236 |
+
num_query,
|
| 237 |
+
value_dim,
|
| 238 |
+
weight_dim,
|
| 239 |
+
value_offset,
|
| 240 |
+
weight_idx
|
| 241 |
+
);
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
return cumulation_value;
|
| 248 |
+
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
at::Tensor lsh_weighted_cumulation_ver2_kernel(
|
| 252 |
+
at::Tensor query_mask,
|
| 253 |
+
at::Tensor query_hash_code,
|
| 254 |
+
at::Tensor query_weight,
|
| 255 |
+
at::Tensor key_mask,
|
| 256 |
+
at::Tensor key_hash_code,
|
| 257 |
+
at::Tensor key_weight,
|
| 258 |
+
at::Tensor value,
|
| 259 |
+
int hashtable_capacity,
|
| 260 |
+
bool use_cuda
|
| 261 |
+
) {
|
| 262 |
+
|
| 263 |
+
int batch_size = query_hash_code.size(0);
|
| 264 |
+
int num_hash_f = query_hash_code.size(2);
|
| 265 |
+
|
| 266 |
+
int num_query = query_hash_code.size(1);
|
| 267 |
+
int num_key = key_hash_code.size(1);
|
| 268 |
+
int value_dim = value.size(2);
|
| 269 |
+
int weight_dim = query_weight.size(2);
|
| 270 |
+
|
| 271 |
+
at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
|
| 272 |
+
at::Tensor key_sorted_idxes = at::zeros({batch_size, num_hash_f, num_key}, query_hash_code.options());
|
| 273 |
+
at::Tensor query_info = at::zeros({batch_size, num_query, 2, num_hash_f}, query_hash_code.options());
|
| 274 |
+
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
| 275 |
+
|
| 276 |
+
if (use_cuda) {
|
| 277 |
+
|
| 278 |
+
int *query_mask_ptr = query_mask.data_ptr<int>();
|
| 279 |
+
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
| 280 |
+
float *query_weight_ptr = query_weight.data_ptr<float>();
|
| 281 |
+
int *key_mask_ptr = key_mask.data_ptr<int>();
|
| 282 |
+
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
| 283 |
+
float *key_weight_ptr = key_weight.data_ptr<float>();
|
| 284 |
+
float *value_ptr = value.data_ptr<float>();
|
| 285 |
+
|
| 286 |
+
int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
|
| 287 |
+
int *key_sorted_idxes_ptr = key_sorted_idxes.data_ptr<int>();
|
| 288 |
+
int *query_info_ptr = query_info.data_ptr<int>();
|
| 289 |
+
|
| 290 |
+
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
| 291 |
+
|
| 292 |
+
{
|
| 293 |
+
dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
| 294 |
+
dim3 blocks_step13(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
| 295 |
+
dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
|
| 296 |
+
dim3 blocks_step2(num_hash_f, batch_size);
|
| 297 |
+
int shared_mem = hashtable_capacity * sizeof(float);
|
| 298 |
+
count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
| 299 |
+
key_mask_ptr,
|
| 300 |
+
key_hash_code_ptr,
|
| 301 |
+
count_sort_table_ptr,
|
| 302 |
+
batch_size,
|
| 303 |
+
num_hash_f,
|
| 304 |
+
hashtable_capacity,
|
| 305 |
+
num_key
|
| 306 |
+
);
|
| 307 |
+
count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
|
| 308 |
+
count_sort_table_ptr,
|
| 309 |
+
batch_size,
|
| 310 |
+
num_hash_f,
|
| 311 |
+
hashtable_capacity
|
| 312 |
+
);
|
| 313 |
+
count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
| 314 |
+
key_mask_ptr,
|
| 315 |
+
key_hash_code_ptr,
|
| 316 |
+
count_sort_table_ptr,
|
| 317 |
+
key_sorted_idxes_ptr,
|
| 318 |
+
batch_size,
|
| 319 |
+
num_hash_f,
|
| 320 |
+
hashtable_capacity,
|
| 321 |
+
num_key
|
| 322 |
+
);
|
| 323 |
+
}
|
| 324 |
+
{
|
| 325 |
+
dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
| 326 |
+
dim3 blocks(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
| 327 |
+
extract_query_info_cuda_kernel<<<blocks, threads>>>(
|
| 328 |
+
query_mask_ptr,
|
| 329 |
+
query_hash_code_ptr,
|
| 330 |
+
count_sort_table_ptr,
|
| 331 |
+
query_info_ptr,
|
| 332 |
+
batch_size,
|
| 333 |
+
num_hash_f,
|
| 334 |
+
hashtable_capacity,
|
| 335 |
+
num_query
|
| 336 |
+
);
|
| 337 |
+
}
|
| 338 |
+
{
|
| 339 |
+
dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
|
| 340 |
+
dim3 blocks(num_query, num_hash_f, batch_size);
|
| 341 |
+
int shared_mem = (weight_dim + WARP_SIZE) * sizeof(float);
|
| 342 |
+
lsh_weighted_cumulation_ver2_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
| 343 |
+
query_mask_ptr,
|
| 344 |
+
query_info_ptr,
|
| 345 |
+
key_sorted_idxes_ptr,
|
| 346 |
+
query_weight_ptr,
|
| 347 |
+
key_weight_ptr,
|
| 348 |
+
value_ptr,
|
| 349 |
+
cumulation_value_ptr,
|
| 350 |
+
batch_size,
|
| 351 |
+
num_hash_f,
|
| 352 |
+
num_query,
|
| 353 |
+
num_key,
|
| 354 |
+
value_dim,
|
| 355 |
+
weight_dim
|
| 356 |
+
);
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
return cumulation_value;
|
| 361 |
+
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
at::Tensor lsh_weighted_cumulation_ver3_kernel(
|
| 365 |
+
at::Tensor query_mask,
|
| 366 |
+
at::Tensor query_hash_code,
|
| 367 |
+
at::Tensor query_weight,
|
| 368 |
+
at::Tensor key_mask,
|
| 369 |
+
at::Tensor key_hash_code,
|
| 370 |
+
at::Tensor key_weight,
|
| 371 |
+
at::Tensor value,
|
| 372 |
+
int hashtable_capacity,
|
| 373 |
+
bool use_cuda
|
| 374 |
+
) {
|
| 375 |
+
|
| 376 |
+
int batch_size = query_hash_code.size(0);
|
| 377 |
+
int num_hash_f = query_hash_code.size(2);
|
| 378 |
+
|
| 379 |
+
int num_query = query_hash_code.size(1);
|
| 380 |
+
int num_key = key_hash_code.size(1);
|
| 381 |
+
int value_dim = value.size(2);
|
| 382 |
+
int weight_dim = query_weight.size(2);
|
| 383 |
+
|
| 384 |
+
at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
|
| 385 |
+
at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options());
|
| 386 |
+
at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options());
|
| 387 |
+
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
| 388 |
+
|
| 389 |
+
if (use_cuda) {
|
| 390 |
+
|
| 391 |
+
int *query_mask_ptr = query_mask.data_ptr<int>();
|
| 392 |
+
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
| 393 |
+
float *query_weight_ptr = query_weight.data_ptr<float>();
|
| 394 |
+
int *key_mask_ptr = key_mask.data_ptr<int>();
|
| 395 |
+
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
| 396 |
+
float *key_weight_ptr = key_weight.data_ptr<float>();
|
| 397 |
+
float *value_ptr = value.data_ptr<float>();
|
| 398 |
+
|
| 399 |
+
int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
|
| 400 |
+
int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr<int>();
|
| 401 |
+
int *key_info_ptr = key_info.data_ptr<int>();
|
| 402 |
+
|
| 403 |
+
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
| 404 |
+
|
| 405 |
+
{
|
| 406 |
+
dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
| 407 |
+
dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
| 408 |
+
dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
|
| 409 |
+
dim3 blocks_step2(num_hash_f, batch_size);
|
| 410 |
+
int shared_mem = hashtable_capacity * sizeof(float);
|
| 411 |
+
count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
| 412 |
+
query_mask_ptr,
|
| 413 |
+
query_hash_code_ptr,
|
| 414 |
+
count_sort_table_ptr,
|
| 415 |
+
batch_size,
|
| 416 |
+
num_hash_f,
|
| 417 |
+
hashtable_capacity,
|
| 418 |
+
num_query
|
| 419 |
+
);
|
| 420 |
+
count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
|
| 421 |
+
count_sort_table_ptr,
|
| 422 |
+
batch_size,
|
| 423 |
+
num_hash_f,
|
| 424 |
+
hashtable_capacity
|
| 425 |
+
);
|
| 426 |
+
count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
| 427 |
+
query_mask_ptr,
|
| 428 |
+
query_hash_code_ptr,
|
| 429 |
+
count_sort_table_ptr,
|
| 430 |
+
query_sorted_idxes_ptr,
|
| 431 |
+
batch_size,
|
| 432 |
+
num_hash_f,
|
| 433 |
+
hashtable_capacity,
|
| 434 |
+
num_query
|
| 435 |
+
);
|
| 436 |
+
}
|
| 437 |
+
{
|
| 438 |
+
dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
| 439 |
+
dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
| 440 |
+
extract_query_info_cuda_kernel<<<blocks, threads>>>(
|
| 441 |
+
key_mask_ptr,
|
| 442 |
+
key_hash_code_ptr,
|
| 443 |
+
count_sort_table_ptr,
|
| 444 |
+
key_info_ptr,
|
| 445 |
+
batch_size,
|
| 446 |
+
num_hash_f,
|
| 447 |
+
hashtable_capacity,
|
| 448 |
+
num_key
|
| 449 |
+
);
|
| 450 |
+
}
|
| 451 |
+
{
|
| 452 |
+
dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
|
| 453 |
+
dim3 blocks(num_key, num_hash_f, batch_size);
|
| 454 |
+
int shared_mem = (weight_dim + value_dim + WARP_SIZE) * sizeof(float);
|
| 455 |
+
lsh_weighted_cumulation_ver3_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
| 456 |
+
query_sorted_idxes_ptr,
|
| 457 |
+
key_mask_ptr,
|
| 458 |
+
key_info_ptr,
|
| 459 |
+
query_weight_ptr,
|
| 460 |
+
key_weight_ptr,
|
| 461 |
+
value_ptr,
|
| 462 |
+
cumulation_value_ptr,
|
| 463 |
+
batch_size,
|
| 464 |
+
num_hash_f,
|
| 465 |
+
num_query,
|
| 466 |
+
num_key,
|
| 467 |
+
value_dim,
|
| 468 |
+
weight_dim
|
| 469 |
+
);
|
| 470 |
+
}
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
return cumulation_value;
|
| 474 |
+
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
at::Tensor lsh_weighted_cumulation_ver4_kernel(
|
| 478 |
+
at::Tensor query_mask,
|
| 479 |
+
at::Tensor query_hash_code,
|
| 480 |
+
at::Tensor query_weight,
|
| 481 |
+
at::Tensor key_mask,
|
| 482 |
+
at::Tensor key_hash_code,
|
| 483 |
+
at::Tensor key_weight,
|
| 484 |
+
at::Tensor value,
|
| 485 |
+
int hashtable_capacity,
|
| 486 |
+
bool use_cuda
|
| 487 |
+
) {
|
| 488 |
+
|
| 489 |
+
int batch_size = query_hash_code.size(0);
|
| 490 |
+
int num_hash_f = query_hash_code.size(2);
|
| 491 |
+
|
| 492 |
+
int num_query = query_hash_code.size(1);
|
| 493 |
+
int num_key = key_hash_code.size(1);
|
| 494 |
+
int value_dim = value.size(2);
|
| 495 |
+
int weight_dim = query_weight.size(2);
|
| 496 |
+
|
| 497 |
+
at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
|
| 498 |
+
at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options());
|
| 499 |
+
at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options());
|
| 500 |
+
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
| 501 |
+
|
| 502 |
+
if (use_cuda) {
|
| 503 |
+
|
| 504 |
+
int *query_mask_ptr = query_mask.data_ptr<int>();
|
| 505 |
+
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
| 506 |
+
float *query_weight_ptr = query_weight.data_ptr<float>();
|
| 507 |
+
int *key_mask_ptr = key_mask.data_ptr<int>();
|
| 508 |
+
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
| 509 |
+
float *key_weight_ptr = key_weight.data_ptr<float>();
|
| 510 |
+
float *value_ptr = value.data_ptr<float>();
|
| 511 |
+
|
| 512 |
+
int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
|
| 513 |
+
int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr<int>();
|
| 514 |
+
int *key_info_ptr = key_info.data_ptr<int>();
|
| 515 |
+
|
| 516 |
+
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
| 517 |
+
|
| 518 |
+
{
|
| 519 |
+
dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
| 520 |
+
dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
| 521 |
+
dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
|
| 522 |
+
dim3 blocks_step2(num_hash_f, batch_size);
|
| 523 |
+
int shared_mem = hashtable_capacity * sizeof(float);
|
| 524 |
+
count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
| 525 |
+
query_mask_ptr,
|
| 526 |
+
query_hash_code_ptr,
|
| 527 |
+
count_sort_table_ptr,
|
| 528 |
+
batch_size,
|
| 529 |
+
num_hash_f,
|
| 530 |
+
hashtable_capacity,
|
| 531 |
+
num_query
|
| 532 |
+
);
|
| 533 |
+
count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
|
| 534 |
+
count_sort_table_ptr,
|
| 535 |
+
batch_size,
|
| 536 |
+
num_hash_f,
|
| 537 |
+
hashtable_capacity
|
| 538 |
+
);
|
| 539 |
+
count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
| 540 |
+
query_mask_ptr,
|
| 541 |
+
query_hash_code_ptr,
|
| 542 |
+
count_sort_table_ptr,
|
| 543 |
+
query_sorted_idxes_ptr,
|
| 544 |
+
batch_size,
|
| 545 |
+
num_hash_f,
|
| 546 |
+
hashtable_capacity,
|
| 547 |
+
num_query
|
| 548 |
+
);
|
| 549 |
+
}
|
| 550 |
+
{
|
| 551 |
+
dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
| 552 |
+
dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
| 553 |
+
extract_query_info_cuda_kernel<<<blocks, threads>>>(
|
| 554 |
+
key_mask_ptr,
|
| 555 |
+
key_hash_code_ptr,
|
| 556 |
+
count_sort_table_ptr,
|
| 557 |
+
key_info_ptr,
|
| 558 |
+
batch_size,
|
| 559 |
+
num_hash_f,
|
| 560 |
+
hashtable_capacity,
|
| 561 |
+
num_key
|
| 562 |
+
);
|
| 563 |
+
}
|
| 564 |
+
{
|
| 565 |
+
dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
|
| 566 |
+
dim3 blocks(num_key, batch_size);
|
| 567 |
+
int shared_mem = (weight_dim + value_dim + 2 * num_hash_f) * sizeof(float);
|
| 568 |
+
lsh_weighted_cumulation_ver4_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
| 569 |
+
query_sorted_idxes_ptr,
|
| 570 |
+
key_mask_ptr,
|
| 571 |
+
key_info_ptr,
|
| 572 |
+
query_weight_ptr,
|
| 573 |
+
key_weight_ptr,
|
| 574 |
+
value_ptr,
|
| 575 |
+
cumulation_value_ptr,
|
| 576 |
+
batch_size,
|
| 577 |
+
num_hash_f,
|
| 578 |
+
num_query,
|
| 579 |
+
num_key,
|
| 580 |
+
value_dim,
|
| 581 |
+
weight_dim
|
| 582 |
+
);
|
| 583 |
+
}
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
return cumulation_value;
|
| 587 |
+
|
| 588 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation.h
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <vector>
|
| 4 |
+
|
| 5 |
+
std::vector<at::Tensor> fast_hash_ver1_kernel(
|
| 6 |
+
at::Tensor query_mask,
|
| 7 |
+
at::Tensor query_vector,
|
| 8 |
+
at::Tensor key_mask,
|
| 9 |
+
at::Tensor key_vector,
|
| 10 |
+
int num_hash_f,
|
| 11 |
+
int hash_code_len,
|
| 12 |
+
bool use_cuda
|
| 13 |
+
);
|
| 14 |
+
|
| 15 |
+
at::Tensor lsh_cumulation_ver1_kernel(
|
| 16 |
+
at::Tensor query_mask,
|
| 17 |
+
at::Tensor query_hash_code,
|
| 18 |
+
at::Tensor key_mask,
|
| 19 |
+
at::Tensor key_hash_code,
|
| 20 |
+
at::Tensor value,
|
| 21 |
+
int hashtable_capacity,
|
| 22 |
+
bool use_cuda
|
| 23 |
+
);
|
| 24 |
+
|
| 25 |
+
at::Tensor lsh_weighted_cumulation_ver1_kernel(
|
| 26 |
+
at::Tensor query_mask,
|
| 27 |
+
at::Tensor query_hash_code,
|
| 28 |
+
at::Tensor query_weight,
|
| 29 |
+
at::Tensor key_mask,
|
| 30 |
+
at::Tensor key_hash_code,
|
| 31 |
+
at::Tensor key_weight,
|
| 32 |
+
at::Tensor value,
|
| 33 |
+
int hashtable_capacity,
|
| 34 |
+
bool use_cuda
|
| 35 |
+
);
|
| 36 |
+
|
| 37 |
+
at::Tensor lsh_weighted_cumulation_ver2_kernel(
|
| 38 |
+
at::Tensor query_mask,
|
| 39 |
+
at::Tensor query_hash_code,
|
| 40 |
+
at::Tensor query_weight,
|
| 41 |
+
at::Tensor key_mask,
|
| 42 |
+
at::Tensor key_hash_code,
|
| 43 |
+
at::Tensor key_weight,
|
| 44 |
+
at::Tensor value,
|
| 45 |
+
int hashtable_capacity,
|
| 46 |
+
bool use_cuda
|
| 47 |
+
);
|
| 48 |
+
|
| 49 |
+
at::Tensor lsh_weighted_cumulation_ver3_kernel(
|
| 50 |
+
at::Tensor query_mask,
|
| 51 |
+
at::Tensor query_hash_code,
|
| 52 |
+
at::Tensor query_weight,
|
| 53 |
+
at::Tensor key_mask,
|
| 54 |
+
at::Tensor key_hash_code,
|
| 55 |
+
at::Tensor key_weight,
|
| 56 |
+
at::Tensor value,
|
| 57 |
+
int hashtable_capacity,
|
| 58 |
+
bool use_cuda
|
| 59 |
+
);
|
| 60 |
+
|
| 61 |
+
at::Tensor lsh_weighted_cumulation_ver4_kernel(
|
| 62 |
+
at::Tensor query_mask,
|
| 63 |
+
at::Tensor query_hash_code,
|
| 64 |
+
at::Tensor query_weight,
|
| 65 |
+
at::Tensor key_mask,
|
| 66 |
+
at::Tensor key_hash_code,
|
| 67 |
+
at::Tensor key_weight,
|
| 68 |
+
at::Tensor value,
|
| 69 |
+
int hashtable_capacity,
|
| 70 |
+
bool use_cuda
|
| 71 |
+
);
|
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_cuda.cu
ADDED
|
@@ -0,0 +1,825 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation_cuda.cu
|
| 2 |
+
|
| 3 |
+
#include "fast_lsh_cumulation_cuda.h"
|
| 4 |
+
#include "common_cuda_device.h"
|
| 5 |
+
#include "common_cuda.h"
|
| 6 |
+
#include "common.h"
|
| 7 |
+
#include <stdio.h>
|
| 8 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////
|
| 9 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////
|
| 10 |
+
|
| 11 |
+
inline __device__ void fast_hadamard_transform(float *vector_buffer, int vector_dim, int dim_idx) {
|
| 12 |
+
int stride = vector_dim / 2;
|
| 13 |
+
while (stride > (WARP_SIZE / 2)) {
|
| 14 |
+
__syncthreads();
|
| 15 |
+
int sign = 1 - ((dim_idx / stride) % 2) * 2;
|
| 16 |
+
float val1 = vector_buffer[dim_idx];
|
| 17 |
+
float val2 = vector_buffer[dim_idx + sign * stride];
|
| 18 |
+
__syncthreads();
|
| 19 |
+
vector_buffer[dim_idx] = float(sign) * val1 + val2;
|
| 20 |
+
stride = stride / 2;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
float val = vector_buffer[dim_idx];
|
| 24 |
+
#pragma unroll
|
| 25 |
+
for (stride = (WARP_SIZE / 2); stride > 0; stride = stride / 2) {
|
| 26 |
+
int sign = 1 - ((dim_idx / stride) % 2) * 2;
|
| 27 |
+
val = float(sign) * val + __shfl_xor_sync(FULL_MASK, val, stride);
|
| 28 |
+
}
|
| 29 |
+
vector_buffer[dim_idx] = val;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
__global__ void fast_hash_ver1_cuda_kernel(
|
| 33 |
+
int *mask, // [batch_size, num_vector]
|
| 34 |
+
float *vector, // [batch_size, num_vector, vector_dim]
|
| 35 |
+
int *Dmat, // [batch_size, 3, num_part, vector_dim]
|
| 36 |
+
int *hash_code, // [batch_size, num_vector, num_hash_f]
|
| 37 |
+
int batch_size,
|
| 38 |
+
int num_vector,
|
| 39 |
+
int vector_dim,
|
| 40 |
+
int num_part,
|
| 41 |
+
int num_hash_f,
|
| 42 |
+
int hash_code_len
|
| 43 |
+
) {
|
| 44 |
+
|
| 45 |
+
int batch_idx = blockIdx.z;
|
| 46 |
+
int vector_idx = blockIdx.y;
|
| 47 |
+
int part_idx = blockIdx.x;
|
| 48 |
+
|
| 49 |
+
int dim_idx = threadIdx.x;
|
| 50 |
+
|
| 51 |
+
int batch_idx__vector_idx = batch_idx * num_vector + vector_idx;
|
| 52 |
+
if (mask[batch_idx__vector_idx] == 0) {
|
| 53 |
+
return;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
extern __shared__ float buffer[];
|
| 57 |
+
float *vector_buffer = buffer;
|
| 58 |
+
|
| 59 |
+
vector_buffer[dim_idx] = vector[batch_idx__vector_idx * vector_dim + dim_idx];
|
| 60 |
+
|
| 61 |
+
vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 0) * num_part + part_idx) * vector_dim + dim_idx];
|
| 62 |
+
fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
|
| 63 |
+
vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 1) * num_part + part_idx) * vector_dim + dim_idx];
|
| 64 |
+
fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
|
| 65 |
+
vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 2) * num_part + part_idx) * vector_dim + dim_idx];
|
| 66 |
+
fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
|
| 67 |
+
|
| 68 |
+
int num_hash_per_part = vector_dim / hash_code_len;
|
| 69 |
+
if (hash_code_len == 8 || hash_code_len == 16) {
|
| 70 |
+
int code = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0);
|
| 71 |
+
for (int offset = 1; offset < hash_code_len; offset = offset * 2) {
|
| 72 |
+
code += __shfl_xor_sync(FULL_MASK, code, offset);
|
| 73 |
+
}
|
| 74 |
+
if (dim_idx % hash_code_len == 0) {
|
| 75 |
+
int hash_f_idx = part_idx * num_hash_per_part + dim_idx / hash_code_len;
|
| 76 |
+
if (hash_f_idx < num_hash_f) {
|
| 77 |
+
hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code;
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
} else {
|
| 81 |
+
vector_buffer[dim_idx] = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0);
|
| 82 |
+
__syncthreads();
|
| 83 |
+
if (dim_idx < num_hash_per_part) {
|
| 84 |
+
int code = 0;
|
| 85 |
+
for (int i = 0; i < hash_code_len; i++) {
|
| 86 |
+
code += vector_buffer[dim_idx * hash_code_len + i];
|
| 87 |
+
}
|
| 88 |
+
int hash_f_idx = part_idx * num_hash_per_part + dim_idx;
|
| 89 |
+
if (hash_f_idx < num_hash_f) {
|
| 90 |
+
hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code;
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
__global__ void lsh_cumulation_ver1_step1_cuda_kernel(
|
| 97 |
+
int *key_mask, // [batch_size, num_key]
|
| 98 |
+
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 99 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 100 |
+
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
| 101 |
+
int batch_size,
|
| 102 |
+
int num_hash_f,
|
| 103 |
+
int hashtable_capacity,
|
| 104 |
+
int num_key,
|
| 105 |
+
int value_dim,
|
| 106 |
+
int offset_warp
|
| 107 |
+
) {
|
| 108 |
+
|
| 109 |
+
int warp_thread_idx = threadIdx.x;
|
| 110 |
+
|
| 111 |
+
int batch_idx = blockIdx.y;
|
| 112 |
+
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
| 113 |
+
|
| 114 |
+
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
| 115 |
+
if (key_mask[batch_idx__key_idx] == 0) {
|
| 116 |
+
return;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
if (num_hash_f > WARP_SIZE) {
|
| 120 |
+
float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
| 121 |
+
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
| 122 |
+
int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
| 123 |
+
#pragma unroll
|
| 124 |
+
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
| 125 |
+
int current_hashcode = warp_hashcode;
|
| 126 |
+
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
| 127 |
+
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
| 128 |
+
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
} else {
|
| 132 |
+
float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
| 133 |
+
int warp_hashcode = 0;
|
| 134 |
+
if (warp_thread_idx < num_hash_f) {
|
| 135 |
+
warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx];
|
| 136 |
+
}
|
| 137 |
+
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
| 138 |
+
int current_hashcode = warp_hashcode;
|
| 139 |
+
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
| 140 |
+
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
| 141 |
+
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
__global__ void lsh_cumulation_ver1_step2_cuda_kernel(
|
| 148 |
+
int *query_mask, // [batch_size, num_query]
|
| 149 |
+
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
| 150 |
+
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
| 151 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 152 |
+
int batch_size,
|
| 153 |
+
int num_hash_f,
|
| 154 |
+
int hashtable_capacity,
|
| 155 |
+
int num_query,
|
| 156 |
+
int value_dim,
|
| 157 |
+
int offset_warp
|
| 158 |
+
) {
|
| 159 |
+
|
| 160 |
+
int warp_thread_idx = threadIdx.x;
|
| 161 |
+
|
| 162 |
+
int batch_idx = blockIdx.y;
|
| 163 |
+
int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
| 164 |
+
|
| 165 |
+
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
| 166 |
+
if (query_mask[batch_idx__query_idx] == 0) {
|
| 167 |
+
return;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
if (num_hash_f > WARP_SIZE) {
|
| 171 |
+
float warp_value = 0;
|
| 172 |
+
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
| 173 |
+
int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
| 174 |
+
#pragma unroll
|
| 175 |
+
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
| 176 |
+
int current_hashcode = warp_hashcode;
|
| 177 |
+
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
| 178 |
+
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
| 179 |
+
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f);
|
| 183 |
+
} else {
|
| 184 |
+
float warp_value = 0;
|
| 185 |
+
int warp_hashcode = 0;
|
| 186 |
+
if (warp_thread_idx < num_hash_f) {
|
| 187 |
+
warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx];
|
| 188 |
+
}
|
| 189 |
+
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
| 190 |
+
int current_hashcode = warp_hashcode;
|
| 191 |
+
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
| 192 |
+
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
| 193 |
+
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
| 194 |
+
}
|
| 195 |
+
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
|
| 201 |
+
int *key_mask, // [batch_size, num_key]
|
| 202 |
+
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 203 |
+
float *key_weight, // [batch_size, num_key, weight_dim]
|
| 204 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 205 |
+
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
| 206 |
+
int batch_size,
|
| 207 |
+
int num_hash_f,
|
| 208 |
+
int hashtable_capacity,
|
| 209 |
+
int num_key,
|
| 210 |
+
int value_dim,
|
| 211 |
+
int weight_dim,
|
| 212 |
+
int offset_warp,
|
| 213 |
+
int weight_idx
|
| 214 |
+
) {
|
| 215 |
+
|
| 216 |
+
int warp_thread_idx = threadIdx.x;
|
| 217 |
+
|
| 218 |
+
int batch_idx = blockIdx.y;
|
| 219 |
+
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
| 220 |
+
|
| 221 |
+
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
| 222 |
+
if (key_mask[batch_idx__key_idx] == 0) {
|
| 223 |
+
return;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
if (num_hash_f > WARP_SIZE) {
|
| 227 |
+
float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
| 228 |
+
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
| 229 |
+
int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
| 230 |
+
#pragma unroll
|
| 231 |
+
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
| 232 |
+
int current_hashcode = warp_hashcode;
|
| 233 |
+
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
| 234 |
+
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
| 235 |
+
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
} else {
|
| 239 |
+
float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
| 240 |
+
int warp_hashcode = 0;
|
| 241 |
+
if (warp_thread_idx < num_hash_f) {
|
| 242 |
+
warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx];
|
| 243 |
+
}
|
| 244 |
+
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
| 245 |
+
int current_hashcode = warp_hashcode;
|
| 246 |
+
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
| 247 |
+
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
| 248 |
+
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
| 249 |
+
}
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
|
| 255 |
+
int *query_mask, // [batch_size, num_query]
|
| 256 |
+
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
| 257 |
+
float *query_weight, // [batch_size, num_query, weight_dim]
|
| 258 |
+
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
| 259 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 260 |
+
int batch_size,
|
| 261 |
+
int num_hash_f,
|
| 262 |
+
int hashtable_capacity,
|
| 263 |
+
int num_query,
|
| 264 |
+
int value_dim,
|
| 265 |
+
int weight_dim,
|
| 266 |
+
int offset_warp,
|
| 267 |
+
int weight_idx
|
| 268 |
+
) {
|
| 269 |
+
|
| 270 |
+
int warp_thread_idx = threadIdx.x;
|
| 271 |
+
|
| 272 |
+
int batch_idx = blockIdx.y;
|
| 273 |
+
int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
| 274 |
+
|
| 275 |
+
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
| 276 |
+
if (query_mask[batch_idx__query_idx] == 0) {
|
| 277 |
+
return;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
if (num_hash_f > WARP_SIZE) {
|
| 281 |
+
float warp_value = 0;
|
| 282 |
+
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
| 283 |
+
int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
| 284 |
+
#pragma unroll
|
| 285 |
+
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
| 286 |
+
int current_hashcode = warp_hashcode;
|
| 287 |
+
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
| 288 |
+
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
| 289 |
+
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx];
|
| 293 |
+
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f);
|
| 294 |
+
} else {
|
| 295 |
+
float warp_value = 0;
|
| 296 |
+
int warp_hashcode = 0;
|
| 297 |
+
if (warp_thread_idx < num_hash_f) {
|
| 298 |
+
warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx];
|
| 299 |
+
}
|
| 300 |
+
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
| 301 |
+
int current_hashcode = warp_hashcode;
|
| 302 |
+
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
| 303 |
+
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
| 304 |
+
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
| 305 |
+
}
|
| 306 |
+
float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx];
|
| 307 |
+
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
__global__ void count_sort_step1_cuda_kernel(
|
| 313 |
+
int *key_mask, // [batch_size, num_key]
|
| 314 |
+
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 315 |
+
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
| 316 |
+
int batch_size,
|
| 317 |
+
int num_hash_f,
|
| 318 |
+
int hashtable_capacity,
|
| 319 |
+
int num_key
|
| 320 |
+
) {
|
| 321 |
+
|
| 322 |
+
int batch_idx = blockIdx.y;
|
| 323 |
+
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
| 324 |
+
int hash_f_idx = threadIdx.x;
|
| 325 |
+
|
| 326 |
+
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
| 327 |
+
if (key_mask[batch_idx__key_idx] == 0) {
|
| 328 |
+
return;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx];
|
| 332 |
+
atomicAdd(&count_sort_table[(batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code], 1);
|
| 333 |
+
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
__global__ void count_sort_step2_cuda_kernel(
|
| 337 |
+
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
| 338 |
+
int batch_size,
|
| 339 |
+
int num_hash_f,
|
| 340 |
+
int hashtable_capacity
|
| 341 |
+
) {
|
| 342 |
+
|
| 343 |
+
int batch_idx = blockIdx.y;
|
| 344 |
+
int hash_f_idx = blockIdx.x;
|
| 345 |
+
|
| 346 |
+
int num_threads = blockDim.x;
|
| 347 |
+
int thread_id = threadIdx.x;
|
| 348 |
+
|
| 349 |
+
int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
|
| 350 |
+
|
| 351 |
+
extern __shared__ float buffer[];
|
| 352 |
+
int *table_buffer = (int*)buffer;
|
| 353 |
+
|
| 354 |
+
if (thread_id == 0) {
|
| 355 |
+
table_buffer[0] = 0;
|
| 356 |
+
}
|
| 357 |
+
copy_data<int>(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], &table_buffer[1], hashtable_capacity - 1, num_threads, thread_id);
|
| 358 |
+
|
| 359 |
+
for (int table_idx_start = 0; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + num_threads) {
|
| 360 |
+
int thread_value = table_buffer[table_idx_start + thread_id];
|
| 361 |
+
int next_thread_value = 0;
|
| 362 |
+
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
| 363 |
+
next_thread_value = __shfl_up_sync(FULL_MASK, thread_value, offset);
|
| 364 |
+
if (thread_id % WARP_SIZE >= offset) {
|
| 365 |
+
thread_value = thread_value + next_thread_value;
|
| 366 |
+
}
|
| 367 |
+
}
|
| 368 |
+
table_buffer[table_idx_start + thread_id] = thread_value;
|
| 369 |
+
}
|
| 370 |
+
__syncthreads();
|
| 371 |
+
|
| 372 |
+
if (hashtable_capacity > WARP_SIZE) {
|
| 373 |
+
if (thread_id < WARP_SIZE) {
|
| 374 |
+
for (int table_idx_start = WARP_SIZE; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + WARP_SIZE) {
|
| 375 |
+
table_buffer[table_idx_start + thread_id] += table_buffer[table_idx_start - 1];
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
copy_data<int>(table_buffer, &count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], hashtable_capacity, num_threads, thread_id);
|
| 381 |
+
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
__global__ void count_sort_step3_cuda_kernel(
|
| 386 |
+
int *key_mask, // [batch_size, num_key]
|
| 387 |
+
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 388 |
+
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
| 389 |
+
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
| 390 |
+
int batch_size,
|
| 391 |
+
int num_hash_f,
|
| 392 |
+
int hashtable_capacity,
|
| 393 |
+
int num_key
|
| 394 |
+
) {
|
| 395 |
+
|
| 396 |
+
int batch_idx = blockIdx.y;
|
| 397 |
+
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
| 398 |
+
int hash_f_idx = threadIdx.x;
|
| 399 |
+
|
| 400 |
+
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
| 401 |
+
if (key_mask[batch_idx__key_idx] == 0) {
|
| 402 |
+
return;
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
|
| 406 |
+
|
| 407 |
+
int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx];
|
| 408 |
+
int sort_idx = atomicAdd(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity + hash_code], 1);
|
| 409 |
+
key_sorted_idxes[batch_idx__hash_f_idx * num_key + sort_idx] = key_idx;
|
| 410 |
+
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
__global__ void extract_query_info_cuda_kernel(
|
| 414 |
+
int *query_mask, // [batch_size, num_query]
|
| 415 |
+
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
| 416 |
+
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
| 417 |
+
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
| 418 |
+
int batch_size,
|
| 419 |
+
int num_hash_f,
|
| 420 |
+
int hashtable_capacity,
|
| 421 |
+
int num_query
|
| 422 |
+
) {
|
| 423 |
+
|
| 424 |
+
int batch_idx = blockIdx.y;
|
| 425 |
+
int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
| 426 |
+
int hash_f_idx = threadIdx.x;
|
| 427 |
+
|
| 428 |
+
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
| 429 |
+
if (query_mask[batch_idx__query_idx] == 0) {
|
| 430 |
+
return;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
int hash_code = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_idx];
|
| 434 |
+
int batch_idx__hash_f_idx__hash_code = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code;
|
| 435 |
+
|
| 436 |
+
int key_offset = select(hash_code == 0, 0, count_sort_table[batch_idx__hash_f_idx__hash_code - 1]);
|
| 437 |
+
int key_count = count_sort_table[batch_idx__hash_f_idx__hash_code] - key_offset;
|
| 438 |
+
|
| 439 |
+
query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx] = key_offset;
|
| 440 |
+
query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx] = key_count;
|
| 441 |
+
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
|
| 445 |
+
int *query_mask, // [batch_size, num_query]
|
| 446 |
+
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
| 447 |
+
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
| 448 |
+
float *query_weight, // [batch_size, num_query, weight_dim]
|
| 449 |
+
float *key_weight, // [batch_size, num_key, weight_dim]
|
| 450 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 451 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 452 |
+
int batch_size,
|
| 453 |
+
int num_hash_f,
|
| 454 |
+
int num_query,
|
| 455 |
+
int num_key,
|
| 456 |
+
int value_dim,
|
| 457 |
+
int weight_dim
|
| 458 |
+
) {
|
| 459 |
+
|
| 460 |
+
int batch_idx = blockIdx.z;
|
| 461 |
+
int hash_f_idx = blockIdx.y;
|
| 462 |
+
int query_idx = blockIdx.x;
|
| 463 |
+
|
| 464 |
+
int num_threads = blockDim.y * blockDim.x;
|
| 465 |
+
int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
| 466 |
+
|
| 467 |
+
int num_warps = blockDim.y;
|
| 468 |
+
int warp_idx = threadIdx.y;
|
| 469 |
+
int warp_thread_idx = threadIdx.x;
|
| 470 |
+
|
| 471 |
+
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
| 472 |
+
if (query_mask[batch_idx__query_idx] == 0) {
|
| 473 |
+
return;
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
int key_offset = query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx];
|
| 477 |
+
int key_count = query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx];
|
| 478 |
+
|
| 479 |
+
if (key_count == 0) {
|
| 480 |
+
return;
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
extern __shared__ float buffer[];
|
| 484 |
+
|
| 485 |
+
if (key_count == 1) {
|
| 486 |
+
if (warp_idx == 0) {
|
| 487 |
+
int key_idx = key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset];
|
| 488 |
+
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
| 489 |
+
float weight = 0;
|
| 490 |
+
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
| 491 |
+
int weight_dim_idx = weight_offset + warp_thread_idx;
|
| 492 |
+
float val = query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx];
|
| 493 |
+
#pragma unroll
|
| 494 |
+
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
| 495 |
+
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
| 496 |
+
}
|
| 497 |
+
weight = weight + val;
|
| 498 |
+
}
|
| 499 |
+
weight = weight / float(num_hash_f);
|
| 500 |
+
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
| 501 |
+
int value_dim_idx = value_offset + warp_thread_idx;
|
| 502 |
+
float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
|
| 503 |
+
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
| 504 |
+
}
|
| 505 |
+
}
|
| 506 |
+
} else {
|
| 507 |
+
float *weight_buffer = buffer;
|
| 508 |
+
int *key_idxes_buffer = (int*)&buffer[weight_dim];
|
| 509 |
+
|
| 510 |
+
copy_data_nonblocking<float>(&query_weight[batch_idx__query_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
|
| 511 |
+
|
| 512 |
+
while (key_count > 0) {
|
| 513 |
+
int work_size = min(WARP_SIZE, key_count);
|
| 514 |
+
copy_data_nonblocking<int>(&key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset], key_idxes_buffer, work_size, num_threads, thread_id);
|
| 515 |
+
__syncthreads();
|
| 516 |
+
for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) {
|
| 517 |
+
int work_idx = work_offset + warp_idx;
|
| 518 |
+
if (work_idx < key_count) {
|
| 519 |
+
int key_idx = key_idxes_buffer[work_idx];
|
| 520 |
+
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
| 521 |
+
float weight = 0;
|
| 522 |
+
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
| 523 |
+
int weight_dim_idx = weight_offset + warp_thread_idx;
|
| 524 |
+
float val = weight_buffer[weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx];
|
| 525 |
+
#pragma unroll
|
| 526 |
+
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
| 527 |
+
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
| 528 |
+
}
|
| 529 |
+
weight = weight + val;
|
| 530 |
+
}
|
| 531 |
+
weight = weight / float(num_hash_f);
|
| 532 |
+
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
| 533 |
+
int value_dim_idx = value_offset + warp_thread_idx;
|
| 534 |
+
float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
|
| 535 |
+
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
}
|
| 539 |
+
key_count = key_count - work_size;
|
| 540 |
+
key_offset = key_offset + work_size;
|
| 541 |
+
}
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
|
| 547 |
+
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
| 548 |
+
int *key_mask, // [batch_size, num_key]
|
| 549 |
+
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
| 550 |
+
float *query_weight, // [batch_size, num_query, weight_dim]
|
| 551 |
+
float *key_weight, // [batch_size, num_key, weight_dim]
|
| 552 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 553 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 554 |
+
int batch_size,
|
| 555 |
+
int num_hash_f,
|
| 556 |
+
int num_query,
|
| 557 |
+
int num_key,
|
| 558 |
+
int value_dim,
|
| 559 |
+
int weight_dim
|
| 560 |
+
) {
|
| 561 |
+
|
| 562 |
+
int batch_idx = blockIdx.z;
|
| 563 |
+
int hash_f_idx = blockIdx.y;
|
| 564 |
+
int key_idx = blockIdx.x;
|
| 565 |
+
|
| 566 |
+
int num_threads = blockDim.y * blockDim.x;
|
| 567 |
+
int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
| 568 |
+
|
| 569 |
+
int num_warps = blockDim.y;
|
| 570 |
+
int warp_idx = threadIdx.y;
|
| 571 |
+
int warp_thread_idx = threadIdx.x;
|
| 572 |
+
|
| 573 |
+
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
| 574 |
+
if (key_mask[batch_idx__key_idx] == 0) {
|
| 575 |
+
return;
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
int query_offset = key_info[batch_idx__key_idx * 2 * num_hash_f + hash_f_idx];
|
| 579 |
+
int query_count = key_info[(batch_idx__key_idx * 2 + 1) * num_hash_f + hash_f_idx];
|
| 580 |
+
|
| 581 |
+
if (query_count == 0) {
|
| 582 |
+
return;
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
extern __shared__ float buffer[];
|
| 586 |
+
|
| 587 |
+
if (query_count == 1) {
|
| 588 |
+
if (warp_idx == 0) {
|
| 589 |
+
int query_idx = query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset];
|
| 590 |
+
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
| 591 |
+
float weight = 0;
|
| 592 |
+
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
| 593 |
+
int weight_dim_idx = weight_offset + warp_thread_idx;
|
| 594 |
+
float val = key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
|
| 595 |
+
#pragma unroll
|
| 596 |
+
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
| 597 |
+
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
| 598 |
+
}
|
| 599 |
+
weight = weight + val;
|
| 600 |
+
}
|
| 601 |
+
weight = weight / float(num_hash_f);
|
| 602 |
+
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
| 603 |
+
int value_dim_idx = value_offset + warp_thread_idx;
|
| 604 |
+
float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
|
| 605 |
+
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
| 606 |
+
}
|
| 607 |
+
}
|
| 608 |
+
} else {
|
| 609 |
+
float *weight_buffer = buffer;
|
| 610 |
+
float *value_buffer = &buffer[weight_dim];
|
| 611 |
+
int *query_idxes_buffer = (int*)&buffer[weight_dim + value_dim];
|
| 612 |
+
|
| 613 |
+
copy_data_nonblocking<float>(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
|
| 614 |
+
copy_data_nonblocking<float>(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id);
|
| 615 |
+
|
| 616 |
+
while (query_count > 0) {
|
| 617 |
+
int work_size = min(WARP_SIZE, query_count);
|
| 618 |
+
copy_data_nonblocking<int>(&query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset], query_idxes_buffer, work_size, num_threads, thread_id);
|
| 619 |
+
__syncthreads();
|
| 620 |
+
for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) {
|
| 621 |
+
int work_idx = work_offset + warp_idx;
|
| 622 |
+
if (work_idx < query_count) {
|
| 623 |
+
int query_idx = query_idxes_buffer[work_idx];
|
| 624 |
+
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
| 625 |
+
float weight = 0;
|
| 626 |
+
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
| 627 |
+
int weight_dim_idx = weight_offset + warp_thread_idx;
|
| 628 |
+
float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
|
| 629 |
+
#pragma unroll
|
| 630 |
+
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
| 631 |
+
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
| 632 |
+
}
|
| 633 |
+
weight = weight + val;
|
| 634 |
+
}
|
| 635 |
+
weight = weight / float(num_hash_f);
|
| 636 |
+
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
| 637 |
+
int value_dim_idx = value_offset + warp_thread_idx;
|
| 638 |
+
float val = value_buffer[value_dim_idx];
|
| 639 |
+
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
| 640 |
+
}
|
| 641 |
+
}
|
| 642 |
+
}
|
| 643 |
+
query_count = query_count - work_size;
|
| 644 |
+
query_offset = query_offset + work_size;
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
|
| 651 |
+
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
| 652 |
+
int *key_mask, // [batch_size, num_key]
|
| 653 |
+
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
| 654 |
+
float *query_weight, // [batch_size, num_query, weight_dim]
|
| 655 |
+
float *key_weight, // [batch_size, num_key, weight_dim]
|
| 656 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 657 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 658 |
+
int batch_size,
|
| 659 |
+
int num_hash_f,
|
| 660 |
+
int num_query,
|
| 661 |
+
int num_key,
|
| 662 |
+
int value_dim,
|
| 663 |
+
int weight_dim
|
| 664 |
+
) {
|
| 665 |
+
|
| 666 |
+
int batch_idx = blockIdx.y;
|
| 667 |
+
int key_idx = blockIdx.x;
|
| 668 |
+
|
| 669 |
+
int num_threads = blockDim.y * blockDim.x;
|
| 670 |
+
int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
| 671 |
+
|
| 672 |
+
int num_warps = blockDim.y;
|
| 673 |
+
int warp_idx = threadIdx.y;
|
| 674 |
+
int warp_thread_idx = threadIdx.x;
|
| 675 |
+
|
| 676 |
+
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
| 677 |
+
if (key_mask[batch_idx__key_idx] == 0) {
|
| 678 |
+
return;
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
extern __shared__ float buffer[];
|
| 682 |
+
float *weight_buffer = buffer;
|
| 683 |
+
float *value_buffer = &buffer[weight_dim];
|
| 684 |
+
int *key_info_buffer = (int*)&buffer[weight_dim + value_dim];
|
| 685 |
+
|
| 686 |
+
copy_data_nonblocking<float>(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
|
| 687 |
+
copy_data_nonblocking<float>(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id);
|
| 688 |
+
copy_data_nonblocking<int>(&key_info[batch_idx__key_idx * 2 * num_hash_f], key_info_buffer, 2 * num_hash_f, num_threads, thread_id);
|
| 689 |
+
|
| 690 |
+
int *query_offset_buffer = key_info_buffer;
|
| 691 |
+
int *query_count_buffer = &key_info_buffer[num_hash_f];
|
| 692 |
+
|
| 693 |
+
const int hashtable_size = 1024 + OPTIMAL_THREADS_PER_BLOCK;
|
| 694 |
+
__shared__ int hashtable_query[hashtable_size];
|
| 695 |
+
__shared__ int hashtable_count[hashtable_size];
|
| 696 |
+
__shared__ int inserted_query[hashtable_size];
|
| 697 |
+
__shared__ int query_counter[1];
|
| 698 |
+
|
| 699 |
+
int hash_f_idx_base = 0;
|
| 700 |
+
|
| 701 |
+
while (true) {
|
| 702 |
+
|
| 703 |
+
init_buffer_nonblocking<int>(EMPTY_VALUE, hashtable_query, hashtable_size, num_threads, thread_id);
|
| 704 |
+
init_buffer_nonblocking<int>(0, hashtable_count, hashtable_size, num_threads, thread_id);
|
| 705 |
+
init_buffer_nonblocking<int>(EMPTY_VALUE, inserted_query, hashtable_size, num_threads, thread_id);
|
| 706 |
+
init_buffer_nonblocking<int>(0, query_counter, 1, num_threads, thread_id);
|
| 707 |
+
__syncthreads();
|
| 708 |
+
|
| 709 |
+
while (hash_f_idx_base < num_hash_f) {
|
| 710 |
+
|
| 711 |
+
int hash_f_idx = hash_f_idx_base + warp_idx;
|
| 712 |
+
int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
|
| 713 |
+
|
| 714 |
+
int stop_flag = 0;
|
| 715 |
+
|
| 716 |
+
int query_offset = query_offset_buffer[hash_f_idx];
|
| 717 |
+
int query_count = query_count_buffer[hash_f_idx];
|
| 718 |
+
|
| 719 |
+
while (query_count > 0) {
|
| 720 |
+
|
| 721 |
+
int work_size = min(query_count, WARP_SIZE);
|
| 722 |
+
|
| 723 |
+
// try inserting query to set and check whether the query is new
|
| 724 |
+
int found_new_query = 0;
|
| 725 |
+
int query_idx = -1;
|
| 726 |
+
if (warp_thread_idx < work_size) {
|
| 727 |
+
query_idx = query_sorted_idxes[batch_idx__hash_f_idx * num_query + query_offset + warp_thread_idx];
|
| 728 |
+
int slot = set_insert<int>(hashtable_query, hashtable_size, query_idx);
|
| 729 |
+
if (slot >= 0) {
|
| 730 |
+
found_new_query = atomicAdd(&hashtable_count[slot], 1) == 0;
|
| 731 |
+
}
|
| 732 |
+
}
|
| 733 |
+
|
| 734 |
+
// compute cumulative offset
|
| 735 |
+
int position_offset = found_new_query;
|
| 736 |
+
int next_position_offset = 0;
|
| 737 |
+
#pragma unroll
|
| 738 |
+
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
| 739 |
+
next_position_offset = __shfl_up_sync(FULL_MASK, position_offset, offset);
|
| 740 |
+
if (thread_id % WARP_SIZE >= offset) {
|
| 741 |
+
position_offset = position_offset + next_position_offset;
|
| 742 |
+
}
|
| 743 |
+
}
|
| 744 |
+
|
| 745 |
+
// get the inserted query list end index
|
| 746 |
+
int inserted_query_base = 0;
|
| 747 |
+
if (thread_id % WARP_SIZE == WARP_SIZE - 1) {
|
| 748 |
+
inserted_query_base = atomicAdd(query_counter, position_offset);
|
| 749 |
+
}
|
| 750 |
+
inserted_query_base = __shfl_sync(FULL_MASK, inserted_query_base, WARP_SIZE - 1);
|
| 751 |
+
|
| 752 |
+
// insert new queries to list
|
| 753 |
+
int insert_idx = inserted_query_base + position_offset - 1;
|
| 754 |
+
if (found_new_query) {
|
| 755 |
+
inserted_query[insert_idx] = query_idx;
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
// remove inserted queries from list
|
| 759 |
+
query_offset_buffer[hash_f_idx] += work_size;
|
| 760 |
+
query_count_buffer[hash_f_idx] -= work_size;
|
| 761 |
+
query_offset += work_size;
|
| 762 |
+
query_count -= work_size;
|
| 763 |
+
|
| 764 |
+
// if list is almost full, stop inserting
|
| 765 |
+
if (inserted_query_base + OPTIMAL_THREADS_PER_BLOCK > hashtable_size) {
|
| 766 |
+
stop_flag = 1;
|
| 767 |
+
break;
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
}
|
| 771 |
+
|
| 772 |
+
if (stop_flag) {
|
| 773 |
+
break;
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
hash_f_idx_base = hash_f_idx_base + num_warps;
|
| 777 |
+
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
__syncthreads();
|
| 781 |
+
|
| 782 |
+
int num_distint_query = query_counter[0];
|
| 783 |
+
|
| 784 |
+
if (num_distint_query > 0) {
|
| 785 |
+
for (int idx_base = 0; idx_base < num_distint_query; idx_base = idx_base + num_warps) {
|
| 786 |
+
int idx = idx_base + warp_idx;
|
| 787 |
+
if (idx < num_distint_query) {
|
| 788 |
+
int query_idx = inserted_query[idx];
|
| 789 |
+
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
| 790 |
+
|
| 791 |
+
int slot = set_lookup<int>(hashtable_query, hashtable_size, query_idx);
|
| 792 |
+
int duplicate_count = hashtable_count[slot];
|
| 793 |
+
|
| 794 |
+
float weight = 0;
|
| 795 |
+
for (int weight_idx_base = 0; weight_idx_base < weight_dim; weight_idx_base = weight_idx_base + WARP_SIZE) {
|
| 796 |
+
int weight_dim_idx = weight_idx_base + warp_thread_idx;
|
| 797 |
+
float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
|
| 798 |
+
#pragma unroll
|
| 799 |
+
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
| 800 |
+
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
| 801 |
+
}
|
| 802 |
+
weight = weight + val;
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
weight = (float)duplicate_count * weight / float(num_hash_f);
|
| 806 |
+
|
| 807 |
+
for (int value_idx_base = 0; value_idx_base < value_dim; value_idx_base = value_idx_base + WARP_SIZE) {
|
| 808 |
+
int value_dim_idx = value_idx_base + warp_thread_idx;
|
| 809 |
+
float val = value_buffer[value_dim_idx];
|
| 810 |
+
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
| 811 |
+
}
|
| 812 |
+
}
|
| 813 |
+
}
|
| 814 |
+
} else {
|
| 815 |
+
|
| 816 |
+
// all computation is completed if num_distint_query == 0
|
| 817 |
+
break;
|
| 818 |
+
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
__syncthreads();
|
| 822 |
+
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_cuda.h
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__global__ void fast_hash_ver1_cuda_kernel(
|
| 2 |
+
int *mask, // [batch_size, num_vector]
|
| 3 |
+
float *vector, // [batch_size, num_vector, vector_dim]
|
| 4 |
+
int *Dmat, // [3, num_part, vector_dim]
|
| 5 |
+
int *hash_code, // [batch_size, num_vector, num_hash_f]
|
| 6 |
+
int batch_size,
|
| 7 |
+
int num_vector,
|
| 8 |
+
int vector_dim,
|
| 9 |
+
int num_part,
|
| 10 |
+
int num_hash_f,
|
| 11 |
+
int hash_code_len
|
| 12 |
+
);
|
| 13 |
+
|
| 14 |
+
__global__ void lsh_cumulation_ver1_step1_cuda_kernel(
|
| 15 |
+
int *key_mask, // [batch_size, num_key]
|
| 16 |
+
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 17 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 18 |
+
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim]
|
| 19 |
+
int batch_size,
|
| 20 |
+
int num_hash_f,
|
| 21 |
+
int hashtable_capacity,
|
| 22 |
+
int num_key,
|
| 23 |
+
int value_dim,
|
| 24 |
+
int offset_warp
|
| 25 |
+
);
|
| 26 |
+
|
| 27 |
+
__global__ void lsh_cumulation_ver1_step2_cuda_kernel(
|
| 28 |
+
int *query_mask, // [batch_size, num_query]
|
| 29 |
+
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
| 30 |
+
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim]
|
| 31 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 32 |
+
int batch_size,
|
| 33 |
+
int num_hash_f,
|
| 34 |
+
int hashtable_capacity,
|
| 35 |
+
int num_query,
|
| 36 |
+
int value_dim,
|
| 37 |
+
int offset_warp
|
| 38 |
+
);
|
| 39 |
+
|
| 40 |
+
__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
|
| 41 |
+
int *key_mask, // [batch_size, num_key]
|
| 42 |
+
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 43 |
+
float *key_weight, // [batch_size, num_key, weight_dim]
|
| 44 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 45 |
+
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
| 46 |
+
int batch_size,
|
| 47 |
+
int num_hash_f,
|
| 48 |
+
int hashtable_capacity,
|
| 49 |
+
int num_key,
|
| 50 |
+
int value_dim,
|
| 51 |
+
int weight_dim,
|
| 52 |
+
int offset_warp,
|
| 53 |
+
int weight_idx
|
| 54 |
+
);
|
| 55 |
+
|
| 56 |
+
__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
|
| 57 |
+
int *query_mask, // [batch_size, num_query]
|
| 58 |
+
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
| 59 |
+
float *query_weight, // [batch_size, num_query, weight_dim]
|
| 60 |
+
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
| 61 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 62 |
+
int batch_size,
|
| 63 |
+
int num_hash_f,
|
| 64 |
+
int hashtable_capacity,
|
| 65 |
+
int num_query,
|
| 66 |
+
int value_dim,
|
| 67 |
+
int weight_dim,
|
| 68 |
+
int offset_warp,
|
| 69 |
+
int weight_idx
|
| 70 |
+
);
|
| 71 |
+
|
| 72 |
+
__global__ void count_sort_step1_cuda_kernel(
|
| 73 |
+
int *key_mask, // [batch_size, num_key]
|
| 74 |
+
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 75 |
+
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
| 76 |
+
int batch_size,
|
| 77 |
+
int num_hash_f,
|
| 78 |
+
int hashtable_capacity,
|
| 79 |
+
int num_key
|
| 80 |
+
);
|
| 81 |
+
|
| 82 |
+
__global__ void count_sort_step2_cuda_kernel(
|
| 83 |
+
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
| 84 |
+
int batch_size,
|
| 85 |
+
int num_hash_f,
|
| 86 |
+
int hashtable_capacity
|
| 87 |
+
);
|
| 88 |
+
|
| 89 |
+
__global__ void count_sort_step3_cuda_kernel(
|
| 90 |
+
int *key_mask, // [batch_size, num_key]
|
| 91 |
+
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 92 |
+
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
| 93 |
+
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
| 94 |
+
int batch_size,
|
| 95 |
+
int num_hash_f,
|
| 96 |
+
int hashtable_capacity,
|
| 97 |
+
int num_key
|
| 98 |
+
);
|
| 99 |
+
|
| 100 |
+
__global__ void extract_query_info_cuda_kernel(
|
| 101 |
+
int *query_mask, // [batch_size, num_query]
|
| 102 |
+
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
| 103 |
+
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
| 104 |
+
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
| 105 |
+
int batch_size,
|
| 106 |
+
int num_hash_f,
|
| 107 |
+
int hashtable_capacity,
|
| 108 |
+
int num_query
|
| 109 |
+
);
|
| 110 |
+
|
| 111 |
+
__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
|
| 112 |
+
int *query_mask, // [batch_size, num_query]
|
| 113 |
+
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
| 114 |
+
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
| 115 |
+
float *query_weight, // [batch_size, num_query, weight_dim]
|
| 116 |
+
float *key_weight, // [batch_size, num_key, weight_dim]
|
| 117 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 118 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 119 |
+
int batch_size,
|
| 120 |
+
int num_hash_f,
|
| 121 |
+
int num_query,
|
| 122 |
+
int num_key,
|
| 123 |
+
int value_dim,
|
| 124 |
+
int weight_dim
|
| 125 |
+
);
|
| 126 |
+
|
| 127 |
+
__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
|
| 128 |
+
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
| 129 |
+
int *key_mask, // [batch_size, num_key]
|
| 130 |
+
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
| 131 |
+
float *query_weight, // [batch_size, num_query, weight_dim]
|
| 132 |
+
float *key_weight, // [batch_size, num_key, weight_dim]
|
| 133 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 134 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 135 |
+
int batch_size,
|
| 136 |
+
int num_hash_f,
|
| 137 |
+
int num_query,
|
| 138 |
+
int num_key,
|
| 139 |
+
int value_dim,
|
| 140 |
+
int weight_dim
|
| 141 |
+
);
|
| 142 |
+
|
| 143 |
+
__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
|
| 144 |
+
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
| 145 |
+
int *key_mask, // [batch_size, num_key]
|
| 146 |
+
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
| 147 |
+
float *query_weight, // [batch_size, num_query, weight_dim]
|
| 148 |
+
float *key_weight, // [batch_size, num_key, weight_dim]
|
| 149 |
+
float *value, // [batch_size, num_key, value_dim]
|
| 150 |
+
float *cumulation_value, // [batch_size, num_query, value_dim]
|
| 151 |
+
int batch_size,
|
| 152 |
+
int num_hash_f,
|
| 153 |
+
int num_query,
|
| 154 |
+
int num_key,
|
| 155 |
+
int value_dim,
|
| 156 |
+
int weight_dim
|
| 157 |
+
);
|
.venv/lib/python3.11/site-packages/transformers/kernels/yoso/fast_lsh_cumulation_torch.cpp
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include "fast_lsh_cumulation.h"
|
| 4 |
+
#include "common_cuda.h"
|
| 5 |
+
#include <vector>
|
| 6 |
+
|
| 7 |
+
std::vector<at::Tensor> fast_hash(
|
| 8 |
+
at::Tensor query_mask,
|
| 9 |
+
at::Tensor query_vector,
|
| 10 |
+
at::Tensor key_mask,
|
| 11 |
+
at::Tensor key_vector,
|
| 12 |
+
int num_hash_f,
|
| 13 |
+
int hash_code_len,
|
| 14 |
+
bool use_cuda,
|
| 15 |
+
int version
|
| 16 |
+
) {
|
| 17 |
+
return fast_hash_ver1_kernel(
|
| 18 |
+
query_mask,
|
| 19 |
+
query_vector,
|
| 20 |
+
key_mask,
|
| 21 |
+
key_vector,
|
| 22 |
+
num_hash_f,
|
| 23 |
+
hash_code_len,
|
| 24 |
+
use_cuda
|
| 25 |
+
);
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
at::Tensor lsh_cumulation(
|
| 29 |
+
at::Tensor query_mask, // [batch_size, num_query]
|
| 30 |
+
at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f]
|
| 31 |
+
at::Tensor key_mask, // [batch_size, num_key]
|
| 32 |
+
at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 33 |
+
at::Tensor value, // [batch_size, num_key, value_dim]
|
| 34 |
+
int hashtable_capacity,
|
| 35 |
+
bool use_cuda,
|
| 36 |
+
int version
|
| 37 |
+
) {
|
| 38 |
+
return lsh_cumulation_ver1_kernel(
|
| 39 |
+
query_mask,
|
| 40 |
+
query_hash_code,
|
| 41 |
+
key_mask,
|
| 42 |
+
key_hash_code,
|
| 43 |
+
value,
|
| 44 |
+
hashtable_capacity,
|
| 45 |
+
use_cuda
|
| 46 |
+
);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
at::Tensor lsh_weighted_cumulation(
|
| 50 |
+
at::Tensor query_mask, // [batch_size, num_query]
|
| 51 |
+
at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f]
|
| 52 |
+
at::Tensor query_weight, // [batch_size, num_query, weight_dim]
|
| 53 |
+
at::Tensor key_mask, // [batch_size, num_key]
|
| 54 |
+
at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f]
|
| 55 |
+
at::Tensor key_weight, // [batch_size, num_key, weight_dim]
|
| 56 |
+
at::Tensor value, // [batch_size, num_key, value_dim]
|
| 57 |
+
int hashtable_capacity,
|
| 58 |
+
bool use_cuda,
|
| 59 |
+
int version
|
| 60 |
+
) {
|
| 61 |
+
if (version == 1) {
|
| 62 |
+
return lsh_weighted_cumulation_ver1_kernel(
|
| 63 |
+
query_mask,
|
| 64 |
+
query_hash_code,
|
| 65 |
+
query_weight,
|
| 66 |
+
key_mask,
|
| 67 |
+
key_hash_code,
|
| 68 |
+
key_weight,
|
| 69 |
+
value,
|
| 70 |
+
hashtable_capacity,
|
| 71 |
+
use_cuda
|
| 72 |
+
);
|
| 73 |
+
} else if (version == 2) {
|
| 74 |
+
return lsh_weighted_cumulation_ver2_kernel(
|
| 75 |
+
query_mask,
|
| 76 |
+
query_hash_code,
|
| 77 |
+
query_weight,
|
| 78 |
+
key_mask,
|
| 79 |
+
key_hash_code,
|
| 80 |
+
key_weight,
|
| 81 |
+
value,
|
| 82 |
+
hashtable_capacity,
|
| 83 |
+
use_cuda
|
| 84 |
+
);
|
| 85 |
+
} else if (version == 3) {
|
| 86 |
+
return lsh_weighted_cumulation_ver3_kernel(
|
| 87 |
+
query_mask,
|
| 88 |
+
query_hash_code,
|
| 89 |
+
query_weight,
|
| 90 |
+
key_mask,
|
| 91 |
+
key_hash_code,
|
| 92 |
+
key_weight,
|
| 93 |
+
value,
|
| 94 |
+
hashtable_capacity,
|
| 95 |
+
use_cuda
|
| 96 |
+
);
|
| 97 |
+
} else if (version == 4) {
|
| 98 |
+
return lsh_weighted_cumulation_ver4_kernel(
|
| 99 |
+
query_mask,
|
| 100 |
+
query_hash_code,
|
| 101 |
+
query_weight,
|
| 102 |
+
key_mask,
|
| 103 |
+
key_hash_code,
|
| 104 |
+
key_weight,
|
| 105 |
+
value,
|
| 106 |
+
hashtable_capacity,
|
| 107 |
+
use_cuda
|
| 108 |
+
);
|
| 109 |
+
} else {
|
| 110 |
+
return lsh_weighted_cumulation_ver3_kernel(
|
| 111 |
+
query_mask,
|
| 112 |
+
query_hash_code,
|
| 113 |
+
query_weight,
|
| 114 |
+
key_mask,
|
| 115 |
+
key_hash_code,
|
| 116 |
+
key_weight,
|
| 117 |
+
value,
|
| 118 |
+
hashtable_capacity,
|
| 119 |
+
use_cuda
|
| 120 |
+
);
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 125 |
+
m.def("fast_hash", &fast_hash, "Fast Hash (CUDA)");
|
| 126 |
+
m.def("lsh_cumulation", &lsh_cumulation, "LSH Cumulation (CUDA)");
|
| 127 |
+
m.def("lsh_weighted_cumulation", &lsh_weighted_cumulation, "LSH Weighted Cumulation (CUDA)");
|
| 128 |
+
}
|
.venv/lib/python3.11/site-packages/transformers/modelcard.py
ADDED
|
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Configuration base class and utilities."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, List, Optional, Union
|
| 24 |
+
|
| 25 |
+
import requests
|
| 26 |
+
import yaml
|
| 27 |
+
from huggingface_hub import model_info
|
| 28 |
+
from huggingface_hub.utils import HFValidationError
|
| 29 |
+
|
| 30 |
+
from . import __version__
|
| 31 |
+
from .models.auto.modeling_auto import (
|
| 32 |
+
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
| 33 |
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
| 34 |
+
MODEL_FOR_CTC_MAPPING_NAMES,
|
| 35 |
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
| 36 |
+
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
|
| 37 |
+
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
| 38 |
+
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
| 39 |
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
| 40 |
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
| 41 |
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
| 42 |
+
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
| 43 |
+
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
| 44 |
+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
| 45 |
+
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
| 46 |
+
)
|
| 47 |
+
from .training_args import ParallelMode
|
| 48 |
+
from .utils import (
|
| 49 |
+
MODEL_CARD_NAME,
|
| 50 |
+
cached_file,
|
| 51 |
+
is_datasets_available,
|
| 52 |
+
is_offline_mode,
|
| 53 |
+
is_tf_available,
|
| 54 |
+
is_tokenizers_available,
|
| 55 |
+
is_torch_available,
|
| 56 |
+
logging,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
TASK_MAPPING = {
|
| 61 |
+
"text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
| 62 |
+
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
| 63 |
+
"image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
|
| 64 |
+
"fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
| 65 |
+
"object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
| 66 |
+
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
| 67 |
+
"text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
| 68 |
+
"text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
| 69 |
+
"table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
| 70 |
+
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
| 71 |
+
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
| 72 |
+
"automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
|
| 73 |
+
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
logger = logging.get_logger(__name__)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ModelCard:
|
| 80 |
+
r"""
|
| 81 |
+
Structured Model Card class. Store model card as well as methods for loading/downloading/saving model cards.
|
| 82 |
+
|
| 83 |
+
Please read the following paper for details and explanation on the sections: "Model Cards for Model Reporting" by
|
| 84 |
+
Margaret Mitchell, Simone Wu, Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
|
| 85 |
+
Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards. Link: https://arxiv.org/abs/1810.03993
|
| 86 |
+
|
| 87 |
+
Note: A model card can be loaded and saved to disk.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, **kwargs):
|
| 91 |
+
warnings.warn(
|
| 92 |
+
"The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
| 93 |
+
)
|
| 94 |
+
# Recommended attributes from https://arxiv.org/abs/1810.03993 (see papers)
|
| 95 |
+
self.model_details = kwargs.pop("model_details", {})
|
| 96 |
+
self.intended_use = kwargs.pop("intended_use", {})
|
| 97 |
+
self.factors = kwargs.pop("factors", {})
|
| 98 |
+
self.metrics = kwargs.pop("metrics", {})
|
| 99 |
+
self.evaluation_data = kwargs.pop("evaluation_data", {})
|
| 100 |
+
self.training_data = kwargs.pop("training_data", {})
|
| 101 |
+
self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
|
| 102 |
+
self.ethical_considerations = kwargs.pop("ethical_considerations", {})
|
| 103 |
+
self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
|
| 104 |
+
|
| 105 |
+
# Open additional attributes
|
| 106 |
+
for key, value in kwargs.items():
|
| 107 |
+
try:
|
| 108 |
+
setattr(self, key, value)
|
| 109 |
+
except AttributeError as err:
|
| 110 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
| 111 |
+
raise err
|
| 112 |
+
|
| 113 |
+
def save_pretrained(self, save_directory_or_file):
|
| 114 |
+
"""Save a model card object to the directory or file `save_directory_or_file`."""
|
| 115 |
+
if os.path.isdir(save_directory_or_file):
|
| 116 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 117 |
+
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
|
| 118 |
+
else:
|
| 119 |
+
output_model_card_file = save_directory_or_file
|
| 120 |
+
|
| 121 |
+
self.to_json_file(output_model_card_file)
|
| 122 |
+
logger.info(f"Model card saved in {output_model_card_file}")
|
| 123 |
+
|
| 124 |
+
@classmethod
|
| 125 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 126 |
+
r"""
|
| 127 |
+
Instantiate a [`ModelCard`] from a pre-trained model model card.
|
| 128 |
+
|
| 129 |
+
Parameters:
|
| 130 |
+
pretrained_model_name_or_path: either:
|
| 131 |
+
|
| 132 |
+
- a string, the *model id* of a pretrained model card hosted inside a model repo on huggingface.co.
|
| 133 |
+
- a path to a *directory* containing a model card file saved using the [`~ModelCard.save_pretrained`]
|
| 134 |
+
method, e.g.: `./my_model_directory/`.
|
| 135 |
+
- a path or url to a saved model card JSON *file*, e.g.: `./my_model_directory/modelcard.json`.
|
| 136 |
+
|
| 137 |
+
cache_dir: (*optional*) string:
|
| 138 |
+
Path to a directory in which a downloaded pre-trained model card should be cached if the standard cache
|
| 139 |
+
should not be used.
|
| 140 |
+
|
| 141 |
+
kwargs: (*optional*) dict: key/value pairs with which to update the ModelCard object after loading.
|
| 142 |
+
|
| 143 |
+
- The values in kwargs of any keys which are model card attributes will be used to override the loaded
|
| 144 |
+
values.
|
| 145 |
+
- Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the
|
| 146 |
+
*return_unused_kwargs* keyword parameter.
|
| 147 |
+
|
| 148 |
+
proxies: (*optional*) dict, default None:
|
| 149 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
|
| 150 |
+
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
|
| 151 |
+
|
| 152 |
+
return_unused_kwargs: (*optional*) bool:
|
| 153 |
+
|
| 154 |
+
- If False, then this function returns just the final model card object.
|
| 155 |
+
- If True, then this functions returns a tuple *(model card, unused_kwargs)* where *unused_kwargs* is a
|
| 156 |
+
dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of
|
| 157 |
+
kwargs which has not been used to update *ModelCard* and is otherwise ignored.
|
| 158 |
+
|
| 159 |
+
Examples:
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
# Download model card from huggingface.co and cache.
|
| 163 |
+
modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased")
|
| 164 |
+
# Model card was saved using *save_pretrained('./test/saved_model/')*
|
| 165 |
+
modelcard = ModelCard.from_pretrained("./test/saved_model/")
|
| 166 |
+
modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
|
| 167 |
+
modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
|
| 168 |
+
```"""
|
| 169 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 170 |
+
proxies = kwargs.pop("proxies", None)
|
| 171 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
| 172 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
| 173 |
+
|
| 174 |
+
user_agent = {"file_type": "model_card"}
|
| 175 |
+
if from_pipeline is not None:
|
| 176 |
+
user_agent["using_pipeline"] = from_pipeline
|
| 177 |
+
|
| 178 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
| 179 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
| 180 |
+
resolved_model_card_file = pretrained_model_name_or_path
|
| 181 |
+
is_local = True
|
| 182 |
+
else:
|
| 183 |
+
try:
|
| 184 |
+
# Load from URL or cache if already cached
|
| 185 |
+
resolved_model_card_file = cached_file(
|
| 186 |
+
pretrained_model_name_or_path,
|
| 187 |
+
filename=MODEL_CARD_NAME,
|
| 188 |
+
cache_dir=cache_dir,
|
| 189 |
+
proxies=proxies,
|
| 190 |
+
user_agent=user_agent,
|
| 191 |
+
)
|
| 192 |
+
if is_local:
|
| 193 |
+
logger.info(f"loading model card file {resolved_model_card_file}")
|
| 194 |
+
else:
|
| 195 |
+
logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
|
| 196 |
+
# Load model card
|
| 197 |
+
modelcard = cls.from_json_file(resolved_model_card_file)
|
| 198 |
+
|
| 199 |
+
except (EnvironmentError, json.JSONDecodeError):
|
| 200 |
+
# We fall back on creating an empty model card
|
| 201 |
+
modelcard = cls()
|
| 202 |
+
|
| 203 |
+
# Update model card with kwargs if needed
|
| 204 |
+
to_remove = []
|
| 205 |
+
for key, value in kwargs.items():
|
| 206 |
+
if hasattr(modelcard, key):
|
| 207 |
+
setattr(modelcard, key, value)
|
| 208 |
+
to_remove.append(key)
|
| 209 |
+
for key in to_remove:
|
| 210 |
+
kwargs.pop(key, None)
|
| 211 |
+
|
| 212 |
+
logger.info(f"Model card: {modelcard}")
|
| 213 |
+
if return_unused_kwargs:
|
| 214 |
+
return modelcard, kwargs
|
| 215 |
+
else:
|
| 216 |
+
return modelcard
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def from_dict(cls, json_object):
|
| 220 |
+
"""Constructs a `ModelCard` from a Python dictionary of parameters."""
|
| 221 |
+
return cls(**json_object)
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def from_json_file(cls, json_file):
|
| 225 |
+
"""Constructs a `ModelCard` from a json file of parameters."""
|
| 226 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 227 |
+
text = reader.read()
|
| 228 |
+
dict_obj = json.loads(text)
|
| 229 |
+
return cls(**dict_obj)
|
| 230 |
+
|
| 231 |
+
def __eq__(self, other):
|
| 232 |
+
return self.__dict__ == other.__dict__
|
| 233 |
+
|
| 234 |
+
def __repr__(self):
|
| 235 |
+
return str(self.to_json_string())
|
| 236 |
+
|
| 237 |
+
def to_dict(self):
|
| 238 |
+
"""Serializes this instance to a Python dictionary."""
|
| 239 |
+
output = copy.deepcopy(self.__dict__)
|
| 240 |
+
return output
|
| 241 |
+
|
| 242 |
+
def to_json_string(self):
|
| 243 |
+
"""Serializes this instance to a JSON string."""
|
| 244 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
| 245 |
+
|
| 246 |
+
def to_json_file(self, json_file_path):
|
| 247 |
+
"""Save this instance to a json file."""
|
| 248 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 249 |
+
writer.write(self.to_json_string())
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
AUTOGENERATED_TRAINER_COMMENT = """
|
| 253 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 254 |
+
should probably proofread and complete it, then remove this comment. -->
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
AUTOGENERATED_KERAS_COMMENT = """
|
| 258 |
+
<!-- This model card has been generated automatically according to the information Keras had access to. You should
|
| 259 |
+
probably proofread and complete it, then remove this comment. -->
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
TASK_TAG_TO_NAME_MAPPING = {
|
| 264 |
+
"fill-mask": "Masked Language Modeling",
|
| 265 |
+
"image-classification": "Image Classification",
|
| 266 |
+
"image-segmentation": "Image Segmentation",
|
| 267 |
+
"multiple-choice": "Multiple Choice",
|
| 268 |
+
"object-detection": "Object Detection",
|
| 269 |
+
"question-answering": "Question Answering",
|
| 270 |
+
"summarization": "Summarization",
|
| 271 |
+
"table-question-answering": "Table Question Answering",
|
| 272 |
+
"text-classification": "Text Classification",
|
| 273 |
+
"text-generation": "Causal Language Modeling",
|
| 274 |
+
"text2text-generation": "Sequence-to-sequence Language Modeling",
|
| 275 |
+
"token-classification": "Token Classification",
|
| 276 |
+
"translation": "Translation",
|
| 277 |
+
"zero-shot-classification": "Zero Shot Classification",
|
| 278 |
+
"automatic-speech-recognition": "Automatic Speech Recognition",
|
| 279 |
+
"audio-classification": "Audio Classification",
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
METRIC_TAGS = [
|
| 284 |
+
"accuracy",
|
| 285 |
+
"bleu",
|
| 286 |
+
"f1",
|
| 287 |
+
"matthews_correlation",
|
| 288 |
+
"pearsonr",
|
| 289 |
+
"precision",
|
| 290 |
+
"recall",
|
| 291 |
+
"rouge",
|
| 292 |
+
"sacrebleu",
|
| 293 |
+
"spearmanr",
|
| 294 |
+
"wer",
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def _listify(obj):
|
| 299 |
+
if obj is None:
|
| 300 |
+
return []
|
| 301 |
+
elif isinstance(obj, str):
|
| 302 |
+
return [obj]
|
| 303 |
+
else:
|
| 304 |
+
return obj
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _insert_values_as_list(metadata, name, values):
|
| 308 |
+
if values is None:
|
| 309 |
+
return metadata
|
| 310 |
+
if isinstance(values, str):
|
| 311 |
+
values = [values]
|
| 312 |
+
values = [v for v in values if v is not None]
|
| 313 |
+
if len(values) == 0:
|
| 314 |
+
return metadata
|
| 315 |
+
metadata[name] = values
|
| 316 |
+
return metadata
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def infer_metric_tags_from_eval_results(eval_results):
|
| 320 |
+
if eval_results is None:
|
| 321 |
+
return {}
|
| 322 |
+
result = {}
|
| 323 |
+
for key in eval_results.keys():
|
| 324 |
+
if key.lower().replace(" ", "_") in METRIC_TAGS:
|
| 325 |
+
result[key.lower().replace(" ", "_")] = key
|
| 326 |
+
elif key.lower() == "rouge1":
|
| 327 |
+
result["rouge"] = key
|
| 328 |
+
return result
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _insert_value(metadata, name, value):
|
| 332 |
+
if value is None:
|
| 333 |
+
return metadata
|
| 334 |
+
metadata[name] = value
|
| 335 |
+
return metadata
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def is_hf_dataset(dataset):
|
| 339 |
+
if not is_datasets_available():
|
| 340 |
+
return False
|
| 341 |
+
|
| 342 |
+
from datasets import Dataset, IterableDataset
|
| 343 |
+
|
| 344 |
+
return isinstance(dataset, (Dataset, IterableDataset))
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def _get_mapping_values(mapping):
|
| 348 |
+
result = []
|
| 349 |
+
for v in mapping.values():
|
| 350 |
+
if isinstance(v, (tuple, list)):
|
| 351 |
+
result += list(v)
|
| 352 |
+
else:
|
| 353 |
+
result.append(v)
|
| 354 |
+
return result
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@dataclass
|
| 358 |
+
class TrainingSummary:
|
| 359 |
+
model_name: str
|
| 360 |
+
language: Optional[Union[str, List[str]]] = None
|
| 361 |
+
license: Optional[str] = None
|
| 362 |
+
tags: Optional[Union[str, List[str]]] = None
|
| 363 |
+
finetuned_from: Optional[str] = None
|
| 364 |
+
tasks: Optional[Union[str, List[str]]] = None
|
| 365 |
+
dataset: Optional[Union[str, List[str]]] = None
|
| 366 |
+
dataset_tags: Optional[Union[str, List[str]]] = None
|
| 367 |
+
dataset_args: Optional[Union[str, List[str]]] = None
|
| 368 |
+
dataset_metadata: Optional[Dict[str, Any]] = None
|
| 369 |
+
eval_results: Optional[Dict[str, float]] = None
|
| 370 |
+
eval_lines: Optional[List[str]] = None
|
| 371 |
+
hyperparameters: Optional[Dict[str, Any]] = None
|
| 372 |
+
source: Optional[str] = "trainer"
|
| 373 |
+
|
| 374 |
+
def __post_init__(self):
|
| 375 |
+
# Infer default license from the checkpoint used, if possible.
|
| 376 |
+
if (
|
| 377 |
+
self.license is None
|
| 378 |
+
and not is_offline_mode()
|
| 379 |
+
and self.finetuned_from is not None
|
| 380 |
+
and len(self.finetuned_from) > 0
|
| 381 |
+
):
|
| 382 |
+
try:
|
| 383 |
+
info = model_info(self.finetuned_from)
|
| 384 |
+
for tag in info.tags:
|
| 385 |
+
if tag.startswith("license:"):
|
| 386 |
+
self.license = tag[8:]
|
| 387 |
+
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError, HFValidationError):
|
| 388 |
+
pass
|
| 389 |
+
|
| 390 |
+
def create_model_index(self, metric_mapping):
|
| 391 |
+
model_index = {"name": self.model_name}
|
| 392 |
+
|
| 393 |
+
# Dataset mapping tag -> name
|
| 394 |
+
dataset_names = _listify(self.dataset)
|
| 395 |
+
dataset_tags = _listify(self.dataset_tags)
|
| 396 |
+
dataset_args = _listify(self.dataset_args)
|
| 397 |
+
dataset_metadata = _listify(self.dataset_metadata)
|
| 398 |
+
if len(dataset_args) < len(dataset_tags):
|
| 399 |
+
dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
|
| 400 |
+
dataset_mapping = dict(zip(dataset_tags, dataset_names))
|
| 401 |
+
dataset_arg_mapping = dict(zip(dataset_tags, dataset_args))
|
| 402 |
+
dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata))
|
| 403 |
+
|
| 404 |
+
task_mapping = {
|
| 405 |
+
task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
model_index["results"] = []
|
| 409 |
+
|
| 410 |
+
if len(task_mapping) == 0 and len(dataset_mapping) == 0:
|
| 411 |
+
return [model_index]
|
| 412 |
+
if len(task_mapping) == 0:
|
| 413 |
+
task_mapping = {None: None}
|
| 414 |
+
if len(dataset_mapping) == 0:
|
| 415 |
+
dataset_mapping = {None: None}
|
| 416 |
+
|
| 417 |
+
# One entry per dataset and per task
|
| 418 |
+
all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
|
| 419 |
+
for task_tag, ds_tag in all_possibilities:
|
| 420 |
+
result = {}
|
| 421 |
+
if task_tag is not None:
|
| 422 |
+
result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
|
| 423 |
+
|
| 424 |
+
if ds_tag is not None:
|
| 425 |
+
metadata = dataset_metadata_mapping.get(ds_tag, {})
|
| 426 |
+
result["dataset"] = {
|
| 427 |
+
"name": dataset_mapping[ds_tag],
|
| 428 |
+
"type": ds_tag,
|
| 429 |
+
**metadata,
|
| 430 |
+
}
|
| 431 |
+
if dataset_arg_mapping[ds_tag] is not None:
|
| 432 |
+
result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
|
| 433 |
+
|
| 434 |
+
if len(metric_mapping) > 0:
|
| 435 |
+
result["metrics"] = []
|
| 436 |
+
for metric_tag, metric_name in metric_mapping.items():
|
| 437 |
+
result["metrics"].append(
|
| 438 |
+
{
|
| 439 |
+
"name": metric_name,
|
| 440 |
+
"type": metric_tag,
|
| 441 |
+
"value": self.eval_results[metric_name],
|
| 442 |
+
}
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Remove partial results to avoid the model card being rejected.
|
| 446 |
+
if "task" in result and "dataset" in result and "metrics" in result:
|
| 447 |
+
model_index["results"].append(result)
|
| 448 |
+
else:
|
| 449 |
+
logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}")
|
| 450 |
+
|
| 451 |
+
return [model_index]
|
| 452 |
+
|
| 453 |
+
def create_metadata(self):
|
| 454 |
+
metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)
|
| 455 |
+
|
| 456 |
+
metadata = {}
|
| 457 |
+
metadata = _insert_value(metadata, "library_name", "transformers")
|
| 458 |
+
metadata = _insert_values_as_list(metadata, "language", self.language)
|
| 459 |
+
metadata = _insert_value(metadata, "license", self.license)
|
| 460 |
+
if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0:
|
| 461 |
+
metadata = _insert_value(metadata, "base_model", self.finetuned_from)
|
| 462 |
+
metadata = _insert_values_as_list(metadata, "tags", self.tags)
|
| 463 |
+
metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags)
|
| 464 |
+
metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys()))
|
| 465 |
+
metadata["model-index"] = self.create_model_index(metric_mapping)
|
| 466 |
+
|
| 467 |
+
return metadata
|
| 468 |
+
|
| 469 |
+
def to_model_card(self):
|
| 470 |
+
model_card = ""
|
| 471 |
+
|
| 472 |
+
metadata = yaml.dump(self.create_metadata(), sort_keys=False)
|
| 473 |
+
if len(metadata) > 0:
|
| 474 |
+
model_card = f"---\n{metadata}---\n"
|
| 475 |
+
|
| 476 |
+
# Now the model card for realsies.
|
| 477 |
+
if self.source == "trainer":
|
| 478 |
+
model_card += AUTOGENERATED_TRAINER_COMMENT
|
| 479 |
+
else:
|
| 480 |
+
model_card += AUTOGENERATED_KERAS_COMMENT
|
| 481 |
+
|
| 482 |
+
model_card += f"\n# {self.model_name}\n\n"
|
| 483 |
+
|
| 484 |
+
if self.finetuned_from is None:
|
| 485 |
+
model_card += "This model was trained from scratch on "
|
| 486 |
+
else:
|
| 487 |
+
model_card += (
|
| 488 |
+
"This model is a fine-tuned version of"
|
| 489 |
+
f" [{self.finetuned_from}](https://huggingface.co/{self.finetuned_from}) on "
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
if self.dataset is None:
|
| 493 |
+
model_card += "an unknown dataset."
|
| 494 |
+
else:
|
| 495 |
+
if isinstance(self.dataset, str):
|
| 496 |
+
model_card += f"the {self.dataset} dataset."
|
| 497 |
+
elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1:
|
| 498 |
+
model_card += f"the {self.dataset[0]} dataset."
|
| 499 |
+
else:
|
| 500 |
+
model_card += (
|
| 501 |
+
", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets."
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if self.eval_results is not None:
|
| 505 |
+
model_card += "\nIt achieves the following results on the evaluation set:\n"
|
| 506 |
+
model_card += "\n".join([f"- {name}: {_maybe_round(value)}" for name, value in self.eval_results.items()])
|
| 507 |
+
model_card += "\n"
|
| 508 |
+
|
| 509 |
+
model_card += "\n## Model description\n\nMore information needed\n"
|
| 510 |
+
model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
|
| 511 |
+
model_card += "\n## Training and evaluation data\n\nMore information needed\n"
|
| 512 |
+
|
| 513 |
+
model_card += "\n## Training procedure\n"
|
| 514 |
+
model_card += "\n### Training hyperparameters\n"
|
| 515 |
+
if self.hyperparameters is not None:
|
| 516 |
+
model_card += "\nThe following hyperparameters were used during training:\n"
|
| 517 |
+
model_card += "\n".join([f"- {name}: {value}" for name, value in self.hyperparameters.items()])
|
| 518 |
+
model_card += "\n"
|
| 519 |
+
else:
|
| 520 |
+
model_card += "\nMore information needed\n"
|
| 521 |
+
|
| 522 |
+
if self.eval_lines is not None:
|
| 523 |
+
model_card += "\n### Training results\n\n"
|
| 524 |
+
model_card += make_markdown_table(self.eval_lines)
|
| 525 |
+
model_card += "\n"
|
| 526 |
+
|
| 527 |
+
model_card += "\n### Framework versions\n\n"
|
| 528 |
+
model_card += f"- Transformers {__version__}\n"
|
| 529 |
+
|
| 530 |
+
if self.source == "trainer" and is_torch_available():
|
| 531 |
+
import torch
|
| 532 |
+
|
| 533 |
+
model_card += f"- Pytorch {torch.__version__}\n"
|
| 534 |
+
elif self.source == "keras" and is_tf_available():
|
| 535 |
+
import tensorflow as tf
|
| 536 |
+
|
| 537 |
+
model_card += f"- TensorFlow {tf.__version__}\n"
|
| 538 |
+
if is_datasets_available():
|
| 539 |
+
import datasets
|
| 540 |
+
|
| 541 |
+
model_card += f"- Datasets {datasets.__version__}\n"
|
| 542 |
+
if is_tokenizers_available():
|
| 543 |
+
import tokenizers
|
| 544 |
+
|
| 545 |
+
model_card += f"- Tokenizers {tokenizers.__version__}\n"
|
| 546 |
+
|
| 547 |
+
return model_card
|
| 548 |
+
|
| 549 |
+
@classmethod
|
| 550 |
+
def from_trainer(
|
| 551 |
+
cls,
|
| 552 |
+
trainer,
|
| 553 |
+
language=None,
|
| 554 |
+
license=None,
|
| 555 |
+
tags=None,
|
| 556 |
+
model_name=None,
|
| 557 |
+
finetuned_from=None,
|
| 558 |
+
tasks=None,
|
| 559 |
+
dataset_tags=None,
|
| 560 |
+
dataset_metadata=None,
|
| 561 |
+
dataset=None,
|
| 562 |
+
dataset_args=None,
|
| 563 |
+
):
|
| 564 |
+
# Infer default from dataset
|
| 565 |
+
one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset
|
| 566 |
+
if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None):
|
| 567 |
+
default_tag = one_dataset.builder_name
|
| 568 |
+
# Those are not real datasets from the Hub so we exclude them.
|
| 569 |
+
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
|
| 570 |
+
if dataset_metadata is None:
|
| 571 |
+
dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
|
| 572 |
+
if dataset_tags is None:
|
| 573 |
+
dataset_tags = [default_tag]
|
| 574 |
+
if dataset_args is None:
|
| 575 |
+
dataset_args = [one_dataset.config_name]
|
| 576 |
+
|
| 577 |
+
if dataset is None and dataset_tags is not None:
|
| 578 |
+
dataset = dataset_tags
|
| 579 |
+
|
| 580 |
+
# Infer default finetuned_from
|
| 581 |
+
if (
|
| 582 |
+
finetuned_from is None
|
| 583 |
+
and hasattr(trainer.model.config, "_name_or_path")
|
| 584 |
+
and not os.path.isdir(trainer.model.config._name_or_path)
|
| 585 |
+
):
|
| 586 |
+
finetuned_from = trainer.model.config._name_or_path
|
| 587 |
+
|
| 588 |
+
# Infer default task tag:
|
| 589 |
+
if tasks is None:
|
| 590 |
+
model_class_name = trainer.model.__class__.__name__
|
| 591 |
+
for task, mapping in TASK_MAPPING.items():
|
| 592 |
+
if model_class_name in _get_mapping_values(mapping):
|
| 593 |
+
tasks = task
|
| 594 |
+
|
| 595 |
+
if model_name is None:
|
| 596 |
+
model_name = Path(trainer.args.output_dir).name
|
| 597 |
+
if len(model_name) == 0:
|
| 598 |
+
model_name = finetuned_from
|
| 599 |
+
|
| 600 |
+
# Add `generated_from_trainer` to the tags
|
| 601 |
+
if tags is None:
|
| 602 |
+
tags = ["generated_from_trainer"]
|
| 603 |
+
elif isinstance(tags, str) and tags != "generated_from_trainer":
|
| 604 |
+
tags = [tags, "generated_from_trainer"]
|
| 605 |
+
elif "generated_from_trainer" not in tags:
|
| 606 |
+
tags.append("generated_from_trainer")
|
| 607 |
+
|
| 608 |
+
_, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
|
| 609 |
+
hyperparameters = extract_hyperparameters_from_trainer(trainer)
|
| 610 |
+
|
| 611 |
+
return cls(
|
| 612 |
+
language=language,
|
| 613 |
+
license=license,
|
| 614 |
+
tags=tags,
|
| 615 |
+
model_name=model_name,
|
| 616 |
+
finetuned_from=finetuned_from,
|
| 617 |
+
tasks=tasks,
|
| 618 |
+
dataset=dataset,
|
| 619 |
+
dataset_tags=dataset_tags,
|
| 620 |
+
dataset_args=dataset_args,
|
| 621 |
+
dataset_metadata=dataset_metadata,
|
| 622 |
+
eval_results=eval_results,
|
| 623 |
+
eval_lines=eval_lines,
|
| 624 |
+
hyperparameters=hyperparameters,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
@classmethod
|
| 628 |
+
def from_keras(
|
| 629 |
+
cls,
|
| 630 |
+
model,
|
| 631 |
+
model_name,
|
| 632 |
+
keras_history=None,
|
| 633 |
+
language=None,
|
| 634 |
+
license=None,
|
| 635 |
+
tags=None,
|
| 636 |
+
finetuned_from=None,
|
| 637 |
+
tasks=None,
|
| 638 |
+
dataset_tags=None,
|
| 639 |
+
dataset=None,
|
| 640 |
+
dataset_args=None,
|
| 641 |
+
):
|
| 642 |
+
# Infer default from dataset
|
| 643 |
+
if dataset is not None:
|
| 644 |
+
if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
|
| 645 |
+
default_tag = dataset.builder_name
|
| 646 |
+
# Those are not real datasets from the Hub so we exclude them.
|
| 647 |
+
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
|
| 648 |
+
if dataset_tags is None:
|
| 649 |
+
dataset_tags = [default_tag]
|
| 650 |
+
if dataset_args is None:
|
| 651 |
+
dataset_args = [dataset.config_name]
|
| 652 |
+
|
| 653 |
+
if dataset is None and dataset_tags is not None:
|
| 654 |
+
dataset = dataset_tags
|
| 655 |
+
|
| 656 |
+
# Infer default finetuned_from
|
| 657 |
+
if (
|
| 658 |
+
finetuned_from is None
|
| 659 |
+
and hasattr(model.config, "_name_or_path")
|
| 660 |
+
and not os.path.isdir(model.config._name_or_path)
|
| 661 |
+
):
|
| 662 |
+
finetuned_from = model.config._name_or_path
|
| 663 |
+
|
| 664 |
+
# Infer default task tag:
|
| 665 |
+
if tasks is None:
|
| 666 |
+
model_class_name = model.__class__.__name__
|
| 667 |
+
for task, mapping in TASK_MAPPING.items():
|
| 668 |
+
if model_class_name in _get_mapping_values(mapping):
|
| 669 |
+
tasks = task
|
| 670 |
+
|
| 671 |
+
# Add `generated_from_keras_callback` to the tags
|
| 672 |
+
if tags is None:
|
| 673 |
+
tags = ["generated_from_keras_callback"]
|
| 674 |
+
elif isinstance(tags, str) and tags != "generated_from_keras_callback":
|
| 675 |
+
tags = [tags, "generated_from_keras_callback"]
|
| 676 |
+
elif "generated_from_keras_callback" not in tags:
|
| 677 |
+
tags.append("generated_from_keras_callback")
|
| 678 |
+
|
| 679 |
+
if keras_history is not None:
|
| 680 |
+
_, eval_lines, eval_results = parse_keras_history(keras_history)
|
| 681 |
+
else:
|
| 682 |
+
eval_lines = []
|
| 683 |
+
eval_results = {}
|
| 684 |
+
hyperparameters = extract_hyperparameters_from_keras(model)
|
| 685 |
+
|
| 686 |
+
return cls(
|
| 687 |
+
language=language,
|
| 688 |
+
license=license,
|
| 689 |
+
tags=tags,
|
| 690 |
+
model_name=model_name,
|
| 691 |
+
finetuned_from=finetuned_from,
|
| 692 |
+
tasks=tasks,
|
| 693 |
+
dataset_tags=dataset_tags,
|
| 694 |
+
dataset=dataset,
|
| 695 |
+
dataset_args=dataset_args,
|
| 696 |
+
eval_results=eval_results,
|
| 697 |
+
eval_lines=eval_lines,
|
| 698 |
+
hyperparameters=hyperparameters,
|
| 699 |
+
source="keras",
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def parse_keras_history(logs):
|
| 704 |
+
"""
|
| 705 |
+
Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict`
|
| 706 |
+
passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`.
|
| 707 |
+
"""
|
| 708 |
+
if hasattr(logs, "history"):
|
| 709 |
+
# This looks like a `History` object
|
| 710 |
+
if not hasattr(logs, "epoch"):
|
| 711 |
+
# This history looks empty, return empty results
|
| 712 |
+
return None, [], {}
|
| 713 |
+
logs.history["epoch"] = logs.epoch
|
| 714 |
+
logs = logs.history
|
| 715 |
+
else:
|
| 716 |
+
# Training logs is a list of dicts, let's invert it to a dict of lists to match a History object
|
| 717 |
+
logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
|
| 718 |
+
|
| 719 |
+
lines = []
|
| 720 |
+
for i in range(len(logs["epoch"])):
|
| 721 |
+
epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
|
| 722 |
+
values = {}
|
| 723 |
+
for k, v in epoch_dict.items():
|
| 724 |
+
if k.startswith("val_"):
|
| 725 |
+
k = "validation_" + k[4:]
|
| 726 |
+
elif k != "epoch":
|
| 727 |
+
k = "train_" + k
|
| 728 |
+
splits = k.split("_")
|
| 729 |
+
name = " ".join([part.capitalize() for part in splits])
|
| 730 |
+
values[name] = v
|
| 731 |
+
lines.append(values)
|
| 732 |
+
|
| 733 |
+
eval_results = lines[-1]
|
| 734 |
+
|
| 735 |
+
return logs, lines, eval_results
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
def parse_log_history(log_history):
|
| 739 |
+
"""
|
| 740 |
+
Parse the `log_history` of a Trainer to get the intermediate and final evaluation results.
|
| 741 |
+
"""
|
| 742 |
+
idx = 0
|
| 743 |
+
while idx < len(log_history) and "train_runtime" not in log_history[idx]:
|
| 744 |
+
idx += 1
|
| 745 |
+
|
| 746 |
+
# If there are no training logs
|
| 747 |
+
if idx == len(log_history):
|
| 748 |
+
idx -= 1
|
| 749 |
+
while idx >= 0 and "eval_loss" not in log_history[idx]:
|
| 750 |
+
idx -= 1
|
| 751 |
+
|
| 752 |
+
if idx >= 0:
|
| 753 |
+
return None, None, log_history[idx]
|
| 754 |
+
else:
|
| 755 |
+
return None, None, None
|
| 756 |
+
|
| 757 |
+
# From now one we can assume we have training logs:
|
| 758 |
+
train_log = log_history[idx]
|
| 759 |
+
lines = []
|
| 760 |
+
training_loss = "No log"
|
| 761 |
+
for i in range(idx):
|
| 762 |
+
if "loss" in log_history[i]:
|
| 763 |
+
training_loss = log_history[i]["loss"]
|
| 764 |
+
if "eval_loss" in log_history[i]:
|
| 765 |
+
metrics = log_history[i].copy()
|
| 766 |
+
_ = metrics.pop("total_flos", None)
|
| 767 |
+
epoch = metrics.pop("epoch", None)
|
| 768 |
+
step = metrics.pop("step", None)
|
| 769 |
+
_ = metrics.pop("eval_runtime", None)
|
| 770 |
+
_ = metrics.pop("eval_samples_per_second", None)
|
| 771 |
+
_ = metrics.pop("eval_steps_per_second", None)
|
| 772 |
+
_ = metrics.pop("eval_jit_compilation_time", None)
|
| 773 |
+
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
|
| 774 |
+
for k, v in metrics.items():
|
| 775 |
+
if k == "eval_loss":
|
| 776 |
+
values["Validation Loss"] = v
|
| 777 |
+
else:
|
| 778 |
+
splits = k.split("_")
|
| 779 |
+
name = " ".join([part.capitalize() for part in splits[1:]])
|
| 780 |
+
values[name] = v
|
| 781 |
+
lines.append(values)
|
| 782 |
+
|
| 783 |
+
idx = len(log_history) - 1
|
| 784 |
+
while idx >= 0 and "eval_loss" not in log_history[idx]:
|
| 785 |
+
idx -= 1
|
| 786 |
+
|
| 787 |
+
if idx > 0:
|
| 788 |
+
eval_results = {}
|
| 789 |
+
for key, value in log_history[idx].items():
|
| 790 |
+
if key.startswith("eval_"):
|
| 791 |
+
key = key[5:]
|
| 792 |
+
if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
|
| 793 |
+
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
|
| 794 |
+
eval_results[camel_cased_key] = value
|
| 795 |
+
return train_log, lines, eval_results
|
| 796 |
+
else:
|
| 797 |
+
return train_log, lines, None
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
def extract_hyperparameters_from_keras(model):
|
| 801 |
+
from .modeling_tf_utils import keras
|
| 802 |
+
|
| 803 |
+
hyperparameters = {}
|
| 804 |
+
if hasattr(model, "optimizer") and model.optimizer is not None:
|
| 805 |
+
hyperparameters["optimizer"] = model.optimizer.get_config()
|
| 806 |
+
else:
|
| 807 |
+
hyperparameters["optimizer"] = None
|
| 808 |
+
hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name
|
| 809 |
+
|
| 810 |
+
return hyperparameters
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
def _maybe_round(v, decimals=4):
|
| 814 |
+
if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
|
| 815 |
+
return f"{v:.{decimals}f}"
|
| 816 |
+
return str(v)
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def _regular_table_line(values, col_widths):
|
| 820 |
+
values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
|
| 821 |
+
return "".join(values_with_space) + "|\n"
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
def _second_table_line(col_widths):
|
| 825 |
+
values = ["|:" + "-" * w + ":" for w in col_widths]
|
| 826 |
+
return "".join(values) + "|\n"
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def make_markdown_table(lines):
|
| 830 |
+
"""
|
| 831 |
+
Create a nice Markdown table from the results in `lines`.
|
| 832 |
+
"""
|
| 833 |
+
if lines is None or len(lines) == 0:
|
| 834 |
+
return ""
|
| 835 |
+
col_widths = {key: len(str(key)) for key in lines[0].keys()}
|
| 836 |
+
for line in lines:
|
| 837 |
+
for key, value in line.items():
|
| 838 |
+
if col_widths[key] < len(_maybe_round(value)):
|
| 839 |
+
col_widths[key] = len(_maybe_round(value))
|
| 840 |
+
|
| 841 |
+
table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
|
| 842 |
+
table += _second_table_line(list(col_widths.values()))
|
| 843 |
+
for line in lines:
|
| 844 |
+
table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
|
| 845 |
+
return table
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
_TRAINING_ARGS_KEYS = [
|
| 849 |
+
"learning_rate",
|
| 850 |
+
"train_batch_size",
|
| 851 |
+
"eval_batch_size",
|
| 852 |
+
"seed",
|
| 853 |
+
]
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
def extract_hyperparameters_from_trainer(trainer):
|
| 857 |
+
hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}
|
| 858 |
+
|
| 859 |
+
if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
|
| 860 |
+
hyperparameters["distributed_type"] = (
|
| 861 |
+
"multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
|
| 862 |
+
)
|
| 863 |
+
if trainer.args.world_size > 1:
|
| 864 |
+
hyperparameters["num_devices"] = trainer.args.world_size
|
| 865 |
+
if trainer.args.gradient_accumulation_steps > 1:
|
| 866 |
+
hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps
|
| 867 |
+
|
| 868 |
+
total_train_batch_size = (
|
| 869 |
+
trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
|
| 870 |
+
)
|
| 871 |
+
if total_train_batch_size != hyperparameters["train_batch_size"]:
|
| 872 |
+
hyperparameters["total_train_batch_size"] = total_train_batch_size
|
| 873 |
+
total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
|
| 874 |
+
if total_eval_batch_size != hyperparameters["eval_batch_size"]:
|
| 875 |
+
hyperparameters["total_eval_batch_size"] = total_eval_batch_size
|
| 876 |
+
|
| 877 |
+
if trainer.args.optim:
|
| 878 |
+
optimizer_name = trainer.args.optim
|
| 879 |
+
optimizer_args = trainer.args.optim_args if trainer.args.optim_args else "No additional optimizer arguments"
|
| 880 |
+
|
| 881 |
+
if "adam" in optimizer_name.lower():
|
| 882 |
+
hyperparameters["optimizer"] = (
|
| 883 |
+
f"Use {optimizer_name} with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and"
|
| 884 |
+
f" epsilon={trainer.args.adam_epsilon} and optimizer_args={optimizer_args}"
|
| 885 |
+
)
|
| 886 |
+
else:
|
| 887 |
+
hyperparameters["optimizer"] = f"Use {optimizer_name} and the args are:\n{optimizer_args}"
|
| 888 |
+
|
| 889 |
+
hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
|
| 890 |
+
if trainer.args.warmup_ratio != 0.0:
|
| 891 |
+
hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
|
| 892 |
+
if trainer.args.warmup_steps != 0.0:
|
| 893 |
+
hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
|
| 894 |
+
if trainer.args.max_steps != -1:
|
| 895 |
+
hyperparameters["training_steps"] = trainer.args.max_steps
|
| 896 |
+
else:
|
| 897 |
+
hyperparameters["num_epochs"] = trainer.args.num_train_epochs
|
| 898 |
+
|
| 899 |
+
if trainer.args.fp16:
|
| 900 |
+
if trainer.use_apex:
|
| 901 |
+
hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
|
| 902 |
+
else:
|
| 903 |
+
hyperparameters["mixed_precision_training"] = "Native AMP"
|
| 904 |
+
|
| 905 |
+
if trainer.args.label_smoothing_factor != 0.0:
|
| 906 |
+
hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
|
| 907 |
+
|
| 908 |
+
return hyperparameters
|
.venv/lib/python3.11/site-packages/transformers/modeling_attn_mask_utils.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from .utils.import_utils import is_torchdynamo_compiling
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class AttentionMaskConverter:
|
| 24 |
+
"""
|
| 25 |
+
A utility attention mask class that allows one to:
|
| 26 |
+
- Create a causal 4d mask
|
| 27 |
+
- Create a causal 4d mask with slided window
|
| 28 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
| 29 |
+
key_value_length) that can be multiplied with attention scores
|
| 30 |
+
|
| 31 |
+
Examples:
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
>>> import torch
|
| 35 |
+
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 36 |
+
|
| 37 |
+
>>> converter = AttentionMaskConverter(True)
|
| 38 |
+
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
|
| 39 |
+
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 40 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 41 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 42 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
|
| 43 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Parameters:
|
| 47 |
+
is_causal (`bool`):
|
| 48 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
| 49 |
+
|
| 50 |
+
sliding_window (`int`, *optional*):
|
| 51 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
is_causal: bool
|
| 55 |
+
sliding_window: int
|
| 56 |
+
|
| 57 |
+
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
|
| 58 |
+
self.is_causal = is_causal
|
| 59 |
+
self.sliding_window = sliding_window
|
| 60 |
+
|
| 61 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
| 62 |
+
raise ValueError(
|
| 63 |
+
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def to_causal_4d(
|
| 67 |
+
self,
|
| 68 |
+
batch_size: int,
|
| 69 |
+
query_length: int,
|
| 70 |
+
key_value_length: int,
|
| 71 |
+
dtype: torch.dtype,
|
| 72 |
+
device: Union[torch.device, "str"] = "cpu",
|
| 73 |
+
) -> Optional[torch.Tensor]:
|
| 74 |
+
"""
|
| 75 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
| 76 |
+
bias to upper right hand triangular matrix (causal mask).
|
| 77 |
+
"""
|
| 78 |
+
if not self.is_causal:
|
| 79 |
+
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
|
| 80 |
+
|
| 81 |
+
# If shape is not cached, create a new causal mask and cache it
|
| 82 |
+
input_shape = (batch_size, query_length)
|
| 83 |
+
past_key_values_length = key_value_length - query_length
|
| 84 |
+
|
| 85 |
+
# create causal mask
|
| 86 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 87 |
+
causal_4d_mask = None
|
| 88 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
| 89 |
+
causal_4d_mask = self._make_causal_mask(
|
| 90 |
+
input_shape,
|
| 91 |
+
dtype,
|
| 92 |
+
device=device,
|
| 93 |
+
past_key_values_length=past_key_values_length,
|
| 94 |
+
sliding_window=self.sliding_window,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return causal_4d_mask
|
| 98 |
+
|
| 99 |
+
def to_4d(
|
| 100 |
+
self,
|
| 101 |
+
attention_mask_2d: torch.Tensor,
|
| 102 |
+
query_length: int,
|
| 103 |
+
dtype: torch.dtype,
|
| 104 |
+
key_value_length: Optional[int] = None,
|
| 105 |
+
) -> torch.Tensor:
|
| 106 |
+
"""
|
| 107 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
| 108 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
| 109 |
+
causal, a causal mask will be added.
|
| 110 |
+
"""
|
| 111 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
| 112 |
+
|
| 113 |
+
# create causal mask
|
| 114 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 115 |
+
causal_4d_mask = None
|
| 116 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
| 117 |
+
if key_value_length is None:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
past_key_values_length = key_value_length - query_length
|
| 123 |
+
causal_4d_mask = self._make_causal_mask(
|
| 124 |
+
input_shape,
|
| 125 |
+
dtype,
|
| 126 |
+
device=attention_mask_2d.device,
|
| 127 |
+
past_key_values_length=past_key_values_length,
|
| 128 |
+
sliding_window=self.sliding_window,
|
| 129 |
+
)
|
| 130 |
+
elif self.sliding_window is not None:
|
| 131 |
+
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
|
| 132 |
+
|
| 133 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 134 |
+
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
|
| 135 |
+
attention_mask_2d.device
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if causal_4d_mask is not None:
|
| 139 |
+
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
|
| 140 |
+
|
| 141 |
+
# expanded_attn_mask + causal_4d_mask can cause some overflow
|
| 142 |
+
expanded_4d_mask = expanded_attn_mask
|
| 143 |
+
|
| 144 |
+
return expanded_4d_mask
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def _make_causal_mask(
|
| 148 |
+
input_ids_shape: torch.Size,
|
| 149 |
+
dtype: torch.dtype,
|
| 150 |
+
device: torch.device,
|
| 151 |
+
past_key_values_length: int = 0,
|
| 152 |
+
sliding_window: Optional[int] = None,
|
| 153 |
+
):
|
| 154 |
+
"""
|
| 155 |
+
Make causal mask used for bi-directional self-attention.
|
| 156 |
+
"""
|
| 157 |
+
bsz, tgt_len = input_ids_shape
|
| 158 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 159 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 160 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 161 |
+
|
| 162 |
+
mask = mask.to(dtype)
|
| 163 |
+
|
| 164 |
+
if past_key_values_length > 0:
|
| 165 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
| 166 |
+
|
| 167 |
+
# add lower triangular sliding window mask if necessary
|
| 168 |
+
if sliding_window is not None:
|
| 169 |
+
diagonal = past_key_values_length - sliding_window - 1
|
| 170 |
+
|
| 171 |
+
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
|
| 172 |
+
# Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
|
| 173 |
+
# See https://github.com/pytorch/pytorch/issues/127571
|
| 174 |
+
if is_torchdynamo_compiling():
|
| 175 |
+
mask = mask.clone()
|
| 176 |
+
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
|
| 177 |
+
|
| 178 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
| 179 |
+
|
| 180 |
+
@staticmethod
|
| 181 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 182 |
+
"""
|
| 183 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 184 |
+
"""
|
| 185 |
+
bsz, src_len = mask.size()
|
| 186 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 187 |
+
|
| 188 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 189 |
+
|
| 190 |
+
inverted_mask = 1.0 - expanded_mask
|
| 191 |
+
|
| 192 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def _unmask_unattended(
|
| 196 |
+
expanded_mask: torch.FloatTensor,
|
| 197 |
+
min_dtype: float,
|
| 198 |
+
):
|
| 199 |
+
# fmt: off
|
| 200 |
+
"""
|
| 201 |
+
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
| 202 |
+
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 203 |
+
Details: https://github.com/pytorch/pytorch/issues/110213
|
| 204 |
+
|
| 205 |
+
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
| 206 |
+
`attention_mask` is [bsz, src_seq_len].
|
| 207 |
+
|
| 208 |
+
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
|
| 209 |
+
|
| 210 |
+
For example, if `expanded_mask` is (e.g. here left-padding case)
|
| 211 |
+
```
|
| 212 |
+
[[[[0, 0, 0],
|
| 213 |
+
[0, 0, 0],
|
| 214 |
+
[0, 0, 1]]],
|
| 215 |
+
[[[1, 0, 0],
|
| 216 |
+
[1, 1, 0],
|
| 217 |
+
[1, 1, 1]]],
|
| 218 |
+
[[[0, 0, 0],
|
| 219 |
+
[0, 1, 0],
|
| 220 |
+
[0, 1, 1]]]]
|
| 221 |
+
```
|
| 222 |
+
then the modified `expanded_mask` will be
|
| 223 |
+
```
|
| 224 |
+
[[[[1, 1, 1], <-- modified
|
| 225 |
+
[1, 1, 1], <-- modified
|
| 226 |
+
[0, 0, 1]]],
|
| 227 |
+
[[[1, 0, 0],
|
| 228 |
+
[1, 1, 0],
|
| 229 |
+
[1, 1, 1]]],
|
| 230 |
+
[[[1, 1, 1], <-- modified
|
| 231 |
+
[0, 1, 0],
|
| 232 |
+
[0, 1, 1]]]]
|
| 233 |
+
```
|
| 234 |
+
"""
|
| 235 |
+
# fmt: on
|
| 236 |
+
if expanded_mask.dtype == torch.bool:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
|
| 242 |
+
|
| 243 |
+
@staticmethod
|
| 244 |
+
def _ignore_causal_mask_sdpa(
|
| 245 |
+
attention_mask: Optional[torch.Tensor],
|
| 246 |
+
inputs_embeds: torch.Tensor,
|
| 247 |
+
past_key_values_length: int,
|
| 248 |
+
sliding_window: Optional[int] = None,
|
| 249 |
+
is_training: bool = False,
|
| 250 |
+
) -> bool:
|
| 251 |
+
"""
|
| 252 |
+
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
|
| 253 |
+
ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
| 254 |
+
|
| 255 |
+
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
|
| 256 |
+
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
|
| 257 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
|
| 258 |
+
passed).
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
| 262 |
+
key_value_length = query_length + past_key_values_length
|
| 263 |
+
|
| 264 |
+
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
|
| 265 |
+
|
| 266 |
+
ignore_causal_mask = False
|
| 267 |
+
|
| 268 |
+
if attention_mask is None:
|
| 269 |
+
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
|
| 270 |
+
# shape, thus SDPA's `is_causal` argument is rightfully updated
|
| 271 |
+
# (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
|
| 272 |
+
# `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
|
| 273 |
+
# hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
|
| 274 |
+
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
|
| 275 |
+
# Thus, we only set `ignore_causal_mask = True` if the model is set to training.
|
| 276 |
+
#
|
| 277 |
+
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
|
| 278 |
+
# ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
|
| 279 |
+
if (
|
| 280 |
+
(is_training or not is_tracing)
|
| 281 |
+
and (query_length == 1 or key_value_length == query_length)
|
| 282 |
+
and (sliding_window is None or key_value_length < sliding_window)
|
| 283 |
+
):
|
| 284 |
+
ignore_causal_mask = True
|
| 285 |
+
elif sliding_window is None or key_value_length < sliding_window:
|
| 286 |
+
if len(attention_mask.shape) == 4:
|
| 287 |
+
return False
|
| 288 |
+
elif not is_tracing and torch.all(attention_mask == 1):
|
| 289 |
+
if query_length == 1 or key_value_length == query_length:
|
| 290 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
| 291 |
+
ignore_causal_mask = True
|
| 292 |
+
|
| 293 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
|
| 294 |
+
# the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
|
| 295 |
+
# SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
| 296 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
| 297 |
+
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
|
| 298 |
+
|
| 299 |
+
return ignore_causal_mask
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _prepare_4d_causal_attention_mask(
|
| 303 |
+
attention_mask: Optional[torch.Tensor],
|
| 304 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 305 |
+
inputs_embeds: torch.Tensor,
|
| 306 |
+
past_key_values_length: int,
|
| 307 |
+
sliding_window: Optional[int] = None,
|
| 308 |
+
):
|
| 309 |
+
"""
|
| 310 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 311 |
+
`(batch_size, key_value_length)`
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
attention_mask (`torch.Tensor` or `None`):
|
| 315 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 316 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
| 317 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
| 318 |
+
inputs_embeds (`torch.Tensor`):
|
| 319 |
+
The embedded inputs as a torch Tensor.
|
| 320 |
+
past_key_values_length (`int`):
|
| 321 |
+
The length of the key value cache.
|
| 322 |
+
sliding_window (`int`, *optional*):
|
| 323 |
+
If the model uses windowed attention, a sliding window should be passed.
|
| 324 |
+
"""
|
| 325 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
| 326 |
+
|
| 327 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 328 |
+
|
| 329 |
+
# 4d mask is passed through the layers
|
| 330 |
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
| 331 |
+
attention_mask = attn_mask_converter.to_4d(
|
| 332 |
+
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
|
| 333 |
+
)
|
| 334 |
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
| 335 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
| 336 |
+
if tuple(attention_mask.shape) != expected_shape:
|
| 337 |
+
raise ValueError(
|
| 338 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
| 339 |
+
)
|
| 340 |
+
else:
|
| 341 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
| 342 |
+
inverted_mask = 1.0 - attention_mask
|
| 343 |
+
attention_mask = inverted_mask.masked_fill(
|
| 344 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
| 345 |
+
)
|
| 346 |
+
else:
|
| 347 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
| 348 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
return attention_mask
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# Adapted from _prepare_4d_causal_attention_mask
|
| 355 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(
|
| 356 |
+
attention_mask: Optional[torch.Tensor],
|
| 357 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 358 |
+
inputs_embeds: torch.Tensor,
|
| 359 |
+
past_key_values_length: int,
|
| 360 |
+
sliding_window: Optional[int] = None,
|
| 361 |
+
):
|
| 362 |
+
"""
|
| 363 |
+
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
|
| 364 |
+
|
| 365 |
+
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
|
| 366 |
+
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
|
| 367 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
| 368 |
+
"""
|
| 369 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
| 370 |
+
|
| 371 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 372 |
+
|
| 373 |
+
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
| 374 |
+
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
| 375 |
+
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
| 376 |
+
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
|
| 377 |
+
|
| 378 |
+
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 379 |
+
attention_mask=attention_mask,
|
| 380 |
+
inputs_embeds=inputs_embeds,
|
| 381 |
+
past_key_values_length=past_key_values_length,
|
| 382 |
+
sliding_window=sliding_window,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
if ignore_causal_mask:
|
| 386 |
+
expanded_4d_mask = None
|
| 387 |
+
elif attention_mask is None:
|
| 388 |
+
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
| 389 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 390 |
+
)
|
| 391 |
+
else:
|
| 392 |
+
if attention_mask.dim() == 4:
|
| 393 |
+
expanded_4d_mask = attention_mask
|
| 394 |
+
else:
|
| 395 |
+
expanded_4d_mask = attn_mask_converter.to_4d(
|
| 396 |
+
attention_mask,
|
| 397 |
+
input_shape[-1],
|
| 398 |
+
dtype=inputs_embeds.dtype,
|
| 399 |
+
key_value_length=key_value_length,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
|
| 403 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 404 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 405 |
+
if not is_tracing and expanded_4d_mask.device.type == "cuda":
|
| 406 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
| 407 |
+
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
return expanded_4d_mask
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 414 |
+
"""
|
| 415 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 416 |
+
`(batch_size, key_value_length)`
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
mask (`torch.Tensor`):
|
| 420 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 421 |
+
dtype (`torch.dtype`):
|
| 422 |
+
The torch dtype the created mask shall have.
|
| 423 |
+
tgt_len (`int`):
|
| 424 |
+
The target length or query length the created mask shall have.
|
| 425 |
+
"""
|
| 426 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 430 |
+
"""
|
| 431 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 432 |
+
`(batch_size, key_value_length)`
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
mask (`torch.Tensor`):
|
| 436 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 437 |
+
dtype (`torch.dtype`):
|
| 438 |
+
The torch dtype the created mask shall have.
|
| 439 |
+
tgt_len (`int`):
|
| 440 |
+
The target length or query length the created mask shall have.
|
| 441 |
+
"""
|
| 442 |
+
_, key_value_length = mask.shape
|
| 443 |
+
tgt_len = tgt_len if tgt_len is not None else key_value_length
|
| 444 |
+
|
| 445 |
+
is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()
|
| 446 |
+
|
| 447 |
+
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
|
| 448 |
+
if not is_tracing and torch.all(mask == 1):
|
| 449 |
+
return None
|
| 450 |
+
else:
|
| 451 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def _create_4d_causal_attention_mask(
|
| 455 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 456 |
+
dtype: torch.dtype,
|
| 457 |
+
device: torch.device,
|
| 458 |
+
past_key_values_length: int = 0,
|
| 459 |
+
sliding_window: Optional[int] = None,
|
| 460 |
+
) -> Optional[torch.Tensor]:
|
| 461 |
+
"""
|
| 462 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
| 466 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
| 467 |
+
dtype (`torch.dtype`):
|
| 468 |
+
The torch dtype the created mask shall have.
|
| 469 |
+
device (`int`):
|
| 470 |
+
The torch device the created mask shall have.
|
| 471 |
+
sliding_window (`int`, *optional*):
|
| 472 |
+
If the model uses windowed attention, a sliding window should be passed.
|
| 473 |
+
"""
|
| 474 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
| 475 |
+
|
| 476 |
+
key_value_length = past_key_values_length + input_shape[-1]
|
| 477 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
| 478 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
return attention_mask
|