Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- docs/resources/grpo_clevr_count.png +3 -0
- docs/resources/grpo_countdown_1.png +3 -0
- docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py +172 -0
- docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py +324 -0
- docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py +807 -0
- docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py +1198 -0
- docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py +114 -0
- docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py +1058 -0
- docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py +410 -0
- docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py +157 -0
- docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py +181 -0
- docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +1337 -0
- docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py +518 -0
- docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx +107 -0
- docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py +134 -0
- docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py +220 -0
- docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py +911 -0
- docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py +613 -0
- docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py +279 -0
- docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py +0 -0
- docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py +407 -0
- docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py +29 -0
- docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py +184 -0
- docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py +291 -0
- docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py +791 -0
- docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py +146 -0
- docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py +243 -0
- docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py +298 -0
- docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py +0 -0
- docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py +45 -0
- docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py +410 -0
- docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py +148 -0
- docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py +953 -0
- docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py +105 -0
- docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py +1697 -0
- docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py +169 -0
- docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py +975 -0
- docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py +123 -0
- docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +1749 -0
.gitattributes
CHANGED
|
@@ -53,3 +53,5 @@ docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text
|
|
| 53 |
docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 53 |
docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text
|
docs/resources/grpo_clevr_count.png
ADDED
|
Git LFS Details
|
docs/resources/grpo_countdown_1.png
ADDED
|
Git LFS Details
|
docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""EfficientFormer model configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
from ....configuration_utils import PretrainedConfig
|
| 20 |
+
from ....utils import logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class EfficientFormerConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of an [`EfficientFormerModel`]. It is used to
|
| 29 |
+
instantiate an EfficientFormer model according to the specified arguments, defining the model architecture.
|
| 30 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the EfficientFormer
|
| 31 |
+
[snap-research/efficientformer-l1](https://huggingface.co/snap-research/efficientformer-l1) architecture.
|
| 32 |
+
|
| 33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 34 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
depths (`List(int)`, *optional*, defaults to `[3, 2, 6, 4]`)
|
| 38 |
+
Depth of each stage.
|
| 39 |
+
hidden_sizes (`List(int)`, *optional*, defaults to `[48, 96, 224, 448]`)
|
| 40 |
+
Dimensionality of each stage.
|
| 41 |
+
downsamples (`List(bool)`, *optional*, defaults to `[True, True, True, True]`)
|
| 42 |
+
Whether or not to downsample inputs between two stages.
|
| 43 |
+
dim (`int`, *optional*, defaults to 448):
|
| 44 |
+
Number of channels in Meta3D layers
|
| 45 |
+
key_dim (`int`, *optional*, defaults to 32):
|
| 46 |
+
The size of the key in meta3D block.
|
| 47 |
+
attention_ratio (`int`, *optional*, defaults to 4):
|
| 48 |
+
Ratio of the dimension of the query and value to the dimension of the key in MSHA block
|
| 49 |
+
resolution (`int`, *optional*, defaults to 7)
|
| 50 |
+
Size of each patch
|
| 51 |
+
num_hidden_layers (`int`, *optional*, defaults to 5):
|
| 52 |
+
Number of hidden layers in the Transformer encoder.
|
| 53 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
| 54 |
+
Number of attention heads for each attention layer in the 3D MetaBlock.
|
| 55 |
+
mlp_expansion_ratio (`int`, *optional*, defaults to 4):
|
| 56 |
+
Ratio of size of the hidden dimensionality of an MLP to the dimensionality of its input.
|
| 57 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 58 |
+
The dropout probability for all fully connected layers in the embeddings and encoder.
|
| 59 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 60 |
+
The size (resolution) of each patch.
|
| 61 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 62 |
+
The number of input channels.
|
| 63 |
+
pool_size (`int`, *optional*, defaults to 3):
|
| 64 |
+
Kernel size of pooling layers.
|
| 65 |
+
downsample_patch_size (`int`, *optional*, defaults to 3):
|
| 66 |
+
The size of patches in downsampling layers.
|
| 67 |
+
downsample_stride (`int`, *optional*, defaults to 2):
|
| 68 |
+
The stride of convolution kernels in downsampling layers.
|
| 69 |
+
downsample_pad (`int`, *optional*, defaults to 1):
|
| 70 |
+
Padding in downsampling layers.
|
| 71 |
+
drop_path_rate (`int`, *optional*, defaults to 0):
|
| 72 |
+
Rate at which to increase dropout probability in DropPath.
|
| 73 |
+
num_meta3d_blocks (`int`, *optional*, defaults to 1):
|
| 74 |
+
The number of 3D MetaBlocks in the last stage.
|
| 75 |
+
distillation (`bool`, *optional*, defaults to `True`):
|
| 76 |
+
Whether to add a distillation head.
|
| 77 |
+
use_layer_scale (`bool`, *optional*, defaults to `True`):
|
| 78 |
+
Whether to scale outputs from token mixers.
|
| 79 |
+
layer_scale_init_value (`float`, *optional*, defaults to 1e-5):
|
| 80 |
+
Factor by which outputs from token mixers are scaled.
|
| 81 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 82 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 83 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 84 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 85 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 86 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 87 |
+
The epsilon used by the layer normalization layers.
|
| 88 |
+
image_size (`int`, *optional*, defaults to `224`):
|
| 89 |
+
The size (resolution) of each image.
|
| 90 |
+
|
| 91 |
+
Example:
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
>>> from transformers import EfficientFormerConfig, EfficientFormerModel
|
| 95 |
+
|
| 96 |
+
>>> # Initializing a EfficientFormer efficientformer-l1 style configuration
|
| 97 |
+
>>> configuration = EfficientFormerConfig()
|
| 98 |
+
|
| 99 |
+
>>> # Initializing a EfficientFormerModel (with random weights) from the efficientformer-l3 style configuration
|
| 100 |
+
>>> model = EfficientFormerModel(configuration)
|
| 101 |
+
|
| 102 |
+
>>> # Accessing the model configuration
|
| 103 |
+
>>> configuration = model.config
|
| 104 |
+
```"""
|
| 105 |
+
|
| 106 |
+
model_type = "efficientformer"
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
depths: List[int] = [3, 2, 6, 4],
|
| 111 |
+
hidden_sizes: List[int] = [48, 96, 224, 448],
|
| 112 |
+
downsamples: List[bool] = [True, True, True, True],
|
| 113 |
+
dim: int = 448,
|
| 114 |
+
key_dim: int = 32,
|
| 115 |
+
attention_ratio: int = 4,
|
| 116 |
+
resolution: int = 7,
|
| 117 |
+
num_hidden_layers: int = 5,
|
| 118 |
+
num_attention_heads: int = 8,
|
| 119 |
+
mlp_expansion_ratio: int = 4,
|
| 120 |
+
hidden_dropout_prob: float = 0.0,
|
| 121 |
+
patch_size: int = 16,
|
| 122 |
+
num_channels: int = 3,
|
| 123 |
+
pool_size: int = 3,
|
| 124 |
+
downsample_patch_size: int = 3,
|
| 125 |
+
downsample_stride: int = 2,
|
| 126 |
+
downsample_pad: int = 1,
|
| 127 |
+
drop_path_rate: float = 0.0,
|
| 128 |
+
num_meta3d_blocks: int = 1,
|
| 129 |
+
distillation: bool = True,
|
| 130 |
+
use_layer_scale: bool = True,
|
| 131 |
+
layer_scale_init_value: float = 1e-5,
|
| 132 |
+
hidden_act: str = "gelu",
|
| 133 |
+
initializer_range: float = 0.02,
|
| 134 |
+
layer_norm_eps: float = 1e-12,
|
| 135 |
+
image_size: int = 224,
|
| 136 |
+
batch_norm_eps: float = 1e-05,
|
| 137 |
+
**kwargs,
|
| 138 |
+
) -> None:
|
| 139 |
+
super().__init__(**kwargs)
|
| 140 |
+
|
| 141 |
+
self.hidden_act = hidden_act
|
| 142 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 143 |
+
self.hidden_sizes = hidden_sizes
|
| 144 |
+
self.num_hidden_layers = num_hidden_layers
|
| 145 |
+
self.num_attention_heads = num_attention_heads
|
| 146 |
+
self.initializer_range = initializer_range
|
| 147 |
+
self.layer_norm_eps = layer_norm_eps
|
| 148 |
+
self.patch_size = patch_size
|
| 149 |
+
self.num_channels = num_channels
|
| 150 |
+
self.depths = depths
|
| 151 |
+
self.mlp_expansion_ratio = mlp_expansion_ratio
|
| 152 |
+
self.downsamples = downsamples
|
| 153 |
+
self.dim = dim
|
| 154 |
+
self.key_dim = key_dim
|
| 155 |
+
self.attention_ratio = attention_ratio
|
| 156 |
+
self.resolution = resolution
|
| 157 |
+
self.pool_size = pool_size
|
| 158 |
+
self.downsample_patch_size = downsample_patch_size
|
| 159 |
+
self.downsample_stride = downsample_stride
|
| 160 |
+
self.downsample_pad = downsample_pad
|
| 161 |
+
self.drop_path_rate = drop_path_rate
|
| 162 |
+
self.num_meta3d_blocks = num_meta3d_blocks
|
| 163 |
+
self.distillation = distillation
|
| 164 |
+
self.use_layer_scale = use_layer_scale
|
| 165 |
+
self.layer_scale_init_value = layer_scale_init_value
|
| 166 |
+
self.image_size = image_size
|
| 167 |
+
self.batch_norm_eps = batch_norm_eps
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
__all__ = [
|
| 171 |
+
"EfficientFormerConfig",
|
| 172 |
+
]
|
docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for EfficientFormer."""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ....image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 22 |
+
from ....image_transforms import (
|
| 23 |
+
get_resize_output_image_size,
|
| 24 |
+
resize,
|
| 25 |
+
to_channel_dimension_format,
|
| 26 |
+
)
|
| 27 |
+
from ....image_utils import (
|
| 28 |
+
IMAGENET_DEFAULT_MEAN,
|
| 29 |
+
IMAGENET_DEFAULT_STD,
|
| 30 |
+
ChannelDimension,
|
| 31 |
+
ImageInput,
|
| 32 |
+
PILImageResampling,
|
| 33 |
+
infer_channel_dimension_format,
|
| 34 |
+
is_batched,
|
| 35 |
+
is_scaled_image,
|
| 36 |
+
to_numpy_array,
|
| 37 |
+
valid_images,
|
| 38 |
+
validate_kwargs,
|
| 39 |
+
validate_preprocess_arguments,
|
| 40 |
+
)
|
| 41 |
+
from ....utils import TensorType, logging
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class EfficientFormerImageProcessor(BaseImageProcessor):
|
| 48 |
+
r"""
|
| 49 |
+
Constructs a EfficientFormer image processor.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 53 |
+
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
|
| 54 |
+
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
|
| 55 |
+
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 56 |
+
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
| 57 |
+
method.
|
| 58 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 59 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 60 |
+
`preprocess` method.
|
| 61 |
+
do_center_crop (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
|
| 63 |
+
`preprocess` method.
|
| 64 |
+
crop_size (`Dict[str, int]` *optional*, defaults to 224):
|
| 65 |
+
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
|
| 66 |
+
method.
|
| 67 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 69 |
+
parameter in the `preprocess` method.
|
| 70 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 71 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 72 |
+
`preprocess` method.
|
| 73 |
+
do_normalize:
|
| 74 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 75 |
+
method.
|
| 76 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 77 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 78 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 79 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 80 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 81 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
model_input_names = ["pixel_values"]
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
do_resize: bool = True,
|
| 89 |
+
size: Optional[Dict[str, int]] = None,
|
| 90 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 91 |
+
do_center_crop: bool = True,
|
| 92 |
+
do_rescale: bool = True,
|
| 93 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 94 |
+
crop_size: Dict[str, int] = None,
|
| 95 |
+
do_normalize: bool = True,
|
| 96 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 97 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 98 |
+
**kwargs,
|
| 99 |
+
) -> None:
|
| 100 |
+
super().__init__(**kwargs)
|
| 101 |
+
size = size if size is not None else {"height": 224, "width": 224}
|
| 102 |
+
size = get_size_dict(size)
|
| 103 |
+
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
| 104 |
+
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
|
| 105 |
+
|
| 106 |
+
self.do_resize = do_resize
|
| 107 |
+
self.do_rescale = do_rescale
|
| 108 |
+
self.do_normalize = do_normalize
|
| 109 |
+
self.do_center_crop = do_center_crop
|
| 110 |
+
self.crop_size = crop_size
|
| 111 |
+
self.size = size
|
| 112 |
+
self.resample = resample
|
| 113 |
+
self.rescale_factor = rescale_factor
|
| 114 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 115 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 116 |
+
self._valid_processor_keys = [
|
| 117 |
+
"images",
|
| 118 |
+
"do_resize",
|
| 119 |
+
"size",
|
| 120 |
+
"resample",
|
| 121 |
+
"do_center_crop",
|
| 122 |
+
"crop_size",
|
| 123 |
+
"do_rescale",
|
| 124 |
+
"rescale_factor",
|
| 125 |
+
"do_normalize",
|
| 126 |
+
"image_mean",
|
| 127 |
+
"image_std",
|
| 128 |
+
"return_tensors",
|
| 129 |
+
"data_format",
|
| 130 |
+
"input_data_format",
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
def resize(
|
| 134 |
+
self,
|
| 135 |
+
image: np.ndarray,
|
| 136 |
+
size: Dict[str, int],
|
| 137 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 138 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 139 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 140 |
+
**kwargs,
|
| 141 |
+
) -> np.ndarray:
|
| 142 |
+
"""
|
| 143 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
image (`np.ndarray`):
|
| 147 |
+
Image to resize.
|
| 148 |
+
size (`Dict[str, int]`):
|
| 149 |
+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
| 150 |
+
resample:
|
| 151 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 152 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 153 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 154 |
+
image is used. Can be one of:
|
| 155 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 156 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 157 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 158 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
`np.ndarray`: The resized image.
|
| 162 |
+
"""
|
| 163 |
+
size = get_size_dict(size)
|
| 164 |
+
|
| 165 |
+
if "shortest_edge" in size:
|
| 166 |
+
size = get_resize_output_image_size(
|
| 167 |
+
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
| 168 |
+
)
|
| 169 |
+
# size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
|
| 170 |
+
elif "height" in size and "width" in size:
|
| 171 |
+
size = (size["height"], size["width"])
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError(f"Size must contain 'height' and 'width' keys or 'shortest_edge' key. Got {size.keys()}")
|
| 174 |
+
return resize(
|
| 175 |
+
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def preprocess(
|
| 179 |
+
self,
|
| 180 |
+
images: ImageInput,
|
| 181 |
+
do_resize: Optional[bool] = None,
|
| 182 |
+
size: Dict[str, int] = None,
|
| 183 |
+
resample: PILImageResampling = None,
|
| 184 |
+
do_center_crop: Optional[bool] = None,
|
| 185 |
+
crop_size: Optional[int] = None,
|
| 186 |
+
do_rescale: Optional[bool] = None,
|
| 187 |
+
rescale_factor: Optional[float] = None,
|
| 188 |
+
do_normalize: Optional[bool] = None,
|
| 189 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 190 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 191 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 192 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 193 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 194 |
+
**kwargs,
|
| 195 |
+
) -> BatchFeature:
|
| 196 |
+
"""
|
| 197 |
+
Preprocess an image or batch of images.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
images (`ImageInput`):
|
| 201 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 202 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 203 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 204 |
+
Whether to resize the image.
|
| 205 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 206 |
+
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
|
| 207 |
+
resizing.
|
| 208 |
+
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
|
| 209 |
+
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
|
| 210 |
+
an effect if `do_resize` is set to `True`.
|
| 211 |
+
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
| 212 |
+
Whether to center crop the image.
|
| 213 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 214 |
+
Whether to rescale the image values between [0 - 1].
|
| 215 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 216 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 217 |
+
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
| 218 |
+
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
|
| 219 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 220 |
+
Whether to normalize the image.
|
| 221 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 222 |
+
Image mean to use if `do_normalize` is set to `True`.
|
| 223 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 224 |
+
Image standard deviation to use if `do_normalize` is set to `True`.
|
| 225 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 226 |
+
The type of tensors to return. Can be one of:
|
| 227 |
+
- Unset: Return a list of `np.ndarray`.
|
| 228 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 229 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 230 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 231 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 232 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 233 |
+
The channel dimension format for the output image. Can be one of:
|
| 234 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 235 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 236 |
+
- Unset: Use the channel dimension format of the input image.
|
| 237 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 238 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 239 |
+
from the input image. Can be one of:
|
| 240 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 241 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 242 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 243 |
+
"""
|
| 244 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 245 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 246 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 247 |
+
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
| 248 |
+
crop_size = crop_size if crop_size is not None else self.crop_size
|
| 249 |
+
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
|
| 250 |
+
resample = resample if resample is not None else self.resample
|
| 251 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 252 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 253 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 254 |
+
|
| 255 |
+
size = size if size is not None else self.size
|
| 256 |
+
size_dict = get_size_dict(size)
|
| 257 |
+
|
| 258 |
+
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
| 259 |
+
|
| 260 |
+
if not is_batched(images):
|
| 261 |
+
images = [images]
|
| 262 |
+
|
| 263 |
+
if not valid_images(images):
|
| 264 |
+
raise ValueError(
|
| 265 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 266 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 267 |
+
)
|
| 268 |
+
validate_preprocess_arguments(
|
| 269 |
+
do_rescale=do_rescale,
|
| 270 |
+
rescale_factor=rescale_factor,
|
| 271 |
+
do_normalize=do_normalize,
|
| 272 |
+
image_mean=image_mean,
|
| 273 |
+
image_std=image_std,
|
| 274 |
+
do_center_crop=do_center_crop,
|
| 275 |
+
crop_size=crop_size,
|
| 276 |
+
do_resize=do_resize,
|
| 277 |
+
size=size,
|
| 278 |
+
resample=resample,
|
| 279 |
+
)
|
| 280 |
+
# All transformations expect numpy arrays.
|
| 281 |
+
images = [to_numpy_array(image) for image in images]
|
| 282 |
+
|
| 283 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 284 |
+
logger.warning_once(
|
| 285 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 286 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if input_data_format is None:
|
| 290 |
+
# We assume that all images have the same channel dimension format.
|
| 291 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 292 |
+
|
| 293 |
+
if do_resize:
|
| 294 |
+
images = [
|
| 295 |
+
self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
|
| 296 |
+
for image in images
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
if do_center_crop:
|
| 300 |
+
images = [
|
| 301 |
+
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
if do_rescale:
|
| 305 |
+
images = [
|
| 306 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 307 |
+
for image in images
|
| 308 |
+
]
|
| 309 |
+
|
| 310 |
+
if do_normalize:
|
| 311 |
+
images = [
|
| 312 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 313 |
+
for image in images
|
| 314 |
+
]
|
| 315 |
+
|
| 316 |
+
images = [
|
| 317 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
data = {"pixel_values": images}
|
| 321 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
__all__ = ["EfficientFormerImageProcessor"]
|
docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py
ADDED
|
@@ -0,0 +1,807 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Snapchat Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch EfficientFormer model."""
|
| 16 |
+
|
| 17 |
+
import itertools
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from ....activations import ACT2FN
|
| 27 |
+
from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
| 28 |
+
from ....modeling_utils import PreTrainedModel
|
| 29 |
+
from ....utils import (
|
| 30 |
+
ModelOutput,
|
| 31 |
+
add_code_sample_docstrings,
|
| 32 |
+
add_start_docstrings,
|
| 33 |
+
add_start_docstrings_to_model_forward,
|
| 34 |
+
logging,
|
| 35 |
+
)
|
| 36 |
+
from .configuration_efficientformer import EfficientFormerConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
# General docstring
|
| 42 |
+
_CONFIG_FOR_DOC = "EfficientFormerConfig"
|
| 43 |
+
|
| 44 |
+
# Base docstring
|
| 45 |
+
_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
|
| 46 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 49, 448]
|
| 47 |
+
|
| 48 |
+
# Image classification docstring
|
| 49 |
+
_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300"
|
| 50 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class EfficientFormerPatchEmbeddings(nn.Module):
|
| 54 |
+
"""
|
| 55 |
+
This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels,
|
| 56 |
+
height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride]
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.num_channels = num_channels
|
| 62 |
+
|
| 63 |
+
self.projection = nn.Conv2d(
|
| 64 |
+
num_channels,
|
| 65 |
+
embed_dim,
|
| 66 |
+
kernel_size=config.downsample_patch_size,
|
| 67 |
+
stride=config.downsample_stride,
|
| 68 |
+
padding=config.downsample_pad,
|
| 69 |
+
)
|
| 70 |
+
self.norm = nn.BatchNorm2d(embed_dim, eps=config.batch_norm_eps) if apply_norm else nn.Identity()
|
| 71 |
+
|
| 72 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 74 |
+
if num_channels != self.num_channels:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
embeddings = self.projection(pixel_values)
|
| 80 |
+
embeddings = self.norm(embeddings)
|
| 81 |
+
|
| 82 |
+
return embeddings
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class EfficientFormerSelfAttention(nn.Module):
|
| 86 |
+
def __init__(self, dim: int, key_dim: int, num_heads: int, attention_ratio: int, resolution: int):
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.num_heads = num_heads
|
| 90 |
+
self.key_dim = key_dim
|
| 91 |
+
self.attention_ratio = attention_ratio
|
| 92 |
+
self.scale = key_dim**-0.5
|
| 93 |
+
self.total_key_dim = key_dim * num_heads
|
| 94 |
+
self.expanded_key_dim = int(attention_ratio * key_dim)
|
| 95 |
+
self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads)
|
| 96 |
+
hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2
|
| 97 |
+
self.qkv = nn.Linear(dim, hidden_size)
|
| 98 |
+
self.projection = nn.Linear(self.total_expanded_key_dim, dim)
|
| 99 |
+
points = list(itertools.product(range(resolution), range(resolution)))
|
| 100 |
+
num_points = len(points)
|
| 101 |
+
attention_offsets = {}
|
| 102 |
+
idxs = []
|
| 103 |
+
for point_1 in points:
|
| 104 |
+
for point_2 in points:
|
| 105 |
+
offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1]))
|
| 106 |
+
if offset not in attention_offsets:
|
| 107 |
+
attention_offsets[offset] = len(attention_offsets)
|
| 108 |
+
idxs.append(attention_offsets[offset])
|
| 109 |
+
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
| 110 |
+
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(num_points, num_points))
|
| 111 |
+
|
| 112 |
+
@torch.no_grad()
|
| 113 |
+
def train(self, mode=True):
|
| 114 |
+
super().train(mode)
|
| 115 |
+
if mode and hasattr(self, "ab"):
|
| 116 |
+
del self.ab
|
| 117 |
+
else:
|
| 118 |
+
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
| 119 |
+
|
| 120 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
|
| 121 |
+
batch_size, sequence_length, num_channels = hidden_states.shape
|
| 122 |
+
qkv = self.qkv(hidden_states)
|
| 123 |
+
query_layer, key_layer, value_layer = qkv.reshape(batch_size, sequence_length, self.num_heads, -1).split(
|
| 124 |
+
[self.key_dim, self.key_dim, self.expanded_key_dim], dim=3
|
| 125 |
+
)
|
| 126 |
+
query_layer = query_layer.permute(0, 2, 1, 3)
|
| 127 |
+
key_layer = key_layer.permute(0, 2, 1, 3)
|
| 128 |
+
value_layer = value_layer.permute(0, 2, 1, 3)
|
| 129 |
+
|
| 130 |
+
# set `model.to(torch_device)` won't change `self.ab.device`, if there is no follow-up `train` or `eval` call.
|
| 131 |
+
# Let's do it manually here, so users won't have to do this everytime.
|
| 132 |
+
if not self.training:
|
| 133 |
+
self.ab = self.ab.to(self.attention_biases.device)
|
| 134 |
+
attention_probs = (torch.matmul(query_layer, key_layer.transpose(-2, -1))) * self.scale + (
|
| 135 |
+
self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
attention_probs = attention_probs.softmax(dim=-1)
|
| 139 |
+
|
| 140 |
+
context_layer = torch.matmul(attention_probs, value_layer).transpose(1, 2)
|
| 141 |
+
context_layer = context_layer.reshape(batch_size, sequence_length, self.total_expanded_key_dim)
|
| 142 |
+
context_layer = self.projection(context_layer)
|
| 143 |
+
|
| 144 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 145 |
+
|
| 146 |
+
return outputs
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class EfficientFormerConvStem(nn.Module):
|
| 150 |
+
def __init__(self, config: EfficientFormerConfig, out_channels: int):
|
| 151 |
+
super().__init__()
|
| 152 |
+
|
| 153 |
+
self.convolution1 = nn.Conv2d(config.num_channels, out_channels // 2, kernel_size=3, stride=2, padding=1)
|
| 154 |
+
self.batchnorm_before = nn.BatchNorm2d(out_channels // 2, eps=config.batch_norm_eps)
|
| 155 |
+
|
| 156 |
+
self.convolution2 = nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=2, padding=1)
|
| 157 |
+
self.batchnorm_after = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps)
|
| 158 |
+
|
| 159 |
+
self.activation = nn.ReLU()
|
| 160 |
+
|
| 161 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 162 |
+
features = self.batchnorm_before(self.convolution1(pixel_values))
|
| 163 |
+
features = self.activation(features)
|
| 164 |
+
features = self.batchnorm_after(self.convolution2(features))
|
| 165 |
+
features = self.activation(features)
|
| 166 |
+
|
| 167 |
+
return features
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class EfficientFormerPooling(nn.Module):
|
| 171 |
+
def __init__(self, pool_size: int):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
|
| 174 |
+
|
| 175 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 176 |
+
output = self.pool(hidden_states) - hidden_states
|
| 177 |
+
return output
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class EfficientFormerDenseMlp(nn.Module):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
config: EfficientFormerConfig,
|
| 184 |
+
in_features: int,
|
| 185 |
+
hidden_features: Optional[int] = None,
|
| 186 |
+
out_features: Optional[int] = None,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
out_features = out_features or in_features
|
| 190 |
+
hidden_features = hidden_features or in_features
|
| 191 |
+
|
| 192 |
+
self.linear_in = nn.Linear(in_features, hidden_features)
|
| 193 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 194 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 195 |
+
self.linear_out = nn.Linear(hidden_features, out_features)
|
| 196 |
+
|
| 197 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 198 |
+
hidden_states = self.linear_in(hidden_states)
|
| 199 |
+
hidden_states = self.activation(hidden_states)
|
| 200 |
+
hidden_states = self.dropout(hidden_states)
|
| 201 |
+
hidden_states = self.linear_out(hidden_states)
|
| 202 |
+
hidden_states = self.dropout(hidden_states)
|
| 203 |
+
|
| 204 |
+
return hidden_states
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class EfficientFormerConvMlp(nn.Module):
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
config: EfficientFormerConfig,
|
| 211 |
+
in_features: int,
|
| 212 |
+
hidden_features: Optional[int] = None,
|
| 213 |
+
out_features: Optional[int] = None,
|
| 214 |
+
drop: float = 0.0,
|
| 215 |
+
):
|
| 216 |
+
super().__init__()
|
| 217 |
+
out_features = out_features or in_features
|
| 218 |
+
hidden_features = hidden_features or in_features
|
| 219 |
+
|
| 220 |
+
self.convolution1 = nn.Conv2d(in_features, hidden_features, 1)
|
| 221 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 222 |
+
self.convolution2 = nn.Conv2d(hidden_features, out_features, 1)
|
| 223 |
+
self.dropout = nn.Dropout(drop)
|
| 224 |
+
|
| 225 |
+
self.batchnorm_before = nn.BatchNorm2d(hidden_features, eps=config.batch_norm_eps)
|
| 226 |
+
self.batchnorm_after = nn.BatchNorm2d(out_features, eps=config.batch_norm_eps)
|
| 227 |
+
|
| 228 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 229 |
+
hidden_state = self.convolution1(hidden_state)
|
| 230 |
+
hidden_state = self.batchnorm_before(hidden_state)
|
| 231 |
+
|
| 232 |
+
hidden_state = self.activation(hidden_state)
|
| 233 |
+
hidden_state = self.dropout(hidden_state)
|
| 234 |
+
hidden_state = self.convolution2(hidden_state)
|
| 235 |
+
|
| 236 |
+
hidden_state = self.batchnorm_after(hidden_state)
|
| 237 |
+
hidden_state = self.dropout(hidden_state)
|
| 238 |
+
|
| 239 |
+
return hidden_state
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
| 243 |
+
"""
|
| 244 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 245 |
+
|
| 246 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 247 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 248 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 249 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 250 |
+
argument.
|
| 251 |
+
"""
|
| 252 |
+
if drop_prob == 0.0 or not training:
|
| 253 |
+
return input
|
| 254 |
+
keep_prob = 1 - drop_prob
|
| 255 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 256 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
| 257 |
+
random_tensor.floor_() # binarize
|
| 258 |
+
output = input.div(keep_prob) * random_tensor
|
| 259 |
+
return output
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class EfficientFormerDropPath(nn.Module):
|
| 263 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 264 |
+
|
| 265 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
| 266 |
+
super().__init__()
|
| 267 |
+
self.drop_prob = drop_prob
|
| 268 |
+
|
| 269 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 270 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
| 271 |
+
|
| 272 |
+
def extra_repr(self) -> str:
|
| 273 |
+
return "p={}".format(self.drop_prob)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class EfficientFormerFlat(nn.Module):
|
| 277 |
+
def __init__(self):
|
| 278 |
+
super().__init__()
|
| 279 |
+
|
| 280 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
|
| 281 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
| 282 |
+
return hidden_states
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class EfficientFormerMeta3D(nn.Module):
|
| 286 |
+
def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0):
|
| 287 |
+
super().__init__()
|
| 288 |
+
|
| 289 |
+
self.token_mixer = EfficientFormerSelfAttention(
|
| 290 |
+
dim=config.dim,
|
| 291 |
+
key_dim=config.key_dim,
|
| 292 |
+
num_heads=config.num_attention_heads,
|
| 293 |
+
attention_ratio=config.attention_ratio,
|
| 294 |
+
resolution=config.resolution,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
self.layernorm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
| 298 |
+
self.layernorm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
| 299 |
+
|
| 300 |
+
mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
|
| 301 |
+
self.mlp = EfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim)
|
| 302 |
+
|
| 303 |
+
self.drop_path = EfficientFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 304 |
+
self.use_layer_scale = config.use_layer_scale
|
| 305 |
+
if config.use_layer_scale:
|
| 306 |
+
self.layer_scale_1 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 307 |
+
self.layer_scale_2 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 308 |
+
|
| 309 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
|
| 310 |
+
self_attention_outputs = self.token_mixer(self.layernorm1(hidden_states), output_attentions)
|
| 311 |
+
attention_output = self_attention_outputs[0]
|
| 312 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 313 |
+
|
| 314 |
+
if self.use_layer_scale:
|
| 315 |
+
layer_output = hidden_states + self.drop_path(
|
| 316 |
+
self.layer_scale_1.unsqueeze(0).unsqueeze(0) * attention_output
|
| 317 |
+
)
|
| 318 |
+
layer_output = layer_output + self.drop_path(
|
| 319 |
+
self.layer_scale_2.unsqueeze(0).unsqueeze(0) * self.mlp(self.layernorm2(layer_output))
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
layer_output = hidden_states + self.drop_path(attention_output)
|
| 323 |
+
layer_output = layer_output + self.drop_path(self.mlp(self.layernorm2(layer_output)))
|
| 324 |
+
|
| 325 |
+
outputs = (layer_output,) + outputs
|
| 326 |
+
|
| 327 |
+
return outputs
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class EfficientFormerMeta3DLayers(nn.Module):
|
| 331 |
+
def __init__(self, config: EfficientFormerConfig):
|
| 332 |
+
super().__init__()
|
| 333 |
+
drop_paths = [
|
| 334 |
+
config.drop_path_rate * (block_idx + sum(config.depths[:-1]))
|
| 335 |
+
for block_idx in range(config.num_meta3d_blocks)
|
| 336 |
+
]
|
| 337 |
+
self.blocks = nn.ModuleList(
|
| 338 |
+
[EfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path) for drop_path in drop_paths]
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
|
| 342 |
+
all_attention_outputs = () if output_attentions else None
|
| 343 |
+
|
| 344 |
+
for layer_module in self.blocks:
|
| 345 |
+
if isinstance(hidden_states, tuple):
|
| 346 |
+
hidden_states = hidden_states[0]
|
| 347 |
+
|
| 348 |
+
hidden_states = layer_module(hidden_states, output_attentions)
|
| 349 |
+
|
| 350 |
+
if output_attentions:
|
| 351 |
+
all_attention_outputs = all_attention_outputs + (hidden_states[1],)
|
| 352 |
+
|
| 353 |
+
if output_attentions:
|
| 354 |
+
outputs = (hidden_states[0],) + all_attention_outputs
|
| 355 |
+
return outputs
|
| 356 |
+
|
| 357 |
+
return hidden_states
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class EfficientFormerMeta4D(nn.Module):
|
| 361 |
+
def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0):
|
| 362 |
+
super().__init__()
|
| 363 |
+
pool_size = config.pool_size if config.pool_size is not None else 3
|
| 364 |
+
self.token_mixer = EfficientFormerPooling(pool_size=pool_size)
|
| 365 |
+
mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
|
| 366 |
+
self.mlp = EfficientFormerConvMlp(
|
| 367 |
+
config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
self.drop_path = EfficientFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 371 |
+
self.use_layer_scale = config.use_layer_scale
|
| 372 |
+
if config.use_layer_scale:
|
| 373 |
+
self.layer_scale_1 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 374 |
+
self.layer_scale_2 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 375 |
+
|
| 376 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
|
| 377 |
+
outputs = self.token_mixer(hidden_states)
|
| 378 |
+
|
| 379 |
+
if self.use_layer_scale:
|
| 380 |
+
layer_output = hidden_states + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * outputs)
|
| 381 |
+
|
| 382 |
+
layer_output = layer_output + self.drop_path(
|
| 383 |
+
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(layer_output)
|
| 384 |
+
)
|
| 385 |
+
else:
|
| 386 |
+
layer_output = hidden_states + self.drop_path(outputs)
|
| 387 |
+
layer_output = layer_output + self.drop_path(self.mlp(layer_output))
|
| 388 |
+
|
| 389 |
+
return layer_output
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class EfficientFormerMeta4DLayers(nn.Module):
|
| 393 |
+
def __init__(self, config: EfficientFormerConfig, stage_idx: int):
|
| 394 |
+
super().__init__()
|
| 395 |
+
num_layers = (
|
| 396 |
+
config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks
|
| 397 |
+
)
|
| 398 |
+
drop_paths = [
|
| 399 |
+
config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers)
|
| 400 |
+
]
|
| 401 |
+
|
| 402 |
+
self.blocks = nn.ModuleList(
|
| 403 |
+
[
|
| 404 |
+
EfficientFormerMeta4D(config, config.hidden_sizes[stage_idx], drop_path=drop_path)
|
| 405 |
+
for drop_path in drop_paths
|
| 406 |
+
]
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
|
| 410 |
+
for layer_module in self.blocks:
|
| 411 |
+
hidden_states = layer_module(hidden_states)
|
| 412 |
+
return hidden_states
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class EfficientFormerIntermediateStage(nn.Module):
|
| 416 |
+
def __init__(self, config: EfficientFormerConfig, index: int):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.meta4D_layers = EfficientFormerMeta4DLayers(config, index)
|
| 419 |
+
|
| 420 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
|
| 421 |
+
hidden_states = self.meta4D_layers(hidden_states)
|
| 422 |
+
return hidden_states
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class EfficientFormerLastStage(nn.Module):
|
| 426 |
+
def __init__(self, config: EfficientFormerConfig):
|
| 427 |
+
super().__init__()
|
| 428 |
+
self.meta4D_layers = EfficientFormerMeta4DLayers(config, -1)
|
| 429 |
+
self.flat = EfficientFormerFlat()
|
| 430 |
+
self.meta3D_layers = EfficientFormerMeta3DLayers(config)
|
| 431 |
+
|
| 432 |
+
def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
|
| 433 |
+
hidden_states = self.meta4D_layers(hidden_states)
|
| 434 |
+
hidden_states = self.flat(hidden_states)
|
| 435 |
+
hidden_states = self.meta3D_layers(hidden_states, output_attentions)
|
| 436 |
+
|
| 437 |
+
return hidden_states
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class EfficientFormerEncoder(nn.Module):
|
| 441 |
+
def __init__(self, config: EfficientFormerConfig):
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.config = config
|
| 444 |
+
num_intermediate_stages = len(config.depths) - 1
|
| 445 |
+
downsamples = [
|
| 446 |
+
config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1]
|
| 447 |
+
for i in range(num_intermediate_stages)
|
| 448 |
+
]
|
| 449 |
+
intermediate_stages = []
|
| 450 |
+
|
| 451 |
+
for i in range(num_intermediate_stages):
|
| 452 |
+
intermediate_stages.append(EfficientFormerIntermediateStage(config, i))
|
| 453 |
+
if downsamples[i]:
|
| 454 |
+
intermediate_stages.append(
|
| 455 |
+
EfficientFormerPatchEmbeddings(config, config.hidden_sizes[i], config.hidden_sizes[i + 1])
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
self.intermediate_stages = nn.ModuleList(intermediate_stages)
|
| 459 |
+
self.last_stage = EfficientFormerLastStage(config)
|
| 460 |
+
|
| 461 |
+
def forward(
|
| 462 |
+
self,
|
| 463 |
+
hidden_states: torch.Tensor,
|
| 464 |
+
output_hidden_states: bool = False,
|
| 465 |
+
output_attentions: bool = False,
|
| 466 |
+
return_dict: bool = True,
|
| 467 |
+
) -> BaseModelOutput:
|
| 468 |
+
all_hidden_states = () if output_hidden_states else None
|
| 469 |
+
all_self_attentions = () if output_attentions else None
|
| 470 |
+
|
| 471 |
+
if output_hidden_states:
|
| 472 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 473 |
+
|
| 474 |
+
for layer_module in self.intermediate_stages:
|
| 475 |
+
hidden_states = layer_module(hidden_states)
|
| 476 |
+
if output_hidden_states:
|
| 477 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 478 |
+
|
| 479 |
+
layer_output = self.last_stage(hidden_states, output_attentions=output_attentions)
|
| 480 |
+
|
| 481 |
+
if output_attentions:
|
| 482 |
+
all_self_attentions = all_self_attentions + layer_output[1:]
|
| 483 |
+
|
| 484 |
+
if output_hidden_states:
|
| 485 |
+
all_hidden_states = all_hidden_states + (layer_output[0],)
|
| 486 |
+
|
| 487 |
+
if not return_dict:
|
| 488 |
+
return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None)
|
| 489 |
+
|
| 490 |
+
return BaseModelOutput(
|
| 491 |
+
last_hidden_state=layer_output[0],
|
| 492 |
+
hidden_states=all_hidden_states,
|
| 493 |
+
attentions=all_self_attentions,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class EfficientFormerPreTrainedModel(PreTrainedModel):
|
| 498 |
+
"""
|
| 499 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 500 |
+
models.
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
config_class = EfficientFormerConfig
|
| 504 |
+
base_model_prefix = "efficientformer"
|
| 505 |
+
main_input_name = "pixel_values"
|
| 506 |
+
supports_gradient_checkpointing = False
|
| 507 |
+
|
| 508 |
+
def _init_weights(self, module: nn.Module):
|
| 509 |
+
"""Initialize the weights"""
|
| 510 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 511 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 512 |
+
if module.bias is not None:
|
| 513 |
+
module.bias.data.zero_()
|
| 514 |
+
elif isinstance(module, nn.LayerNorm):
|
| 515 |
+
module.bias.data.zero_()
|
| 516 |
+
module.weight.data.fill_(1.0)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
EFFICIENTFORMER_START_DOCSTRING = r"""
|
| 520 |
+
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) subclass. Use it as a
|
| 521 |
+
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
|
| 522 |
+
|
| 523 |
+
Parameters:
|
| 524 |
+
config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model.
|
| 525 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 526 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
EFFICIENTFORMER_INPUTS_DOCSTRING = r"""
|
| 530 |
+
Args:
|
| 531 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 532 |
+
Pixel values. Pixel values can be obtained using [`ViTImageProcessor`]. See
|
| 533 |
+
[`ViTImageProcessor.preprocess`] for details.
|
| 534 |
+
output_attentions (`bool`, *optional*):
|
| 535 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 536 |
+
tensors for more detail.
|
| 537 |
+
output_hidden_states (`bool`, *optional*):
|
| 538 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 539 |
+
more detail.
|
| 540 |
+
return_dict (`bool`, *optional*):
|
| 541 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 542 |
+
"""
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
@add_start_docstrings(
|
| 546 |
+
"The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.",
|
| 547 |
+
EFFICIENTFORMER_START_DOCSTRING,
|
| 548 |
+
)
|
| 549 |
+
class EfficientFormerModel(EfficientFormerPreTrainedModel):
|
| 550 |
+
def __init__(self, config: EfficientFormerConfig):
|
| 551 |
+
super().__init__(config)
|
| 552 |
+
self.config = config
|
| 553 |
+
_no_split_modules = ["EfficientFormerMeta4D"]
|
| 554 |
+
|
| 555 |
+
self.patch_embed = EfficientFormerConvStem(config, config.hidden_sizes[0])
|
| 556 |
+
self.encoder = EfficientFormerEncoder(config)
|
| 557 |
+
self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
|
| 558 |
+
|
| 559 |
+
# Initialize weights and apply final processing
|
| 560 |
+
self.post_init()
|
| 561 |
+
|
| 562 |
+
@add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
|
| 563 |
+
@add_code_sample_docstrings(
|
| 564 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 565 |
+
output_type=BaseModelOutputWithPooling,
|
| 566 |
+
config_class=_CONFIG_FOR_DOC,
|
| 567 |
+
modality="vision",
|
| 568 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 569 |
+
)
|
| 570 |
+
def forward(
|
| 571 |
+
self,
|
| 572 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 573 |
+
output_attentions: Optional[bool] = None,
|
| 574 |
+
output_hidden_states: Optional[bool] = None,
|
| 575 |
+
return_dict: Optional[bool] = None,
|
| 576 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 577 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 578 |
+
output_hidden_states = (
|
| 579 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 580 |
+
)
|
| 581 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 582 |
+
|
| 583 |
+
if pixel_values is None:
|
| 584 |
+
raise ValueError("You have to specify pixel_values")
|
| 585 |
+
|
| 586 |
+
embedding_output = self.patch_embed(pixel_values)
|
| 587 |
+
encoder_outputs = self.encoder(
|
| 588 |
+
embedding_output, output_attentions=output_attentions, output_hidden_states=output_hidden_states
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
sequence_output = encoder_outputs[0]
|
| 592 |
+
sequence_output = self.layernorm(sequence_output)
|
| 593 |
+
|
| 594 |
+
if not return_dict:
|
| 595 |
+
head_outputs = (sequence_output,)
|
| 596 |
+
return head_outputs + encoder_outputs[1:]
|
| 597 |
+
|
| 598 |
+
return BaseModelOutput(
|
| 599 |
+
last_hidden_state=sequence_output,
|
| 600 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 601 |
+
attentions=encoder_outputs.attentions,
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
@add_start_docstrings(
|
| 606 |
+
"""
|
| 607 |
+
EfficientFormer Model transformer with an image classification head on top (a linear layer on top of the final
|
| 608 |
+
hidden state of the [CLS] token) e.g. for ImageNet.
|
| 609 |
+
""",
|
| 610 |
+
EFFICIENTFORMER_START_DOCSTRING,
|
| 611 |
+
)
|
| 612 |
+
class EfficientFormerForImageClassification(EfficientFormerPreTrainedModel):
|
| 613 |
+
def __init__(self, config: EfficientFormerConfig):
|
| 614 |
+
super().__init__(config)
|
| 615 |
+
|
| 616 |
+
self.num_labels = config.num_labels
|
| 617 |
+
self.efficientformer = EfficientFormerModel(config)
|
| 618 |
+
|
| 619 |
+
# Classifier head
|
| 620 |
+
self.classifier = (
|
| 621 |
+
nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
# Initialize weights and apply final processing
|
| 625 |
+
self.post_init()
|
| 626 |
+
|
| 627 |
+
@add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
|
| 628 |
+
@add_code_sample_docstrings(
|
| 629 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 630 |
+
output_type=ImageClassifierOutput,
|
| 631 |
+
config_class=_CONFIG_FOR_DOC,
|
| 632 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 633 |
+
)
|
| 634 |
+
def forward(
|
| 635 |
+
self,
|
| 636 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 637 |
+
labels: Optional[torch.Tensor] = None,
|
| 638 |
+
output_attentions: Optional[bool] = None,
|
| 639 |
+
output_hidden_states: Optional[bool] = None,
|
| 640 |
+
return_dict: Optional[bool] = None,
|
| 641 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
| 642 |
+
r"""
|
| 643 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 644 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 645 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 646 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 647 |
+
"""
|
| 648 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 649 |
+
|
| 650 |
+
outputs = self.efficientformer(
|
| 651 |
+
pixel_values,
|
| 652 |
+
output_attentions=output_attentions,
|
| 653 |
+
output_hidden_states=output_hidden_states,
|
| 654 |
+
return_dict=return_dict,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
sequence_output = outputs[0]
|
| 658 |
+
|
| 659 |
+
logits = self.classifier(sequence_output.mean(-2))
|
| 660 |
+
|
| 661 |
+
loss = None
|
| 662 |
+
if labels is not None:
|
| 663 |
+
if self.config.problem_type is None:
|
| 664 |
+
if self.num_labels == 1:
|
| 665 |
+
self.config.problem_type = "regression"
|
| 666 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 667 |
+
self.config.problem_type = "single_label_classification"
|
| 668 |
+
else:
|
| 669 |
+
self.config.problem_type = "multi_label_classification"
|
| 670 |
+
|
| 671 |
+
if self.config.problem_type == "regression":
|
| 672 |
+
loss_fct = MSELoss()
|
| 673 |
+
if self.num_labels == 1:
|
| 674 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 675 |
+
else:
|
| 676 |
+
loss = loss_fct(logits, labels)
|
| 677 |
+
elif self.config.problem_type == "single_label_classification":
|
| 678 |
+
loss_fct = CrossEntropyLoss()
|
| 679 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 680 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 681 |
+
loss_fct = BCEWithLogitsLoss()
|
| 682 |
+
loss = loss_fct(logits, labels)
|
| 683 |
+
|
| 684 |
+
if not return_dict:
|
| 685 |
+
output = (logits,) + outputs[1:]
|
| 686 |
+
return ((loss,) + output) if loss is not None else output
|
| 687 |
+
|
| 688 |
+
return ImageClassifierOutput(
|
| 689 |
+
loss=loss,
|
| 690 |
+
logits=logits,
|
| 691 |
+
hidden_states=outputs.hidden_states,
|
| 692 |
+
attentions=outputs.attentions,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
@dataclass
|
| 697 |
+
class EfficientFormerForImageClassificationWithTeacherOutput(ModelOutput):
|
| 698 |
+
"""
|
| 699 |
+
Output type of [`EfficientFormerForImageClassificationWithTeacher`].
|
| 700 |
+
|
| 701 |
+
Args:
|
| 702 |
+
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
| 703 |
+
Prediction scores as the average of the cls_logits and distillation logits.
|
| 704 |
+
cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
| 705 |
+
Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
|
| 706 |
+
class token).
|
| 707 |
+
distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
| 708 |
+
Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
|
| 709 |
+
distillation token).
|
| 710 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 711 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 712 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 713 |
+
plus the initial embedding outputs.
|
| 714 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 715 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 716 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 717 |
+
the self-attention heads.
|
| 718 |
+
"""
|
| 719 |
+
|
| 720 |
+
logits: Optional[torch.FloatTensor] = None
|
| 721 |
+
cls_logits: Optional[torch.FloatTensor] = None
|
| 722 |
+
distillation_logits: Optional[torch.FloatTensor] = None
|
| 723 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 724 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
@add_start_docstrings(
|
| 728 |
+
"""
|
| 729 |
+
EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden
|
| 730 |
+
state of the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for
|
| 731 |
+
ImageNet.
|
| 732 |
+
|
| 733 |
+
<Tip warning={true}>
|
| 734 |
+
|
| 735 |
+
This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
|
| 736 |
+
supported.
|
| 737 |
+
|
| 738 |
+
</Tip>
|
| 739 |
+
""",
|
| 740 |
+
EFFICIENTFORMER_START_DOCSTRING,
|
| 741 |
+
)
|
| 742 |
+
class EfficientFormerForImageClassificationWithTeacher(EfficientFormerPreTrainedModel):
|
| 743 |
+
def __init__(self, config: EfficientFormerConfig):
|
| 744 |
+
super().__init__(config)
|
| 745 |
+
|
| 746 |
+
self.num_labels = config.num_labels
|
| 747 |
+
self.efficientformer = EfficientFormerModel(config)
|
| 748 |
+
|
| 749 |
+
# Classifier head
|
| 750 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 751 |
+
# Distillation head
|
| 752 |
+
self.distillation_classifier = (
|
| 753 |
+
nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
# Initialize weights and apply final processing
|
| 757 |
+
self.post_init()
|
| 758 |
+
|
| 759 |
+
@add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
|
| 760 |
+
@add_code_sample_docstrings(
|
| 761 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 762 |
+
output_type=EfficientFormerForImageClassificationWithTeacherOutput,
|
| 763 |
+
config_class=_CONFIG_FOR_DOC,
|
| 764 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 765 |
+
)
|
| 766 |
+
def forward(
|
| 767 |
+
self,
|
| 768 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 769 |
+
output_attentions: Optional[bool] = None,
|
| 770 |
+
output_hidden_states: Optional[bool] = None,
|
| 771 |
+
return_dict: Optional[bool] = None,
|
| 772 |
+
) -> Union[tuple, EfficientFormerForImageClassificationWithTeacherOutput]:
|
| 773 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 774 |
+
outputs = self.efficientformer(
|
| 775 |
+
pixel_values,
|
| 776 |
+
output_attentions=output_attentions,
|
| 777 |
+
output_hidden_states=output_hidden_states,
|
| 778 |
+
return_dict=return_dict,
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
sequence_output = outputs[0]
|
| 782 |
+
|
| 783 |
+
cls_logits = self.classifier(sequence_output.mean(-2))
|
| 784 |
+
distillation_logits = self.distillation_classifier(sequence_output.mean(-2))
|
| 785 |
+
|
| 786 |
+
# during inference, return the average of both classifier predictions
|
| 787 |
+
logits = (cls_logits + distillation_logits) / 2
|
| 788 |
+
|
| 789 |
+
if not return_dict:
|
| 790 |
+
output = (logits, cls_logits, distillation_logits) + outputs[1:]
|
| 791 |
+
return output
|
| 792 |
+
|
| 793 |
+
return EfficientFormerForImageClassificationWithTeacherOutput(
|
| 794 |
+
logits=logits,
|
| 795 |
+
cls_logits=cls_logits,
|
| 796 |
+
distillation_logits=distillation_logits,
|
| 797 |
+
hidden_states=outputs.hidden_states,
|
| 798 |
+
attentions=outputs.attentions,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
__all__ = [
|
| 803 |
+
"EfficientFormerForImageClassification",
|
| 804 |
+
"EfficientFormerForImageClassificationWithTeacher",
|
| 805 |
+
"EfficientFormerModel",
|
| 806 |
+
"EfficientFormerPreTrainedModel",
|
| 807 |
+
]
|
docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py
ADDED
|
@@ -0,0 +1,1198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Snapchat Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TensorFlow EfficientFormer model."""
|
| 16 |
+
|
| 17 |
+
import itertools
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import tensorflow as tf
|
| 22 |
+
|
| 23 |
+
from ....activations_tf import ACT2FN
|
| 24 |
+
from ....modeling_tf_outputs import (
|
| 25 |
+
TFBaseModelOutput,
|
| 26 |
+
TFBaseModelOutputWithPooling,
|
| 27 |
+
TFImageClassifierOutput,
|
| 28 |
+
)
|
| 29 |
+
from ....modeling_tf_utils import (
|
| 30 |
+
TFPreTrainedModel,
|
| 31 |
+
TFSequenceClassificationLoss,
|
| 32 |
+
get_initializer,
|
| 33 |
+
keras,
|
| 34 |
+
keras_serializable,
|
| 35 |
+
unpack_inputs,
|
| 36 |
+
)
|
| 37 |
+
from ....tf_utils import shape_list, stable_softmax
|
| 38 |
+
from ....utils import (
|
| 39 |
+
ModelOutput,
|
| 40 |
+
add_code_sample_docstrings,
|
| 41 |
+
add_start_docstrings,
|
| 42 |
+
add_start_docstrings_to_model_forward,
|
| 43 |
+
logging,
|
| 44 |
+
)
|
| 45 |
+
from .configuration_efficientformer import EfficientFormerConfig
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
# General docstring
|
| 51 |
+
_CONFIG_FOR_DOC = "EfficientFormerConfig"
|
| 52 |
+
|
| 53 |
+
# Base docstring
|
| 54 |
+
_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
|
| 55 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 49, 448]
|
| 56 |
+
|
| 57 |
+
# Image classification docstring
|
| 58 |
+
_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300"
|
| 59 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_281"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TFEfficientFormerPatchEmbeddings(keras.layers.Layer):
|
| 63 |
+
"""
|
| 64 |
+
This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels,
|
| 65 |
+
height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride]
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True, **kwargs
|
| 70 |
+
) -> None:
|
| 71 |
+
super().__init__(**kwargs)
|
| 72 |
+
self.num_channels = num_channels
|
| 73 |
+
|
| 74 |
+
self.padding = keras.layers.ZeroPadding2D(padding=config.downsample_pad)
|
| 75 |
+
self.projection = keras.layers.Conv2D(
|
| 76 |
+
filters=embed_dim,
|
| 77 |
+
kernel_size=config.downsample_patch_size,
|
| 78 |
+
strides=config.downsample_stride,
|
| 79 |
+
padding="valid",
|
| 80 |
+
name="projection",
|
| 81 |
+
)
|
| 82 |
+
# Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
|
| 83 |
+
self.norm = (
|
| 84 |
+
keras.layers.BatchNormalization(axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
|
| 85 |
+
if apply_norm
|
| 86 |
+
else tf.identity
|
| 87 |
+
)
|
| 88 |
+
self.embed_dim = embed_dim
|
| 89 |
+
|
| 90 |
+
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 91 |
+
tf.debugging.assert_shapes(
|
| 92 |
+
[(pixel_values, (..., None, None, self.num_channels))],
|
| 93 |
+
message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
|
| 94 |
+
)
|
| 95 |
+
embeddings = self.projection(self.padding(pixel_values))
|
| 96 |
+
embeddings = self.norm(embeddings, training=training)
|
| 97 |
+
return embeddings
|
| 98 |
+
|
| 99 |
+
def build(self, input_shape=None):
|
| 100 |
+
if self.built:
|
| 101 |
+
return
|
| 102 |
+
self.built = True
|
| 103 |
+
if getattr(self, "projection", None) is not None:
|
| 104 |
+
with tf.name_scope(self.projection.name):
|
| 105 |
+
self.projection.build([None, None, None, self.num_channels])
|
| 106 |
+
if getattr(self, "norm", None) is not None:
|
| 107 |
+
if hasattr(self.norm, "name"):
|
| 108 |
+
with tf.name_scope(self.norm.name):
|
| 109 |
+
self.norm.build([None, None, None, self.embed_dim])
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TFEfficientFormerSelfAttention(keras.layers.Layer):
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
dim: int,
|
| 116 |
+
key_dim: int,
|
| 117 |
+
num_heads: int,
|
| 118 |
+
attention_ratio: int,
|
| 119 |
+
resolution: int,
|
| 120 |
+
config: EfficientFormerConfig,
|
| 121 |
+
**kwargs,
|
| 122 |
+
):
|
| 123 |
+
super().__init__(**kwargs)
|
| 124 |
+
|
| 125 |
+
self.num_heads = num_heads
|
| 126 |
+
self.key_dim = key_dim
|
| 127 |
+
self.attention_ratio = attention_ratio
|
| 128 |
+
self.scale = key_dim**-0.5
|
| 129 |
+
self.total_key_dim = key_dim * num_heads
|
| 130 |
+
self.expanded_key_dim = int(attention_ratio * key_dim)
|
| 131 |
+
self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads)
|
| 132 |
+
hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2
|
| 133 |
+
|
| 134 |
+
self.qkv = keras.layers.Dense(
|
| 135 |
+
units=hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="qkv"
|
| 136 |
+
)
|
| 137 |
+
self.projection = keras.layers.Dense(
|
| 138 |
+
units=dim, kernel_initializer=get_initializer(config.initializer_range), name="projection"
|
| 139 |
+
)
|
| 140 |
+
self.resolution = resolution
|
| 141 |
+
self.dim = dim
|
| 142 |
+
|
| 143 |
+
def build(self, input_shape: tf.TensorShape) -> None:
|
| 144 |
+
points = list(itertools.product(range(self.resolution), range(self.resolution)))
|
| 145 |
+
num_points = len(points)
|
| 146 |
+
attention_offsets = {}
|
| 147 |
+
|
| 148 |
+
idxs = []
|
| 149 |
+
|
| 150 |
+
for point_1 in points:
|
| 151 |
+
for point_2 in points:
|
| 152 |
+
offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1]))
|
| 153 |
+
if offset not in attention_offsets:
|
| 154 |
+
attention_offsets[offset] = len(attention_offsets)
|
| 155 |
+
idxs.append(attention_offsets[offset])
|
| 156 |
+
|
| 157 |
+
self.attention_biases = self.add_weight(
|
| 158 |
+
shape=(self.num_heads, len(attention_offsets)),
|
| 159 |
+
initializer=keras.initializers.zeros(),
|
| 160 |
+
trainable=True,
|
| 161 |
+
name="attention_biases",
|
| 162 |
+
)
|
| 163 |
+
self.attention_bias_idxs = self.add_weight(
|
| 164 |
+
shape=(num_points, num_points),
|
| 165 |
+
trainable=False,
|
| 166 |
+
dtype=tf.int32,
|
| 167 |
+
name="attention_bias_idxs",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.attention_bias_idxs.assign(tf.reshape(tf.cast(idxs, dtype=tf.int32), (num_points, num_points)))
|
| 171 |
+
|
| 172 |
+
if self.built:
|
| 173 |
+
return
|
| 174 |
+
self.built = True
|
| 175 |
+
if getattr(self, "qkv", None) is not None:
|
| 176 |
+
with tf.name_scope(self.qkv.name):
|
| 177 |
+
self.qkv.build([None, None, self.dim])
|
| 178 |
+
if getattr(self, "projection", None) is not None:
|
| 179 |
+
with tf.name_scope(self.projection.name):
|
| 180 |
+
self.projection.build([None, None, self.total_expanded_key_dim])
|
| 181 |
+
|
| 182 |
+
def call(
|
| 183 |
+
self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
|
| 184 |
+
) -> Tuple[tf.Tensor]:
|
| 185 |
+
batch_size, sequence_length, *_ = shape_list(hidden_states)
|
| 186 |
+
qkv = self.qkv(inputs=hidden_states)
|
| 187 |
+
|
| 188 |
+
query_layer, key_layer, value_layer = tf.split(
|
| 189 |
+
tf.reshape(tensor=qkv, shape=(batch_size, sequence_length, self.num_heads, -1)),
|
| 190 |
+
num_or_size_splits=[self.key_dim, self.key_dim, self.expanded_key_dim],
|
| 191 |
+
axis=3,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3])
|
| 195 |
+
key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3])
|
| 196 |
+
value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3])
|
| 197 |
+
|
| 198 |
+
attention_probs = tf.matmul(query_layer, tf.transpose(key_layer, perm=[0, 1, 3, 2]))
|
| 199 |
+
scale = tf.cast(self.scale, dtype=attention_probs.dtype)
|
| 200 |
+
attention_probs = tf.multiply(attention_probs, scale)
|
| 201 |
+
|
| 202 |
+
attention_biases = tf.gather(params=self.attention_biases, indices=self.attention_bias_idxs, axis=1)
|
| 203 |
+
attention_probs = attention_probs + attention_biases
|
| 204 |
+
attention_probs = stable_softmax(logits=attention_probs, axis=-1)
|
| 205 |
+
|
| 206 |
+
context_layer = tf.matmul(attention_probs, value_layer)
|
| 207 |
+
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
| 208 |
+
|
| 209 |
+
context_layer = tf.reshape(
|
| 210 |
+
tensor=context_layer, shape=(batch_size, sequence_length, self.total_expanded_key_dim)
|
| 211 |
+
)
|
| 212 |
+
context_layer = self.projection(context_layer)
|
| 213 |
+
|
| 214 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 215 |
+
|
| 216 |
+
return outputs
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class TFEfficientFormerConvStem(keras.layers.Layer):
|
| 220 |
+
def __init__(self, config: EfficientFormerConfig, out_channels: int, **kwargs):
|
| 221 |
+
super().__init__(**kwargs)
|
| 222 |
+
|
| 223 |
+
self.padding = keras.layers.ZeroPadding2D(padding=1)
|
| 224 |
+
self.convolution1 = keras.layers.Conv2D(
|
| 225 |
+
filters=out_channels // 2, kernel_size=3, strides=2, padding="valid", name="convolution1"
|
| 226 |
+
)
|
| 227 |
+
# Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
|
| 228 |
+
self.batchnorm_before = keras.layers.BatchNormalization(
|
| 229 |
+
axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.convolution2 = keras.layers.Conv2D(
|
| 233 |
+
filters=out_channels,
|
| 234 |
+
kernel_size=3,
|
| 235 |
+
strides=2,
|
| 236 |
+
padding="valid",
|
| 237 |
+
name="convolution2",
|
| 238 |
+
)
|
| 239 |
+
# Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
|
| 240 |
+
self.batchnorm_after = keras.layers.BatchNormalization(
|
| 241 |
+
axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.activation = keras.layers.Activation(activation=keras.activations.relu, name="activation")
|
| 245 |
+
self.out_channels = out_channels
|
| 246 |
+
self.config = config
|
| 247 |
+
|
| 248 |
+
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 249 |
+
features = self.batchnorm_before(self.convolution1(self.padding(pixel_values)), training=training)
|
| 250 |
+
features = self.activation(features)
|
| 251 |
+
features = self.batchnorm_after(self.convolution2(self.padding(features)), training=training)
|
| 252 |
+
features = self.activation(features)
|
| 253 |
+
return features
|
| 254 |
+
|
| 255 |
+
def build(self, input_shape=None):
|
| 256 |
+
if self.built:
|
| 257 |
+
return
|
| 258 |
+
self.built = True
|
| 259 |
+
if getattr(self, "convolution1", None) is not None:
|
| 260 |
+
with tf.name_scope(self.convolution1.name):
|
| 261 |
+
self.convolution1.build([None, None, None, self.config.num_channels])
|
| 262 |
+
if getattr(self, "batchnorm_before", None) is not None:
|
| 263 |
+
with tf.name_scope(self.batchnorm_before.name):
|
| 264 |
+
self.batchnorm_before.build([None, None, None, self.out_channels // 2])
|
| 265 |
+
if getattr(self, "convolution2", None) is not None:
|
| 266 |
+
with tf.name_scope(self.convolution2.name):
|
| 267 |
+
self.convolution2.build([None, None, None, self.out_channels // 2])
|
| 268 |
+
if getattr(self, "batchnorm_after", None) is not None:
|
| 269 |
+
with tf.name_scope(self.batchnorm_after.name):
|
| 270 |
+
self.batchnorm_after.build([None, None, None, self.out_channels])
|
| 271 |
+
if getattr(self, "activation", None) is not None:
|
| 272 |
+
with tf.name_scope(self.activation.name):
|
| 273 |
+
self.activation.build(None)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class TFEfficientFormerPooling(keras.layers.Layer):
|
| 277 |
+
def __init__(self, pool_size: int, **kwargs):
|
| 278 |
+
super().__init__(**kwargs)
|
| 279 |
+
self.pool = keras.layers.AveragePooling2D(pool_size=pool_size, strides=1, padding="same")
|
| 280 |
+
|
| 281 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 282 |
+
output = self.pool(hidden_states)
|
| 283 |
+
output = output - hidden_states
|
| 284 |
+
return output
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class TFEfficientFormerDenseMlp(keras.layers.Layer):
|
| 288 |
+
def __init__(
|
| 289 |
+
self,
|
| 290 |
+
config: EfficientFormerConfig,
|
| 291 |
+
in_features: int,
|
| 292 |
+
hidden_features: Optional[int] = None,
|
| 293 |
+
out_features: Optional[int] = None,
|
| 294 |
+
**kwargs,
|
| 295 |
+
):
|
| 296 |
+
super().__init__(**kwargs)
|
| 297 |
+
out_features = out_features or in_features
|
| 298 |
+
hidden_features = hidden_features or in_features
|
| 299 |
+
|
| 300 |
+
self.linear_in = keras.layers.Dense(
|
| 301 |
+
units=hidden_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_in"
|
| 302 |
+
)
|
| 303 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 304 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 305 |
+
|
| 306 |
+
self.linear_out = keras.layers.Dense(
|
| 307 |
+
units=out_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_out"
|
| 308 |
+
)
|
| 309 |
+
self.hidden_features = hidden_features
|
| 310 |
+
self.in_features = in_features
|
| 311 |
+
|
| 312 |
+
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 313 |
+
hidden_states = self.linear_in(inputs=hidden_states)
|
| 314 |
+
hidden_states = self.activation(hidden_states)
|
| 315 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 316 |
+
hidden_states = self.linear_out(inputs=hidden_states)
|
| 317 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 318 |
+
|
| 319 |
+
return hidden_states
|
| 320 |
+
|
| 321 |
+
def build(self, input_shape=None):
|
| 322 |
+
if self.built:
|
| 323 |
+
return
|
| 324 |
+
self.built = True
|
| 325 |
+
if getattr(self, "linear_in", None) is not None:
|
| 326 |
+
with tf.name_scope(self.linear_in.name):
|
| 327 |
+
self.linear_in.build([None, None, self.in_features])
|
| 328 |
+
if getattr(self, "linear_out", None) is not None:
|
| 329 |
+
with tf.name_scope(self.linear_out.name):
|
| 330 |
+
self.linear_out.build([None, None, self.hidden_features])
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class TFEfficientFormerConvMlp(keras.layers.Layer):
|
| 334 |
+
def __init__(
|
| 335 |
+
self,
|
| 336 |
+
config: EfficientFormerConfig,
|
| 337 |
+
in_features: int,
|
| 338 |
+
hidden_features: Optional[int] = None,
|
| 339 |
+
out_features: Optional[int] = None,
|
| 340 |
+
drop: float = 0.0,
|
| 341 |
+
**kwargs,
|
| 342 |
+
):
|
| 343 |
+
super().__init__(**kwargs)
|
| 344 |
+
out_features = out_features or in_features
|
| 345 |
+
hidden_features = hidden_features or in_features
|
| 346 |
+
|
| 347 |
+
self.convolution1 = keras.layers.Conv2D(
|
| 348 |
+
filters=hidden_features,
|
| 349 |
+
kernel_size=1,
|
| 350 |
+
name="convolution1",
|
| 351 |
+
padding="valid",
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 355 |
+
|
| 356 |
+
self.convolution2 = keras.layers.Conv2D(
|
| 357 |
+
filters=out_features,
|
| 358 |
+
kernel_size=1,
|
| 359 |
+
name="convolution2",
|
| 360 |
+
padding="valid",
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
self.dropout = keras.layers.Dropout(rate=drop)
|
| 364 |
+
|
| 365 |
+
# Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
|
| 366 |
+
self.batchnorm_before = keras.layers.BatchNormalization(
|
| 367 |
+
axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before"
|
| 368 |
+
)
|
| 369 |
+
# Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
|
| 370 |
+
self.batchnorm_after = keras.layers.BatchNormalization(
|
| 371 |
+
axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after"
|
| 372 |
+
)
|
| 373 |
+
self.hidden_features = hidden_features
|
| 374 |
+
self.in_features = in_features
|
| 375 |
+
self.out_features = out_features
|
| 376 |
+
|
| 377 |
+
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 378 |
+
hidden_state = self.convolution1(hidden_state)
|
| 379 |
+
hidden_state = self.batchnorm_before(hidden_state, training=training)
|
| 380 |
+
hidden_state = self.activation(hidden_state)
|
| 381 |
+
hidden_state = self.dropout(hidden_state, training=training)
|
| 382 |
+
hidden_state = self.convolution2(hidden_state)
|
| 383 |
+
hidden_state = self.batchnorm_after(hidden_state, training=training)
|
| 384 |
+
hidden_state = self.dropout(hidden_state, training=training)
|
| 385 |
+
return hidden_state
|
| 386 |
+
|
| 387 |
+
def build(self, input_shape=None):
|
| 388 |
+
if self.built:
|
| 389 |
+
return
|
| 390 |
+
self.built = True
|
| 391 |
+
if getattr(self, "convolution1", None) is not None:
|
| 392 |
+
with tf.name_scope(self.convolution1.name):
|
| 393 |
+
self.convolution1.build([None, None, None, self.in_features])
|
| 394 |
+
if getattr(self, "convolution2", None) is not None:
|
| 395 |
+
with tf.name_scope(self.convolution2.name):
|
| 396 |
+
self.convolution2.build([None, None, None, self.hidden_features])
|
| 397 |
+
if getattr(self, "batchnorm_before", None) is not None:
|
| 398 |
+
with tf.name_scope(self.batchnorm_before.name):
|
| 399 |
+
self.batchnorm_before.build([None, None, None, self.hidden_features])
|
| 400 |
+
if getattr(self, "batchnorm_after", None) is not None:
|
| 401 |
+
with tf.name_scope(self.batchnorm_after.name):
|
| 402 |
+
self.batchnorm_after.build([None, None, None, self.out_features])
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->EfficientFormer
|
| 406 |
+
class TFEfficientFormerDropPath(keras.layers.Layer):
|
| 407 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 408 |
+
References:
|
| 409 |
+
(1) github.com:rwightman/pytorch-image-models
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
def __init__(self, drop_path: float, **kwargs):
|
| 413 |
+
super().__init__(**kwargs)
|
| 414 |
+
self.drop_path = drop_path
|
| 415 |
+
|
| 416 |
+
def call(self, x: tf.Tensor, training=None):
|
| 417 |
+
if training:
|
| 418 |
+
keep_prob = 1 - self.drop_path
|
| 419 |
+
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
| 420 |
+
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
|
| 421 |
+
random_tensor = tf.floor(random_tensor)
|
| 422 |
+
return (x / keep_prob) * random_tensor
|
| 423 |
+
return x
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class TFEfficientFormerFlat(keras.layers.Layer):
|
| 427 |
+
def __init__(self, **kwargs):
|
| 428 |
+
super().__init__(**kwargs)
|
| 429 |
+
|
| 430 |
+
def call(self, hidden_states: tf.Tensor) -> Tuple[tf.Tensor]:
|
| 431 |
+
batch_size, _, _, in_channels = shape_list(hidden_states)
|
| 432 |
+
hidden_states = tf.reshape(hidden_states, shape=[batch_size, -1, in_channels])
|
| 433 |
+
return hidden_states
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class TFEfficientFormerMeta3D(keras.layers.Layer):
|
| 437 |
+
def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs):
|
| 438 |
+
super().__init__(**kwargs)
|
| 439 |
+
|
| 440 |
+
self.token_mixer = TFEfficientFormerSelfAttention(
|
| 441 |
+
dim=config.dim,
|
| 442 |
+
key_dim=config.key_dim,
|
| 443 |
+
num_heads=config.num_attention_heads,
|
| 444 |
+
attention_ratio=config.attention_ratio,
|
| 445 |
+
resolution=config.resolution,
|
| 446 |
+
name="token_mixer",
|
| 447 |
+
config=config,
|
| 448 |
+
)
|
| 449 |
+
self.dim = dim
|
| 450 |
+
self.config = config
|
| 451 |
+
|
| 452 |
+
self.layernorm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm1")
|
| 453 |
+
self.layernorm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm2")
|
| 454 |
+
mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
|
| 455 |
+
self.mlp = TFEfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim, name="mlp")
|
| 456 |
+
|
| 457 |
+
# Using `layers.Activation` instead of `tf.identity` to better control `training' behavior.
|
| 458 |
+
self.drop_path = (
|
| 459 |
+
TFEfficientFormerDropPath(drop_path)
|
| 460 |
+
if drop_path > 0.0
|
| 461 |
+
else keras.layers.Activation("linear", name="drop_path")
|
| 462 |
+
)
|
| 463 |
+
self.config = config
|
| 464 |
+
|
| 465 |
+
def build(self, input_shape=None):
|
| 466 |
+
self.layer_scale_1 = None
|
| 467 |
+
self.layer_scale_2 = None
|
| 468 |
+
|
| 469 |
+
if self.config.use_layer_scale:
|
| 470 |
+
self.layer_scale_1 = self.add_weight(
|
| 471 |
+
shape=(self.dim,),
|
| 472 |
+
initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
|
| 473 |
+
trainable=True,
|
| 474 |
+
name="layer_scale_1",
|
| 475 |
+
)
|
| 476 |
+
self.layer_scale_2 = self.add_weight(
|
| 477 |
+
shape=(self.dim,),
|
| 478 |
+
initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
|
| 479 |
+
trainable=True,
|
| 480 |
+
name="layer_scale_2",
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if self.built:
|
| 484 |
+
return
|
| 485 |
+
self.built = True
|
| 486 |
+
if getattr(self, "token_mixer", None) is not None:
|
| 487 |
+
with tf.name_scope(self.token_mixer.name):
|
| 488 |
+
self.token_mixer.build(None)
|
| 489 |
+
if getattr(self, "layernorm1", None) is not None:
|
| 490 |
+
with tf.name_scope(self.layernorm1.name):
|
| 491 |
+
self.layernorm1.build([None, None, self.dim])
|
| 492 |
+
if getattr(self, "layernorm2", None) is not None:
|
| 493 |
+
with tf.name_scope(self.layernorm2.name):
|
| 494 |
+
self.layernorm2.build([None, None, self.dim])
|
| 495 |
+
if getattr(self, "mlp", None) is not None:
|
| 496 |
+
with tf.name_scope(self.mlp.name):
|
| 497 |
+
self.mlp.build(None)
|
| 498 |
+
if getattr(self, "drop_path", None) is not None:
|
| 499 |
+
with tf.name_scope(self.drop_path.name):
|
| 500 |
+
self.drop_path.build(None)
|
| 501 |
+
|
| 502 |
+
def call(
|
| 503 |
+
self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
|
| 504 |
+
) -> Tuple[tf.Tensor]:
|
| 505 |
+
self_attention_outputs = self.token_mixer(
|
| 506 |
+
hidden_states=self.layernorm1(hidden_states, training=training),
|
| 507 |
+
output_attentions=output_attentions,
|
| 508 |
+
training=training,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
attention_output = self_attention_outputs[0]
|
| 512 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 513 |
+
|
| 514 |
+
if self.config.use_layer_scale:
|
| 515 |
+
layer_output = hidden_states + self.drop_path(
|
| 516 |
+
tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * attention_output,
|
| 517 |
+
training=training,
|
| 518 |
+
)
|
| 519 |
+
layer_output = layer_output + self.drop_path(
|
| 520 |
+
tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0)
|
| 521 |
+
* self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training),
|
| 522 |
+
training=training,
|
| 523 |
+
)
|
| 524 |
+
else:
|
| 525 |
+
layer_output = hidden_states + self.drop_path(attention_output, training=training)
|
| 526 |
+
layer_output = layer_output + self.drop_path(
|
| 527 |
+
self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training),
|
| 528 |
+
training=training,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
outputs = (layer_output,) + outputs
|
| 532 |
+
|
| 533 |
+
return outputs
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class TFEfficientFormerMeta3DLayers(keras.layers.Layer):
|
| 537 |
+
def __init__(self, config: EfficientFormerConfig, **kwargs):
|
| 538 |
+
super().__init__(**kwargs)
|
| 539 |
+
drop_paths = [
|
| 540 |
+
config.drop_path_rate * (block_idx + sum(config.depths[:-1]))
|
| 541 |
+
for block_idx in range(config.num_meta3d_blocks)
|
| 542 |
+
]
|
| 543 |
+
self.blocks = [
|
| 544 |
+
TFEfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path, name=f"blocks.{i}")
|
| 545 |
+
for i, drop_path in enumerate(drop_paths)
|
| 546 |
+
]
|
| 547 |
+
|
| 548 |
+
def call(
|
| 549 |
+
self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
|
| 550 |
+
) -> Tuple[tf.Tensor]:
|
| 551 |
+
all_attention_outputs = () if output_attentions else None
|
| 552 |
+
|
| 553 |
+
for i, layer_module in enumerate(self.blocks):
|
| 554 |
+
if isinstance(hidden_states, tuple):
|
| 555 |
+
hidden_states = hidden_states[0]
|
| 556 |
+
|
| 557 |
+
hidden_states = layer_module(
|
| 558 |
+
hidden_states=hidden_states, output_attentions=output_attentions, training=training
|
| 559 |
+
)
|
| 560 |
+
if output_attentions:
|
| 561 |
+
all_attention_outputs = all_attention_outputs + (hidden_states[1],)
|
| 562 |
+
|
| 563 |
+
if output_attentions:
|
| 564 |
+
outputs = (hidden_states[0],) + all_attention_outputs
|
| 565 |
+
return outputs
|
| 566 |
+
|
| 567 |
+
return hidden_states
|
| 568 |
+
|
| 569 |
+
def build(self, input_shape=None):
|
| 570 |
+
if self.built:
|
| 571 |
+
return
|
| 572 |
+
self.built = True
|
| 573 |
+
if getattr(self, "blocks", None) is not None:
|
| 574 |
+
for layer in self.blocks:
|
| 575 |
+
with tf.name_scope(layer.name):
|
| 576 |
+
layer.build(None)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class TFEfficientFormerMeta4D(keras.layers.Layer):
|
| 580 |
+
def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs):
|
| 581 |
+
super().__init__(**kwargs)
|
| 582 |
+
pool_size = config.pool_size if config.pool_size is not None else 3
|
| 583 |
+
self.token_mixer = TFEfficientFormerPooling(pool_size=pool_size, name="token_mixer")
|
| 584 |
+
self.dim = dim
|
| 585 |
+
mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
|
| 586 |
+
self.mlp = TFEfficientFormerConvMlp(
|
| 587 |
+
config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob, name="mlp"
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
self.drop_path = (
|
| 591 |
+
TFEfficientFormerDropPath(drop_path, name="drop_path")
|
| 592 |
+
if drop_path > 0.0
|
| 593 |
+
else keras.layers.Activation("linear", name="drop_path")
|
| 594 |
+
)
|
| 595 |
+
self.config = config
|
| 596 |
+
|
| 597 |
+
def build(self, input_shape=None):
|
| 598 |
+
self.layer_scale_1 = None
|
| 599 |
+
self.layer_scale_2 = None
|
| 600 |
+
|
| 601 |
+
if self.config.use_layer_scale:
|
| 602 |
+
self.layer_scale_1 = self.add_weight(
|
| 603 |
+
shape=(self.dim),
|
| 604 |
+
initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
|
| 605 |
+
trainable=True,
|
| 606 |
+
name="layer_scale_1",
|
| 607 |
+
)
|
| 608 |
+
self.layer_scale_2 = self.add_weight(
|
| 609 |
+
shape=(self.dim),
|
| 610 |
+
initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
|
| 611 |
+
trainable=True,
|
| 612 |
+
name="layer_scale_2",
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
if self.built:
|
| 616 |
+
return
|
| 617 |
+
self.built = True
|
| 618 |
+
if getattr(self, "token_mixer", None) is not None:
|
| 619 |
+
with tf.name_scope(self.token_mixer.name):
|
| 620 |
+
self.token_mixer.build(None)
|
| 621 |
+
if getattr(self, "mlp", None) is not None:
|
| 622 |
+
with tf.name_scope(self.mlp.name):
|
| 623 |
+
self.mlp.build(None)
|
| 624 |
+
if getattr(self, "drop_path", None) is not None:
|
| 625 |
+
with tf.name_scope(self.drop_path.name):
|
| 626 |
+
self.drop_path.build(None)
|
| 627 |
+
|
| 628 |
+
def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
|
| 629 |
+
outputs = self.token_mixer(hidden_states)
|
| 630 |
+
|
| 631 |
+
if self.config.use_layer_scale:
|
| 632 |
+
layer_output = hidden_states + self.drop_path(
|
| 633 |
+
tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * outputs,
|
| 634 |
+
training=training,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
layer_output = layer_output + self.drop_path(
|
| 638 |
+
tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0)
|
| 639 |
+
* self.mlp(hidden_state=layer_output, training=training),
|
| 640 |
+
training=training,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
else:
|
| 644 |
+
layer_output = hidden_states + self.drop_path(outputs, training=training)
|
| 645 |
+
layer_output = layer_output + self.drop_path(
|
| 646 |
+
self.mlp(hidden_state=layer_output, training=training), training=training
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
return layer_output
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class TFEfficientFormerMeta4DLayers(keras.layers.Layer):
|
| 653 |
+
def __init__(self, config: EfficientFormerConfig, stage_idx: int, **kwargs):
|
| 654 |
+
super().__init__(**kwargs)
|
| 655 |
+
num_layers = (
|
| 656 |
+
config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks
|
| 657 |
+
)
|
| 658 |
+
drop_paths = [
|
| 659 |
+
config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers)
|
| 660 |
+
]
|
| 661 |
+
|
| 662 |
+
self.blocks = [
|
| 663 |
+
TFEfficientFormerMeta4D(
|
| 664 |
+
config=config, dim=config.hidden_sizes[stage_idx], drop_path=drop_paths[i], name=f"blocks.{i}"
|
| 665 |
+
)
|
| 666 |
+
for i in range(len(drop_paths))
|
| 667 |
+
]
|
| 668 |
+
|
| 669 |
+
def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
|
| 670 |
+
for layer_module in self.blocks:
|
| 671 |
+
hidden_states = layer_module(hidden_states=hidden_states, training=training)
|
| 672 |
+
return hidden_states
|
| 673 |
+
|
| 674 |
+
def build(self, input_shape=None):
|
| 675 |
+
if self.built:
|
| 676 |
+
return
|
| 677 |
+
self.built = True
|
| 678 |
+
if getattr(self, "blocks", None) is not None:
|
| 679 |
+
for layer in self.blocks:
|
| 680 |
+
with tf.name_scope(layer.name):
|
| 681 |
+
layer.build(None)
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class TFEfficientFormerIntermediateStage(keras.layers.Layer):
|
| 685 |
+
def __init__(self, config: EfficientFormerConfig, index: int, **kwargs):
|
| 686 |
+
super().__init__(**kwargs)
|
| 687 |
+
self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=index, name="meta4D_layers")
|
| 688 |
+
|
| 689 |
+
def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
|
| 690 |
+
hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training)
|
| 691 |
+
return hidden_states
|
| 692 |
+
|
| 693 |
+
def build(self, input_shape=None):
|
| 694 |
+
if self.built:
|
| 695 |
+
return
|
| 696 |
+
self.built = True
|
| 697 |
+
if getattr(self, "meta4D_layers", None) is not None:
|
| 698 |
+
with tf.name_scope(self.meta4D_layers.name):
|
| 699 |
+
self.meta4D_layers.build(None)
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
class TFEfficientFormerLastStage(keras.layers.Layer):
|
| 703 |
+
def __init__(self, config: EfficientFormerConfig, **kwargs):
|
| 704 |
+
super().__init__(**kwargs)
|
| 705 |
+
self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=-1, name="meta4D_layers")
|
| 706 |
+
self.flat = TFEfficientFormerFlat(name="flat")
|
| 707 |
+
self.meta3D_layers = TFEfficientFormerMeta3DLayers(config, name="meta3D_layers")
|
| 708 |
+
|
| 709 |
+
def call(
|
| 710 |
+
self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
|
| 711 |
+
) -> Tuple[tf.Tensor]:
|
| 712 |
+
hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training)
|
| 713 |
+
hidden_states = self.flat(hidden_states=hidden_states)
|
| 714 |
+
hidden_states = self.meta3D_layers(
|
| 715 |
+
hidden_states=hidden_states, output_attentions=output_attentions, training=training
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
return hidden_states
|
| 719 |
+
|
| 720 |
+
def build(self, input_shape=None):
|
| 721 |
+
if self.built:
|
| 722 |
+
return
|
| 723 |
+
self.built = True
|
| 724 |
+
if getattr(self, "meta4D_layers", None) is not None:
|
| 725 |
+
with tf.name_scope(self.meta4D_layers.name):
|
| 726 |
+
self.meta4D_layers.build(None)
|
| 727 |
+
if getattr(self, "flat", None) is not None:
|
| 728 |
+
with tf.name_scope(self.flat.name):
|
| 729 |
+
self.flat.build(None)
|
| 730 |
+
if getattr(self, "meta3D_layers", None) is not None:
|
| 731 |
+
with tf.name_scope(self.meta3D_layers.name):
|
| 732 |
+
self.meta3D_layers.build(None)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
class TFEfficientFormerEncoder(keras.layers.Layer):
|
| 736 |
+
def __init__(self, config: EfficientFormerConfig, **kwargs):
|
| 737 |
+
super().__init__(**kwargs)
|
| 738 |
+
|
| 739 |
+
self.config = config
|
| 740 |
+
num_intermediate_stages = len(config.depths) - 1
|
| 741 |
+
downsamples = [
|
| 742 |
+
config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1]
|
| 743 |
+
for i in range(num_intermediate_stages)
|
| 744 |
+
]
|
| 745 |
+
|
| 746 |
+
intermediate_stages = []
|
| 747 |
+
layer_count = -1
|
| 748 |
+
for i in range(num_intermediate_stages):
|
| 749 |
+
layer_count += 1
|
| 750 |
+
intermediate_stages.append(
|
| 751 |
+
TFEfficientFormerIntermediateStage(config, i, name=f"intermediate_stages.{layer_count}")
|
| 752 |
+
)
|
| 753 |
+
if downsamples[i]:
|
| 754 |
+
layer_count += 1
|
| 755 |
+
intermediate_stages.append(
|
| 756 |
+
TFEfficientFormerPatchEmbeddings(
|
| 757 |
+
config,
|
| 758 |
+
config.hidden_sizes[i],
|
| 759 |
+
config.hidden_sizes[i + 1],
|
| 760 |
+
name=f"intermediate_stages.{layer_count}",
|
| 761 |
+
)
|
| 762 |
+
)
|
| 763 |
+
self.intermediate_stages = intermediate_stages
|
| 764 |
+
self.last_stage = TFEfficientFormerLastStage(config, name="last_stage")
|
| 765 |
+
|
| 766 |
+
def call(
|
| 767 |
+
self,
|
| 768 |
+
hidden_states: tf.Tensor,
|
| 769 |
+
output_hidden_states: bool,
|
| 770 |
+
output_attentions: bool,
|
| 771 |
+
return_dict: bool,
|
| 772 |
+
training: bool = False,
|
| 773 |
+
) -> TFBaseModelOutput:
|
| 774 |
+
all_hidden_states = () if output_hidden_states else None
|
| 775 |
+
all_self_attentions = () if output_attentions else None
|
| 776 |
+
|
| 777 |
+
if output_hidden_states:
|
| 778 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 779 |
+
|
| 780 |
+
for layer_module in self.intermediate_stages:
|
| 781 |
+
hidden_states = layer_module(hidden_states, training=training)
|
| 782 |
+
|
| 783 |
+
if output_hidden_states:
|
| 784 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 785 |
+
|
| 786 |
+
layer_output = self.last_stage(hidden_states, output_attentions=output_attentions, training=training)
|
| 787 |
+
|
| 788 |
+
if output_attentions:
|
| 789 |
+
all_self_attentions = all_self_attentions + layer_output[1:]
|
| 790 |
+
|
| 791 |
+
if output_hidden_states:
|
| 792 |
+
all_hidden_states = all_hidden_states + (layer_output[0],)
|
| 793 |
+
|
| 794 |
+
if not return_dict:
|
| 795 |
+
return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None)
|
| 796 |
+
|
| 797 |
+
return TFBaseModelOutput(
|
| 798 |
+
last_hidden_state=layer_output[0],
|
| 799 |
+
hidden_states=all_hidden_states,
|
| 800 |
+
attentions=all_self_attentions,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
def build(self, input_shape=None):
|
| 804 |
+
if self.built:
|
| 805 |
+
return
|
| 806 |
+
self.built = True
|
| 807 |
+
if getattr(self, "last_stage", None) is not None:
|
| 808 |
+
with tf.name_scope(self.last_stage.name):
|
| 809 |
+
self.last_stage.build(None)
|
| 810 |
+
for layer in self.intermediate_stages:
|
| 811 |
+
with tf.name_scope(layer.name):
|
| 812 |
+
layer.build(None)
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
@keras_serializable
|
| 816 |
+
class TFEfficientFormerMainLayer(keras.layers.Layer):
|
| 817 |
+
config_class = EfficientFormerConfig
|
| 818 |
+
|
| 819 |
+
def __init__(self, config: EfficientFormerConfig, **kwargs) -> None:
|
| 820 |
+
super().__init__(**kwargs)
|
| 821 |
+
self.config = config
|
| 822 |
+
|
| 823 |
+
self.patch_embed = TFEfficientFormerConvStem(config, config.hidden_sizes[0], name="patch_embed")
|
| 824 |
+
self.encoder = TFEfficientFormerEncoder(config, name="encoder")
|
| 825 |
+
self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
|
| 826 |
+
|
| 827 |
+
@unpack_inputs
|
| 828 |
+
def call(
|
| 829 |
+
self,
|
| 830 |
+
pixel_values: Optional[tf.Tensor] = None,
|
| 831 |
+
output_attentions: Optional[tf.Tensor] = None,
|
| 832 |
+
output_hidden_states: Optional[tf.Tensor] = None,
|
| 833 |
+
return_dict: Optional[bool] = None,
|
| 834 |
+
training: bool = False,
|
| 835 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor, ...]]:
|
| 836 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 837 |
+
|
| 838 |
+
output_hidden_states = (
|
| 839 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 840 |
+
)
|
| 841 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 842 |
+
|
| 843 |
+
if pixel_values is None:
|
| 844 |
+
raise ValueError("You have to specify pixel_values")
|
| 845 |
+
|
| 846 |
+
# When running on CPU, keras.layers.Conv2D and keras.layers.AveragePool2D do not
|
| 847 |
+
# support channels first NCHW format. A number of blocks contain both.
|
| 848 |
+
# So change the input format from (batch_size, num_channels, height, width) to
|
| 849 |
+
# (batch_size, height, width, num_channels) here.
|
| 850 |
+
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
| 851 |
+
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
| 852 |
+
embedding_output = self.patch_embed(pixel_values, training=training)
|
| 853 |
+
|
| 854 |
+
encoder_outputs = self.encoder(
|
| 855 |
+
hidden_states=embedding_output,
|
| 856 |
+
output_attentions=output_attentions,
|
| 857 |
+
output_hidden_states=output_hidden_states,
|
| 858 |
+
return_dict=return_dict,
|
| 859 |
+
training=training,
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
sequence_output = encoder_outputs[0]
|
| 863 |
+
sequence_output = self.layernorm(sequence_output, training=training)
|
| 864 |
+
|
| 865 |
+
# Change the hidden states from (batch_size, height, width, num_channels) to
|
| 866 |
+
# (batch_size, num_channels, height, width).
|
| 867 |
+
# The hidden states are in (batch_size, height, width, num_channels)
|
| 868 |
+
# shape after all stages except the MB3D blocks.
|
| 869 |
+
if output_hidden_states:
|
| 870 |
+
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1][:-1]]) + (
|
| 871 |
+
encoder_outputs[1][-1],
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
if not return_dict:
|
| 875 |
+
head_outputs = (sequence_output,)
|
| 876 |
+
return head_outputs + encoder_outputs[1:]
|
| 877 |
+
|
| 878 |
+
return TFBaseModelOutput(
|
| 879 |
+
last_hidden_state=sequence_output,
|
| 880 |
+
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
|
| 881 |
+
attentions=encoder_outputs.attentions,
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
def build(self, input_shape=None):
|
| 885 |
+
if self.built:
|
| 886 |
+
return
|
| 887 |
+
self.built = True
|
| 888 |
+
if getattr(self, "patch_embed", None) is not None:
|
| 889 |
+
with tf.name_scope(self.patch_embed.name):
|
| 890 |
+
self.patch_embed.build(None)
|
| 891 |
+
if getattr(self, "encoder", None) is not None:
|
| 892 |
+
with tf.name_scope(self.encoder.name):
|
| 893 |
+
self.encoder.build(None)
|
| 894 |
+
if getattr(self, "layernorm", None) is not None:
|
| 895 |
+
with tf.name_scope(self.layernorm.name):
|
| 896 |
+
self.layernorm.build([None, None, self.config.hidden_sizes[-1]])
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
class TFEfficientFormerPreTrainedModel(TFPreTrainedModel):
|
| 900 |
+
"""
|
| 901 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 902 |
+
models.
|
| 903 |
+
"""
|
| 904 |
+
|
| 905 |
+
config_class = EfficientFormerConfig
|
| 906 |
+
base_model_prefix = "efficientformer"
|
| 907 |
+
main_input_name = "pixel_values"
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
EFFICIENTFORMER_START_DOCSTRING = r"""
|
| 911 |
+
This model is a TensorFlow
|
| 912 |
+
[keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular
|
| 913 |
+
TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior.
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
Parameters:
|
| 917 |
+
config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model.
|
| 918 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 919 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 920 |
+
"""
|
| 921 |
+
|
| 922 |
+
EFFICIENTFORMER_INPUTS_DOCSTRING = r"""
|
| 923 |
+
Args:
|
| 924 |
+
pixel_values ((`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
| 925 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 926 |
+
[`EfficientFormerImageProcessor.__call__`] for details.
|
| 927 |
+
output_attentions (`bool`, *optional*):
|
| 928 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 929 |
+
tensors for more detail.
|
| 930 |
+
output_hidden_states (`bool`, *optional*):
|
| 931 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 932 |
+
more detail.
|
| 933 |
+
return_dict (`bool`, *optional*):
|
| 934 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 935 |
+
"""
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
@add_start_docstrings(
|
| 939 |
+
"The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.",
|
| 940 |
+
EFFICIENTFORMER_START_DOCSTRING,
|
| 941 |
+
)
|
| 942 |
+
class TFEfficientFormerModel(TFEfficientFormerPreTrainedModel):
|
| 943 |
+
def __init__(self, config: EfficientFormerConfig, **kwargs) -> None:
|
| 944 |
+
super().__init__(config, **kwargs)
|
| 945 |
+
|
| 946 |
+
self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer")
|
| 947 |
+
|
| 948 |
+
@unpack_inputs
|
| 949 |
+
@add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
|
| 950 |
+
@add_code_sample_docstrings(
|
| 951 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 952 |
+
output_type=TFBaseModelOutputWithPooling,
|
| 953 |
+
config_class=_CONFIG_FOR_DOC,
|
| 954 |
+
modality="vision",
|
| 955 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 956 |
+
)
|
| 957 |
+
def call(
|
| 958 |
+
self,
|
| 959 |
+
pixel_values: Optional[tf.Tensor] = None,
|
| 960 |
+
output_attentions: Optional[bool] = None,
|
| 961 |
+
output_hidden_states: Optional[bool] = None,
|
| 962 |
+
return_dict: Optional[bool] = None,
|
| 963 |
+
training: bool = False,
|
| 964 |
+
) -> Union[Tuple, TFBaseModelOutput]:
|
| 965 |
+
outputs = self.efficientformer(
|
| 966 |
+
pixel_values=pixel_values,
|
| 967 |
+
output_attentions=output_attentions,
|
| 968 |
+
output_hidden_states=output_hidden_states,
|
| 969 |
+
return_dict=return_dict,
|
| 970 |
+
training=training,
|
| 971 |
+
)
|
| 972 |
+
return outputs
|
| 973 |
+
|
| 974 |
+
def build(self, input_shape=None):
|
| 975 |
+
if self.built:
|
| 976 |
+
return
|
| 977 |
+
self.built = True
|
| 978 |
+
if getattr(self, "efficientformer", None) is not None:
|
| 979 |
+
with tf.name_scope(self.efficientformer.name):
|
| 980 |
+
self.efficientformer.build(None)
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
@add_start_docstrings(
|
| 984 |
+
"""
|
| 985 |
+
EfficientFormer Model transformer with an image classification head on top of pooled last hidden state, e.g. for
|
| 986 |
+
ImageNet.
|
| 987 |
+
""",
|
| 988 |
+
EFFICIENTFORMER_START_DOCSTRING,
|
| 989 |
+
)
|
| 990 |
+
class TFEfficientFormerForImageClassification(TFEfficientFormerPreTrainedModel, TFSequenceClassificationLoss):
|
| 991 |
+
def __init__(self, config: EfficientFormerConfig):
|
| 992 |
+
super().__init__(config)
|
| 993 |
+
|
| 994 |
+
self.num_labels = config.num_labels
|
| 995 |
+
self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer")
|
| 996 |
+
|
| 997 |
+
# Classifier head
|
| 998 |
+
self.classifier = (
|
| 999 |
+
keras.layers.Dense(config.num_labels, name="classifier")
|
| 1000 |
+
if config.num_labels > 0
|
| 1001 |
+
else keras.layers.Activation("linear", name="classifier")
|
| 1002 |
+
)
|
| 1003 |
+
self.config = config
|
| 1004 |
+
|
| 1005 |
+
@unpack_inputs
|
| 1006 |
+
@add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
|
| 1007 |
+
@add_code_sample_docstrings(
|
| 1008 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 1009 |
+
output_type=TFImageClassifierOutput,
|
| 1010 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1011 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 1012 |
+
)
|
| 1013 |
+
def call(
|
| 1014 |
+
self,
|
| 1015 |
+
pixel_values: Optional[tf.Tensor] = None,
|
| 1016 |
+
labels: Optional[tf.Tensor] = None,
|
| 1017 |
+
output_attentions: Optional[bool] = None,
|
| 1018 |
+
output_hidden_states: Optional[bool] = None,
|
| 1019 |
+
return_dict: Optional[bool] = None,
|
| 1020 |
+
training: bool = False,
|
| 1021 |
+
) -> Union[tf.Tensor, TFImageClassifierOutput]:
|
| 1022 |
+
r"""
|
| 1023 |
+
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1024 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 1025 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1026 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1027 |
+
"""
|
| 1028 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1029 |
+
|
| 1030 |
+
outputs = self.efficientformer(
|
| 1031 |
+
pixel_values=pixel_values,
|
| 1032 |
+
output_attentions=output_attentions,
|
| 1033 |
+
output_hidden_states=output_hidden_states,
|
| 1034 |
+
return_dict=return_dict,
|
| 1035 |
+
training=training,
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
sequence_output = outputs[0]
|
| 1039 |
+
|
| 1040 |
+
logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2))
|
| 1041 |
+
|
| 1042 |
+
loss = None if labels is None else self.hf_compute_loss(labels, logits)
|
| 1043 |
+
|
| 1044 |
+
if not return_dict:
|
| 1045 |
+
output = (logits,) + outputs[1:]
|
| 1046 |
+
return ((loss,) + output) if loss is not None else output
|
| 1047 |
+
|
| 1048 |
+
return TFImageClassifierOutput(
|
| 1049 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
def build(self, input_shape=None):
|
| 1053 |
+
if self.built:
|
| 1054 |
+
return
|
| 1055 |
+
self.built = True
|
| 1056 |
+
if getattr(self, "efficientformer", None) is not None:
|
| 1057 |
+
with tf.name_scope(self.efficientformer.name):
|
| 1058 |
+
self.efficientformer.build(None)
|
| 1059 |
+
if getattr(self, "classifier", None) is not None:
|
| 1060 |
+
if hasattr(self.classifier, "name"):
|
| 1061 |
+
with tf.name_scope(self.classifier.name):
|
| 1062 |
+
self.classifier.build([None, None, self.config.hidden_sizes[-1]])
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
@dataclass
|
| 1066 |
+
class TFEfficientFormerForImageClassificationWithTeacherOutput(ModelOutput):
|
| 1067 |
+
"""
|
| 1068 |
+
Args:
|
| 1069 |
+
Output type of [`EfficientFormerForImageClassificationWithTeacher`].
|
| 1070 |
+
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
|
| 1071 |
+
Prediction scores as the average of the cls_logits and distillation logits.
|
| 1072 |
+
cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
|
| 1073 |
+
Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
|
| 1074 |
+
class token).
|
| 1075 |
+
distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
|
| 1076 |
+
Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
|
| 1077 |
+
distillation token).
|
| 1078 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
|
| 1079 |
+
`config.output_hidden_states=True`):
|
| 1080 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 1081 |
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
|
| 1082 |
+
the initial embedding outputs.
|
| 1083 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
|
| 1084 |
+
`config.output_attentions=True`):
|
| 1085 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 1086 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 1087 |
+
the self-attention heads.
|
| 1088 |
+
"""
|
| 1089 |
+
|
| 1090 |
+
logits: Optional[tf.Tensor] = None
|
| 1091 |
+
cls_logits: Optional[tf.Tensor] = None
|
| 1092 |
+
distillation_logits: Optional[tf.Tensor] = None
|
| 1093 |
+
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
| 1094 |
+
attentions: Optional[Tuple[tf.Tensor]] = None
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
@add_start_docstrings(
|
| 1098 |
+
"""
|
| 1099 |
+
EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden
|
| 1100 |
+
state and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
|
| 1101 |
+
|
| 1102 |
+
.. warning::
|
| 1103 |
+
This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
|
| 1104 |
+
supported.
|
| 1105 |
+
""",
|
| 1106 |
+
EFFICIENTFORMER_START_DOCSTRING,
|
| 1107 |
+
)
|
| 1108 |
+
class TFEfficientFormerForImageClassificationWithTeacher(TFEfficientFormerPreTrainedModel):
|
| 1109 |
+
def __init__(self, config: EfficientFormerConfig) -> None:
|
| 1110 |
+
super().__init__(config)
|
| 1111 |
+
|
| 1112 |
+
self.num_labels = config.num_labels
|
| 1113 |
+
self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer")
|
| 1114 |
+
|
| 1115 |
+
# Classifier heads
|
| 1116 |
+
self.classifier = (
|
| 1117 |
+
keras.layers.Dense(config.num_labels, name="classifier")
|
| 1118 |
+
if config.num_labels > 0
|
| 1119 |
+
else keras.layers.Activation("linear", name="classifier")
|
| 1120 |
+
)
|
| 1121 |
+
self.distillation_classifier = (
|
| 1122 |
+
keras.layers.Dense(config.num_labels, name="distillation_classifier")
|
| 1123 |
+
if config.num_labels > 0
|
| 1124 |
+
else keras.layers.Activation("linear", name="distillation_classifier")
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
@unpack_inputs
|
| 1128 |
+
@add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
|
| 1129 |
+
@add_code_sample_docstrings(
|
| 1130 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 1131 |
+
output_type=TFEfficientFormerForImageClassificationWithTeacherOutput,
|
| 1132 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1133 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 1134 |
+
)
|
| 1135 |
+
def call(
|
| 1136 |
+
self,
|
| 1137 |
+
pixel_values: Optional[tf.Tensor] = None,
|
| 1138 |
+
output_attentions: Optional[bool] = None,
|
| 1139 |
+
output_hidden_states: Optional[bool] = None,
|
| 1140 |
+
return_dict: Optional[bool] = None,
|
| 1141 |
+
training: bool = False,
|
| 1142 |
+
) -> Union[tuple, TFEfficientFormerForImageClassificationWithTeacherOutput]:
|
| 1143 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1144 |
+
|
| 1145 |
+
if training:
|
| 1146 |
+
raise Exception(
|
| 1147 |
+
"This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet supported."
|
| 1148 |
+
)
|
| 1149 |
+
|
| 1150 |
+
outputs = self.efficientformer(
|
| 1151 |
+
pixel_values=pixel_values,
|
| 1152 |
+
output_attentions=output_attentions,
|
| 1153 |
+
output_hidden_states=output_hidden_states,
|
| 1154 |
+
return_dict=return_dict,
|
| 1155 |
+
training=training,
|
| 1156 |
+
)
|
| 1157 |
+
|
| 1158 |
+
sequence_output = outputs[0]
|
| 1159 |
+
|
| 1160 |
+
cls_logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2))
|
| 1161 |
+
distillation_logits = self.distillation_classifier(tf.reduce_mean(sequence_output, axis=-2))
|
| 1162 |
+
logits = (cls_logits + distillation_logits) / 2
|
| 1163 |
+
|
| 1164 |
+
if not return_dict:
|
| 1165 |
+
output = (logits, cls_logits, distillation_logits) + outputs[1:]
|
| 1166 |
+
return output
|
| 1167 |
+
|
| 1168 |
+
return TFEfficientFormerForImageClassificationWithTeacherOutput(
|
| 1169 |
+
logits=logits,
|
| 1170 |
+
cls_logits=cls_logits,
|
| 1171 |
+
distillation_logits=distillation_logits,
|
| 1172 |
+
hidden_states=outputs.hidden_states,
|
| 1173 |
+
attentions=outputs.attentions,
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
def build(self, input_shape=None):
|
| 1177 |
+
if self.built:
|
| 1178 |
+
return
|
| 1179 |
+
self.built = True
|
| 1180 |
+
if getattr(self, "efficientformer", None) is not None:
|
| 1181 |
+
with tf.name_scope(self.efficientformer.name):
|
| 1182 |
+
self.efficientformer.build(None)
|
| 1183 |
+
if getattr(self, "classifier", None) is not None:
|
| 1184 |
+
if hasattr(self.classifier, "name"):
|
| 1185 |
+
with tf.name_scope(self.classifier.name):
|
| 1186 |
+
self.classifier.build([None, None, self.config.hidden_sizes[-1]])
|
| 1187 |
+
if getattr(self, "distillation_classifier", None) is not None:
|
| 1188 |
+
if hasattr(self.distillation_classifier, "name"):
|
| 1189 |
+
with tf.name_scope(self.distillation_classifier.name):
|
| 1190 |
+
self.distillation_classifier.build([None, None, self.config.hidden_sizes[-1]])
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
__all__ = [
|
| 1194 |
+
"TFEfficientFormerForImageClassification",
|
| 1195 |
+
"TFEfficientFormerForImageClassificationWithTeacher",
|
| 1196 |
+
"TFEfficientFormerModel",
|
| 1197 |
+
"TFEfficientFormerPreTrainedModel",
|
| 1198 |
+
]
|
docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace and Baidu Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_ernie_m import *
|
| 22 |
+
from .modeling_ernie_m import *
|
| 23 |
+
from .tokenization_ernie_m import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""ErnieM model configuration"""
|
| 16 |
+
# Adapted from original paddlenlp repository.(https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/ernie_m/configuration.py)
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from typing import Dict
|
| 21 |
+
|
| 22 |
+
from ....configuration_utils import PretrainedConfig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ErnieMConfig(PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`ErnieMModel`]. It is used to instantiate a
|
| 28 |
+
Ernie-M model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 29 |
+
with the defaults will yield a similar configuration to that of the `Ernie-M`
|
| 30 |
+
[susnato/ernie-m-base_pytorch](https://huggingface.co/susnato/ernie-m-base_pytorch) architecture.
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 34 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
vocab_size (`int`, *optional*, defaults to 250002):
|
| 38 |
+
Vocabulary size of `inputs_ids` in [`ErnieMModel`]. Also is the vocab size of token embedding matrix.
|
| 39 |
+
Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling
|
| 40 |
+
[`ErnieMModel`].
|
| 41 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 42 |
+
Dimensionality of the embedding layer, encoder layers and pooler layer.
|
| 43 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 44 |
+
Number of hidden layers in the Transformer encoder.
|
| 45 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 46 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 47 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 48 |
+
Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors to feed-forward layers are
|
| 49 |
+
firstly projected from hidden_size to intermediate_size, and then projected back to hidden_size. Typically
|
| 50 |
+
intermediate_size is larger than hidden_size.
|
| 51 |
+
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
| 52 |
+
The non-linear activation function in the feed-forward layer. `"gelu"`, `"relu"` and any other torch
|
| 53 |
+
supported activation functions are supported.
|
| 54 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 55 |
+
The dropout probability for all fully connected layers in the embeddings and encoder.
|
| 56 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 57 |
+
The dropout probability used in `MultiHeadAttention` in all encoder layers to drop some attention target.
|
| 58 |
+
max_position_embeddings (`int`, *optional*, defaults to 514):
|
| 59 |
+
The maximum value of the dimensionality of position encoding, which dictates the maximum supported length
|
| 60 |
+
of an input sequence.
|
| 61 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 62 |
+
The standard deviation of the normal initializer for initializing all weight matrices. The index of padding
|
| 63 |
+
token in the token vocabulary.
|
| 64 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 65 |
+
Padding token id.
|
| 66 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 67 |
+
The epsilon used by the layer normalization layers.
|
| 68 |
+
classifier_dropout (`float`, *optional*):
|
| 69 |
+
The dropout ratio for the classification head.
|
| 70 |
+
act_dropout (`float`, *optional*, defaults to 0.0):
|
| 71 |
+
This dropout probability is used in `ErnieMEncoderLayer` after activation.
|
| 72 |
+
|
| 73 |
+
A normal_initializer initializes weight matrices as normal distributions. See
|
| 74 |
+
`ErnieMPretrainedModel._init_weights()` for how weights are initialized in `ErnieMModel`.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
model_type = "ernie_m"
|
| 78 |
+
attribute_map: Dict[str, str] = {"dropout": "classifier_dropout", "num_classes": "num_labels"}
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
vocab_size: int = 250002,
|
| 83 |
+
hidden_size: int = 768,
|
| 84 |
+
num_hidden_layers: int = 12,
|
| 85 |
+
num_attention_heads: int = 12,
|
| 86 |
+
intermediate_size: int = 3072,
|
| 87 |
+
hidden_act: str = "gelu",
|
| 88 |
+
hidden_dropout_prob: float = 0.1,
|
| 89 |
+
attention_probs_dropout_prob: float = 0.1,
|
| 90 |
+
max_position_embeddings: int = 514,
|
| 91 |
+
initializer_range: float = 0.02,
|
| 92 |
+
pad_token_id: int = 1,
|
| 93 |
+
layer_norm_eps: float = 1e-05,
|
| 94 |
+
classifier_dropout=None,
|
| 95 |
+
act_dropout=0.0,
|
| 96 |
+
**kwargs,
|
| 97 |
+
):
|
| 98 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 99 |
+
self.vocab_size = vocab_size
|
| 100 |
+
self.hidden_size = hidden_size
|
| 101 |
+
self.num_hidden_layers = num_hidden_layers
|
| 102 |
+
self.num_attention_heads = num_attention_heads
|
| 103 |
+
self.intermediate_size = intermediate_size
|
| 104 |
+
self.hidden_act = hidden_act
|
| 105 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 106 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 107 |
+
self.max_position_embeddings = max_position_embeddings
|
| 108 |
+
self.initializer_range = initializer_range
|
| 109 |
+
self.layer_norm_eps = layer_norm_eps
|
| 110 |
+
self.classifier_dropout = classifier_dropout
|
| 111 |
+
self.act_dropout = act_dropout
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
__all__ = ["ErnieMConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py
ADDED
|
@@ -0,0 +1,1058 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ErnieM model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn, tensor
|
| 23 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
+
|
| 25 |
+
from ....activations import ACT2FN
|
| 26 |
+
from ....modeling_outputs import (
|
| 27 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 28 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 29 |
+
MultipleChoiceModelOutput,
|
| 30 |
+
QuestionAnsweringModelOutput,
|
| 31 |
+
SequenceClassifierOutput,
|
| 32 |
+
TokenClassifierOutput,
|
| 33 |
+
)
|
| 34 |
+
from ....modeling_utils import PreTrainedModel
|
| 35 |
+
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 36 |
+
from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 37 |
+
from .configuration_ernie_m import ErnieMConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
_CHECKPOINT_FOR_DOC = "susnato/ernie-m-base_pytorch"
|
| 43 |
+
_CONFIG_FOR_DOC = "ErnieMConfig"
|
| 44 |
+
_TOKENIZER_FOR_DOC = "ErnieMTokenizer"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Adapted from paddlenlp.transformers.ernie_m.modeling.ErnieEmbeddings
|
| 48 |
+
class ErnieMEmbeddings(nn.Module):
|
| 49 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, config):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.hidden_size = config.hidden_size
|
| 54 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 55 |
+
self.position_embeddings = nn.Embedding(
|
| 56 |
+
config.max_position_embeddings, config.hidden_size, padding_idx=config.pad_token_id
|
| 57 |
+
)
|
| 58 |
+
self.layer_norm = nn.LayerNorm(normalized_shape=config.hidden_size, eps=config.layer_norm_eps)
|
| 59 |
+
self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
|
| 60 |
+
self.padding_idx = config.pad_token_id
|
| 61 |
+
|
| 62 |
+
def forward(
|
| 63 |
+
self,
|
| 64 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 65 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 66 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
| 67 |
+
past_key_values_length: int = 0,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
if inputs_embeds is None:
|
| 70 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 71 |
+
if position_ids is None:
|
| 72 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 73 |
+
ones = torch.ones(input_shape, dtype=torch.int64, device=inputs_embeds.device)
|
| 74 |
+
seq_length = torch.cumsum(ones, dim=1)
|
| 75 |
+
position_ids = seq_length - ones
|
| 76 |
+
|
| 77 |
+
if past_key_values_length > 0:
|
| 78 |
+
position_ids = position_ids + past_key_values_length
|
| 79 |
+
# to mimic paddlenlp implementation
|
| 80 |
+
position_ids += 2
|
| 81 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 82 |
+
embeddings = inputs_embeds + position_embeddings
|
| 83 |
+
embeddings = self.layer_norm(embeddings)
|
| 84 |
+
embeddings = self.dropout(embeddings)
|
| 85 |
+
|
| 86 |
+
return embeddings
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ErnieMSelfAttention(nn.Module):
|
| 90 |
+
def __init__(self, config, position_embedding_type=None):
|
| 91 |
+
super().__init__()
|
| 92 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 95 |
+
f"heads ({config.num_attention_heads})"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.num_attention_heads = config.num_attention_heads
|
| 99 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 100 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 101 |
+
|
| 102 |
+
self.q_proj = nn.Linear(config.hidden_size, self.all_head_size)
|
| 103 |
+
self.k_proj = nn.Linear(config.hidden_size, self.all_head_size)
|
| 104 |
+
self.v_proj = nn.Linear(config.hidden_size, self.all_head_size)
|
| 105 |
+
|
| 106 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 107 |
+
self.position_embedding_type = position_embedding_type or getattr(
|
| 108 |
+
config, "position_embedding_type", "absolute"
|
| 109 |
+
)
|
| 110 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 111 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 112 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 113 |
+
|
| 114 |
+
self.is_decoder = config.is_decoder
|
| 115 |
+
|
| 116 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 117 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 118 |
+
x = x.view(new_x_shape)
|
| 119 |
+
return x.permute(0, 2, 1, 3)
|
| 120 |
+
|
| 121 |
+
def forward(
|
| 122 |
+
self,
|
| 123 |
+
hidden_states: torch.Tensor,
|
| 124 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 125 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 126 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 127 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 128 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 129 |
+
output_attentions: Optional[bool] = False,
|
| 130 |
+
) -> Tuple[torch.Tensor]:
|
| 131 |
+
mixed_query_layer = self.q_proj(hidden_states)
|
| 132 |
+
|
| 133 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 134 |
+
# and values come from an encoder; the attention mask needs to be
|
| 135 |
+
# such that the encoder's padding tokens are not attended to.
|
| 136 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 137 |
+
|
| 138 |
+
if is_cross_attention and past_key_value is not None:
|
| 139 |
+
# reuse k,v, cross_attentions
|
| 140 |
+
key_layer = past_key_value[0]
|
| 141 |
+
value_layer = past_key_value[1]
|
| 142 |
+
attention_mask = encoder_attention_mask
|
| 143 |
+
elif is_cross_attention:
|
| 144 |
+
key_layer = self.transpose_for_scores(self.k_proj(encoder_hidden_states))
|
| 145 |
+
value_layer = self.transpose_for_scores(self.v_proj(encoder_hidden_states))
|
| 146 |
+
attention_mask = encoder_attention_mask
|
| 147 |
+
elif past_key_value is not None:
|
| 148 |
+
key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
|
| 149 |
+
value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
|
| 150 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 151 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 152 |
+
else:
|
| 153 |
+
key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
|
| 154 |
+
value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
|
| 155 |
+
|
| 156 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 157 |
+
|
| 158 |
+
use_cache = past_key_value is not None
|
| 159 |
+
if self.is_decoder:
|
| 160 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 161 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 162 |
+
# key/value_states (first "if" case)
|
| 163 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 164 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 165 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 166 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 167 |
+
past_key_value = (key_layer, value_layer)
|
| 168 |
+
|
| 169 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 170 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 171 |
+
|
| 172 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 173 |
+
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
| 174 |
+
if use_cache:
|
| 175 |
+
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
| 176 |
+
-1, 1
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 180 |
+
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 181 |
+
distance = position_ids_l - position_ids_r
|
| 182 |
+
|
| 183 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 184 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 185 |
+
|
| 186 |
+
if self.position_embedding_type == "relative_key":
|
| 187 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 188 |
+
attention_scores = attention_scores + relative_position_scores
|
| 189 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 190 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 191 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 192 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 193 |
+
|
| 194 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 195 |
+
if attention_mask is not None:
|
| 196 |
+
# Apply the attention mask is (precomputed for all layers in ErnieMModel forward() function)
|
| 197 |
+
attention_scores = attention_scores + attention_mask
|
| 198 |
+
|
| 199 |
+
# Normalize the attention scores to probabilities.
|
| 200 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 201 |
+
|
| 202 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 203 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 204 |
+
attention_probs = self.dropout(attention_probs)
|
| 205 |
+
|
| 206 |
+
# Mask heads if we want to
|
| 207 |
+
if head_mask is not None:
|
| 208 |
+
attention_probs = attention_probs * head_mask
|
| 209 |
+
|
| 210 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 211 |
+
|
| 212 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 213 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 214 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 215 |
+
|
| 216 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 217 |
+
|
| 218 |
+
if self.is_decoder:
|
| 219 |
+
outputs = outputs + (past_key_value,)
|
| 220 |
+
return outputs
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class ErnieMAttention(nn.Module):
|
| 224 |
+
def __init__(self, config, position_embedding_type=None):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.self_attn = ErnieMSelfAttention(config, position_embedding_type=position_embedding_type)
|
| 227 |
+
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 228 |
+
self.pruned_heads = set()
|
| 229 |
+
|
| 230 |
+
def prune_heads(self, heads):
|
| 231 |
+
if len(heads) == 0:
|
| 232 |
+
return
|
| 233 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 234 |
+
heads, self.self_attn.num_attention_heads, self.self_attn.attention_head_size, self.pruned_heads
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Prune linear layers
|
| 238 |
+
self.self_attn.q_proj = prune_linear_layer(self.self_attn.q_proj, index)
|
| 239 |
+
self.self_attn.k_proj = prune_linear_layer(self.self_attn.k_proj, index)
|
| 240 |
+
self.self_attn.v_proj = prune_linear_layer(self.self_attn.v_proj, index)
|
| 241 |
+
self.out_proj = prune_linear_layer(self.out_proj, index, dim=1)
|
| 242 |
+
|
| 243 |
+
# Update hyper params and store pruned heads
|
| 244 |
+
self.self_attn.num_attention_heads = self.self_attn.num_attention_heads - len(heads)
|
| 245 |
+
self.self_attn.all_head_size = self.self_attn.attention_head_size * self.self_attn.num_attention_heads
|
| 246 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 247 |
+
|
| 248 |
+
def forward(
|
| 249 |
+
self,
|
| 250 |
+
hidden_states: torch.Tensor,
|
| 251 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 252 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 253 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 254 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 255 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 256 |
+
output_attentions: Optional[bool] = False,
|
| 257 |
+
) -> Tuple[torch.Tensor]:
|
| 258 |
+
self_outputs = self.self_attn(
|
| 259 |
+
hidden_states,
|
| 260 |
+
attention_mask,
|
| 261 |
+
head_mask,
|
| 262 |
+
encoder_hidden_states,
|
| 263 |
+
encoder_attention_mask,
|
| 264 |
+
past_key_value,
|
| 265 |
+
output_attentions,
|
| 266 |
+
)
|
| 267 |
+
attention_output = self.out_proj(self_outputs[0])
|
| 268 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 269 |
+
return outputs
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class ErnieMEncoderLayer(nn.Module):
|
| 273 |
+
def __init__(self, config):
|
| 274 |
+
super().__init__()
|
| 275 |
+
# to mimic paddlenlp implementation
|
| 276 |
+
dropout = 0.1 if config.hidden_dropout_prob is None else config.hidden_dropout_prob
|
| 277 |
+
act_dropout = config.hidden_dropout_prob if config.act_dropout is None else config.act_dropout
|
| 278 |
+
|
| 279 |
+
self.self_attn = ErnieMAttention(config)
|
| 280 |
+
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 281 |
+
self.dropout = nn.Dropout(act_dropout)
|
| 282 |
+
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 283 |
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 284 |
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 285 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 286 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 287 |
+
if isinstance(config.hidden_act, str):
|
| 288 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 289 |
+
else:
|
| 290 |
+
self.activation = config.hidden_act
|
| 291 |
+
|
| 292 |
+
def forward(
|
| 293 |
+
self,
|
| 294 |
+
hidden_states: torch.Tensor,
|
| 295 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 296 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 297 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 298 |
+
output_attentions: Optional[bool] = True,
|
| 299 |
+
):
|
| 300 |
+
residual = hidden_states
|
| 301 |
+
if output_attentions:
|
| 302 |
+
hidden_states, attention_opt_weights = self.self_attn(
|
| 303 |
+
hidden_states=hidden_states,
|
| 304 |
+
attention_mask=attention_mask,
|
| 305 |
+
head_mask=head_mask,
|
| 306 |
+
past_key_value=past_key_value,
|
| 307 |
+
output_attentions=output_attentions,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
else:
|
| 311 |
+
hidden_states = self.self_attn(
|
| 312 |
+
hidden_states=hidden_states,
|
| 313 |
+
attention_mask=attention_mask,
|
| 314 |
+
head_mask=head_mask,
|
| 315 |
+
past_key_value=past_key_value,
|
| 316 |
+
output_attentions=output_attentions,
|
| 317 |
+
)
|
| 318 |
+
hidden_states = residual + self.dropout1(hidden_states)
|
| 319 |
+
hidden_states = self.norm1(hidden_states)
|
| 320 |
+
residual = hidden_states
|
| 321 |
+
|
| 322 |
+
hidden_states = self.linear1(hidden_states)
|
| 323 |
+
hidden_states = self.activation(hidden_states)
|
| 324 |
+
hidden_states = self.dropout(hidden_states)
|
| 325 |
+
hidden_states = self.linear2(hidden_states)
|
| 326 |
+
hidden_states = residual + self.dropout2(hidden_states)
|
| 327 |
+
hidden_states = self.norm2(hidden_states)
|
| 328 |
+
|
| 329 |
+
if output_attentions:
|
| 330 |
+
return hidden_states, attention_opt_weights
|
| 331 |
+
else:
|
| 332 |
+
return hidden_states
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class ErnieMEncoder(nn.Module):
|
| 336 |
+
def __init__(self, config):
|
| 337 |
+
super().__init__()
|
| 338 |
+
self.config = config
|
| 339 |
+
self.layers = nn.ModuleList([ErnieMEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 340 |
+
|
| 341 |
+
def forward(
|
| 342 |
+
self,
|
| 343 |
+
input_embeds: torch.Tensor,
|
| 344 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 345 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 346 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 347 |
+
output_attentions: Optional[bool] = False,
|
| 348 |
+
output_hidden_states: Optional[bool] = False,
|
| 349 |
+
return_dict: Optional[bool] = True,
|
| 350 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 351 |
+
hidden_states = () if output_hidden_states else None
|
| 352 |
+
attentions = () if output_attentions else None
|
| 353 |
+
|
| 354 |
+
output = input_embeds
|
| 355 |
+
if output_hidden_states:
|
| 356 |
+
hidden_states = hidden_states + (output,)
|
| 357 |
+
for i, layer in enumerate(self.layers):
|
| 358 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 359 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 360 |
+
|
| 361 |
+
output, opt_attn_weights = layer(
|
| 362 |
+
hidden_states=output,
|
| 363 |
+
attention_mask=attention_mask,
|
| 364 |
+
head_mask=layer_head_mask,
|
| 365 |
+
past_key_value=past_key_value,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
if output_hidden_states:
|
| 369 |
+
hidden_states = hidden_states + (output,)
|
| 370 |
+
if output_attentions:
|
| 371 |
+
attentions = attentions + (opt_attn_weights,)
|
| 372 |
+
|
| 373 |
+
last_hidden_state = output
|
| 374 |
+
if not return_dict:
|
| 375 |
+
return tuple(v for v in [last_hidden_state, hidden_states, attentions] if v is not None)
|
| 376 |
+
|
| 377 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 378 |
+
last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=attentions
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class ErnieMPooler(nn.Module):
|
| 383 |
+
def __init__(self, config):
|
| 384 |
+
super().__init__()
|
| 385 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 386 |
+
self.activation = nn.Tanh()
|
| 387 |
+
|
| 388 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 389 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 390 |
+
# to the first token.
|
| 391 |
+
first_token_tensor = hidden_states[:, 0]
|
| 392 |
+
pooled_output = self.dense(first_token_tensor)
|
| 393 |
+
pooled_output = self.activation(pooled_output)
|
| 394 |
+
return pooled_output
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class ErnieMPreTrainedModel(PreTrainedModel):
|
| 398 |
+
"""
|
| 399 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 400 |
+
models.
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
config_class = ErnieMConfig
|
| 404 |
+
base_model_prefix = "ernie_m"
|
| 405 |
+
|
| 406 |
+
def _init_weights(self, module):
|
| 407 |
+
"""Initialize the weights"""
|
| 408 |
+
if isinstance(module, nn.Linear):
|
| 409 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 410 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 411 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 412 |
+
if module.bias is not None:
|
| 413 |
+
module.bias.data.zero_()
|
| 414 |
+
elif isinstance(module, nn.Embedding):
|
| 415 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 416 |
+
if module.padding_idx is not None:
|
| 417 |
+
module.weight.data[module.padding_idx].zero_()
|
| 418 |
+
elif isinstance(module, nn.LayerNorm):
|
| 419 |
+
module.bias.data.zero_()
|
| 420 |
+
module.weight.data.fill_(1.0)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
ERNIE_M_START_DOCSTRING = r"""
|
| 424 |
+
|
| 425 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 426 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 427 |
+
etc.)
|
| 428 |
+
|
| 429 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 430 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 431 |
+
behavior.
|
| 432 |
+
|
| 433 |
+
Parameters:
|
| 434 |
+
config ([`ErnieMConfig`]): Model configuration class with all the parameters of the model.
|
| 435 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 436 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 437 |
+
"""
|
| 438 |
+
|
| 439 |
+
ERNIE_M_INPUTS_DOCSTRING = r"""
|
| 440 |
+
Args:
|
| 441 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 442 |
+
Indices of input sequence tokens in the vocabulary.
|
| 443 |
+
|
| 444 |
+
Indices can be obtained using [`ErnieMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 445 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 446 |
+
|
| 447 |
+
[What are input IDs?](../glossary#input-ids)
|
| 448 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 449 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 450 |
+
|
| 451 |
+
- 1 for tokens that are **not masked**,
|
| 452 |
+
- 0 for tokens that are **masked**.
|
| 453 |
+
|
| 454 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 455 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 456 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 457 |
+
config.max_position_embeddings - 1]`.
|
| 458 |
+
|
| 459 |
+
[What are position IDs?](../glossary#position-ids)
|
| 460 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 461 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 462 |
+
|
| 463 |
+
- 1 indicates the head is **not masked**,
|
| 464 |
+
- 0 indicates the head is **masked**.
|
| 465 |
+
|
| 466 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 467 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 468 |
+
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
| 469 |
+
model's internal embedding lookup matrix.
|
| 470 |
+
output_attentions (`bool`, *optional*):
|
| 471 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 472 |
+
tensors for more detail.
|
| 473 |
+
output_hidden_states (`bool`, *optional*):
|
| 474 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 475 |
+
more detail.
|
| 476 |
+
return_dict (`bool`, *optional*):
|
| 477 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
@add_start_docstrings(
|
| 482 |
+
"The bare ErnieM Model transformer outputting raw hidden-states without any specific head on top.",
|
| 483 |
+
ERNIE_M_START_DOCSTRING,
|
| 484 |
+
)
|
| 485 |
+
class ErnieMModel(ErnieMPreTrainedModel):
|
| 486 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 487 |
+
super(ErnieMModel, self).__init__(config)
|
| 488 |
+
self.initializer_range = config.initializer_range
|
| 489 |
+
self.embeddings = ErnieMEmbeddings(config)
|
| 490 |
+
self.encoder = ErnieMEncoder(config)
|
| 491 |
+
self.pooler = ErnieMPooler(config) if add_pooling_layer else None
|
| 492 |
+
self.post_init()
|
| 493 |
+
|
| 494 |
+
def get_input_embeddings(self):
|
| 495 |
+
return self.embeddings.word_embeddings
|
| 496 |
+
|
| 497 |
+
def set_input_embeddings(self, value):
|
| 498 |
+
self.embeddings.word_embeddings = value
|
| 499 |
+
|
| 500 |
+
def _prune_heads(self, heads_to_prune):
|
| 501 |
+
"""
|
| 502 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 503 |
+
class PreTrainedModel
|
| 504 |
+
"""
|
| 505 |
+
for layer, heads in heads_to_prune.items():
|
| 506 |
+
self.encoder.layers[layer].self_attn.prune_heads(heads)
|
| 507 |
+
|
| 508 |
+
@add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 509 |
+
@add_code_sample_docstrings(
|
| 510 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
| 511 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 512 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
| 513 |
+
config_class=_CONFIG_FOR_DOC,
|
| 514 |
+
)
|
| 515 |
+
def forward(
|
| 516 |
+
self,
|
| 517 |
+
input_ids: Optional[tensor] = None,
|
| 518 |
+
position_ids: Optional[tensor] = None,
|
| 519 |
+
attention_mask: Optional[tensor] = None,
|
| 520 |
+
head_mask: Optional[tensor] = None,
|
| 521 |
+
inputs_embeds: Optional[tensor] = None,
|
| 522 |
+
past_key_values: Optional[Tuple[Tuple[tensor]]] = None,
|
| 523 |
+
use_cache: Optional[bool] = None,
|
| 524 |
+
output_hidden_states: Optional[bool] = None,
|
| 525 |
+
output_attentions: Optional[bool] = None,
|
| 526 |
+
return_dict: Optional[bool] = None,
|
| 527 |
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 528 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 529 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.")
|
| 530 |
+
|
| 531 |
+
# init the default bool value
|
| 532 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 533 |
+
output_hidden_states = (
|
| 534 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 535 |
+
)
|
| 536 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 537 |
+
|
| 538 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 539 |
+
|
| 540 |
+
past_key_values_length = 0
|
| 541 |
+
if past_key_values is not None:
|
| 542 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 543 |
+
|
| 544 |
+
# Adapted from paddlenlp.transformers.ernie_m.ErnieMModel
|
| 545 |
+
if attention_mask is None:
|
| 546 |
+
attention_mask = (input_ids == self.config.pad_token_id).to(torch.float32)
|
| 547 |
+
attention_mask *= torch.finfo(attention_mask.dtype).min
|
| 548 |
+
if past_key_values is not None:
|
| 549 |
+
batch_size = past_key_values[0][0].shape[0]
|
| 550 |
+
past_mask = torch.zeros([batch_size, 1, 1, past_key_values_length], dtype=attention_mask.dtype)
|
| 551 |
+
attention_mask = torch.concat([past_mask, attention_mask], dim=-1)
|
| 552 |
+
# For 2D attention_mask from tokenizer
|
| 553 |
+
elif attention_mask.ndim == 2:
|
| 554 |
+
attention_mask = attention_mask.to(torch.float32)
|
| 555 |
+
attention_mask = 1.0 - attention_mask
|
| 556 |
+
attention_mask *= torch.finfo(attention_mask.dtype).min
|
| 557 |
+
|
| 558 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
| 559 |
+
|
| 560 |
+
embedding_output = self.embeddings(
|
| 561 |
+
input_ids=input_ids,
|
| 562 |
+
position_ids=position_ids,
|
| 563 |
+
inputs_embeds=inputs_embeds,
|
| 564 |
+
past_key_values_length=past_key_values_length,
|
| 565 |
+
)
|
| 566 |
+
encoder_outputs = self.encoder(
|
| 567 |
+
embedding_output,
|
| 568 |
+
attention_mask=extended_attention_mask,
|
| 569 |
+
head_mask=head_mask,
|
| 570 |
+
past_key_values=past_key_values,
|
| 571 |
+
output_attentions=output_attentions,
|
| 572 |
+
output_hidden_states=output_hidden_states,
|
| 573 |
+
return_dict=return_dict,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
if not return_dict:
|
| 577 |
+
sequence_output = encoder_outputs[0]
|
| 578 |
+
pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 579 |
+
return (sequence_output, pooler_output) + encoder_outputs[1:]
|
| 580 |
+
|
| 581 |
+
sequence_output = encoder_outputs["last_hidden_state"]
|
| 582 |
+
pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 583 |
+
hidden_states = None if not output_hidden_states else encoder_outputs["hidden_states"]
|
| 584 |
+
attentions = None if not output_attentions else encoder_outputs["attentions"]
|
| 585 |
+
|
| 586 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 587 |
+
last_hidden_state=sequence_output,
|
| 588 |
+
pooler_output=pooler_output,
|
| 589 |
+
hidden_states=hidden_states,
|
| 590 |
+
attentions=attentions,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
@add_start_docstrings(
|
| 595 |
+
"""ErnieM Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
| 596 |
+
the pooled output) e.g. for GLUE tasks.""",
|
| 597 |
+
ERNIE_M_START_DOCSTRING,
|
| 598 |
+
)
|
| 599 |
+
class ErnieMForSequenceClassification(ErnieMPreTrainedModel):
|
| 600 |
+
def __init__(self, config):
|
| 601 |
+
super().__init__(config)
|
| 602 |
+
self.num_labels = config.num_labels
|
| 603 |
+
self.config = config
|
| 604 |
+
|
| 605 |
+
self.ernie_m = ErnieMModel(config)
|
| 606 |
+
classifier_dropout = (
|
| 607 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 608 |
+
)
|
| 609 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 610 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 611 |
+
|
| 612 |
+
# Initialize weights and apply final processing
|
| 613 |
+
self.post_init()
|
| 614 |
+
|
| 615 |
+
@add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 616 |
+
@add_code_sample_docstrings(
|
| 617 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
| 618 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 619 |
+
output_type=SequenceClassifierOutput,
|
| 620 |
+
config_class=_CONFIG_FOR_DOC,
|
| 621 |
+
)
|
| 622 |
+
def forward(
|
| 623 |
+
self,
|
| 624 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 625 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 626 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 627 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 628 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 629 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
| 630 |
+
use_cache: Optional[bool] = None,
|
| 631 |
+
output_hidden_states: Optional[bool] = None,
|
| 632 |
+
output_attentions: Optional[bool] = None,
|
| 633 |
+
return_dict: Optional[bool] = True,
|
| 634 |
+
labels: Optional[torch.Tensor] = None,
|
| 635 |
+
) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
|
| 636 |
+
r"""
|
| 637 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 638 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 639 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 640 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 641 |
+
"""
|
| 642 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 643 |
+
|
| 644 |
+
outputs = self.ernie_m(
|
| 645 |
+
input_ids,
|
| 646 |
+
attention_mask=attention_mask,
|
| 647 |
+
position_ids=position_ids,
|
| 648 |
+
head_mask=head_mask,
|
| 649 |
+
inputs_embeds=inputs_embeds,
|
| 650 |
+
past_key_values=past_key_values,
|
| 651 |
+
output_hidden_states=output_hidden_states,
|
| 652 |
+
output_attentions=output_attentions,
|
| 653 |
+
return_dict=return_dict,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
pooled_output = outputs[1]
|
| 657 |
+
|
| 658 |
+
pooled_output = self.dropout(pooled_output)
|
| 659 |
+
logits = self.classifier(pooled_output)
|
| 660 |
+
|
| 661 |
+
loss = None
|
| 662 |
+
if labels is not None:
|
| 663 |
+
if self.config.problem_type is None:
|
| 664 |
+
if self.num_labels == 1:
|
| 665 |
+
self.config.problem_type = "regression"
|
| 666 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 667 |
+
self.config.problem_type = "single_label_classification"
|
| 668 |
+
else:
|
| 669 |
+
self.config.problem_type = "multi_label_classification"
|
| 670 |
+
|
| 671 |
+
if self.config.problem_type == "regression":
|
| 672 |
+
loss_fct = MSELoss()
|
| 673 |
+
if self.num_labels == 1:
|
| 674 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 675 |
+
else:
|
| 676 |
+
loss = loss_fct(logits, labels)
|
| 677 |
+
elif self.config.problem_type == "single_label_classification":
|
| 678 |
+
loss_fct = CrossEntropyLoss()
|
| 679 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 680 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 681 |
+
loss_fct = BCEWithLogitsLoss()
|
| 682 |
+
loss = loss_fct(logits, labels)
|
| 683 |
+
if not return_dict:
|
| 684 |
+
output = (logits,) + outputs[2:]
|
| 685 |
+
return ((loss,) + output) if loss is not None else output
|
| 686 |
+
|
| 687 |
+
return SequenceClassifierOutput(
|
| 688 |
+
loss=loss,
|
| 689 |
+
logits=logits,
|
| 690 |
+
hidden_states=outputs.hidden_states,
|
| 691 |
+
attentions=outputs.attentions,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
@add_start_docstrings(
|
| 696 |
+
"""ErnieM Model with a multiple choice classification head on top (a linear layer on top of
|
| 697 |
+
the pooled output and a softmax) e.g. for RocStories/SWAG tasks.""",
|
| 698 |
+
ERNIE_M_START_DOCSTRING,
|
| 699 |
+
)
|
| 700 |
+
class ErnieMForMultipleChoice(ErnieMPreTrainedModel):
|
| 701 |
+
def __init__(self, config):
|
| 702 |
+
super().__init__(config)
|
| 703 |
+
|
| 704 |
+
self.ernie_m = ErnieMModel(config)
|
| 705 |
+
classifier_dropout = (
|
| 706 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 707 |
+
)
|
| 708 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 709 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 710 |
+
|
| 711 |
+
# Initialize weights and apply final processing
|
| 712 |
+
self.post_init()
|
| 713 |
+
|
| 714 |
+
@add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
| 715 |
+
@add_code_sample_docstrings(
|
| 716 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 717 |
+
output_type=MultipleChoiceModelOutput,
|
| 718 |
+
config_class=_CONFIG_FOR_DOC,
|
| 719 |
+
)
|
| 720 |
+
def forward(
|
| 721 |
+
self,
|
| 722 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 723 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 724 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 725 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 726 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 727 |
+
labels: Optional[torch.Tensor] = None,
|
| 728 |
+
output_attentions: Optional[bool] = None,
|
| 729 |
+
output_hidden_states: Optional[bool] = None,
|
| 730 |
+
return_dict: Optional[bool] = True,
|
| 731 |
+
) -> Union[Tuple[torch.FloatTensor], MultipleChoiceModelOutput]:
|
| 732 |
+
r"""
|
| 733 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 734 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 735 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
| 736 |
+
`input_ids` above)
|
| 737 |
+
"""
|
| 738 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 739 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 740 |
+
|
| 741 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 742 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 743 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 744 |
+
inputs_embeds = (
|
| 745 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 746 |
+
if inputs_embeds is not None
|
| 747 |
+
else None
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
outputs = self.ernie_m(
|
| 751 |
+
input_ids,
|
| 752 |
+
attention_mask=attention_mask,
|
| 753 |
+
position_ids=position_ids,
|
| 754 |
+
head_mask=head_mask,
|
| 755 |
+
inputs_embeds=inputs_embeds,
|
| 756 |
+
output_attentions=output_attentions,
|
| 757 |
+
output_hidden_states=output_hidden_states,
|
| 758 |
+
return_dict=return_dict,
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
pooled_output = outputs[1]
|
| 762 |
+
|
| 763 |
+
pooled_output = self.dropout(pooled_output)
|
| 764 |
+
logits = self.classifier(pooled_output)
|
| 765 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 766 |
+
|
| 767 |
+
loss = None
|
| 768 |
+
if labels is not None:
|
| 769 |
+
loss_fct = CrossEntropyLoss()
|
| 770 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 771 |
+
|
| 772 |
+
if not return_dict:
|
| 773 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 774 |
+
return ((loss,) + output) if loss is not None else output
|
| 775 |
+
|
| 776 |
+
return MultipleChoiceModelOutput(
|
| 777 |
+
loss=loss,
|
| 778 |
+
logits=reshaped_logits,
|
| 779 |
+
hidden_states=outputs.hidden_states,
|
| 780 |
+
attentions=outputs.attentions,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
@add_start_docstrings(
|
| 785 |
+
"""ErnieM Model with a token classification head on top (a linear layer on top of
|
| 786 |
+
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.""",
|
| 787 |
+
ERNIE_M_START_DOCSTRING,
|
| 788 |
+
)
|
| 789 |
+
class ErnieMForTokenClassification(ErnieMPreTrainedModel):
|
| 790 |
+
def __init__(self, config):
|
| 791 |
+
super().__init__(config)
|
| 792 |
+
self.num_labels = config.num_labels
|
| 793 |
+
|
| 794 |
+
self.ernie_m = ErnieMModel(config, add_pooling_layer=False)
|
| 795 |
+
classifier_dropout = (
|
| 796 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 797 |
+
)
|
| 798 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 799 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 800 |
+
|
| 801 |
+
# Initialize weights and apply final processing
|
| 802 |
+
self.post_init()
|
| 803 |
+
|
| 804 |
+
@add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 805 |
+
@add_code_sample_docstrings(
|
| 806 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
| 807 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 808 |
+
output_type=TokenClassifierOutput,
|
| 809 |
+
config_class=_CONFIG_FOR_DOC,
|
| 810 |
+
)
|
| 811 |
+
def forward(
|
| 812 |
+
self,
|
| 813 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 814 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 815 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 816 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 817 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 818 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
| 819 |
+
output_hidden_states: Optional[bool] = None,
|
| 820 |
+
output_attentions: Optional[bool] = None,
|
| 821 |
+
return_dict: Optional[bool] = True,
|
| 822 |
+
labels: Optional[torch.Tensor] = None,
|
| 823 |
+
) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
|
| 824 |
+
r"""
|
| 825 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 826 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 827 |
+
"""
|
| 828 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 829 |
+
|
| 830 |
+
outputs = self.ernie_m(
|
| 831 |
+
input_ids,
|
| 832 |
+
attention_mask=attention_mask,
|
| 833 |
+
position_ids=position_ids,
|
| 834 |
+
head_mask=head_mask,
|
| 835 |
+
inputs_embeds=inputs_embeds,
|
| 836 |
+
past_key_values=past_key_values,
|
| 837 |
+
output_attentions=output_attentions,
|
| 838 |
+
output_hidden_states=output_hidden_states,
|
| 839 |
+
return_dict=return_dict,
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
sequence_output = outputs[0]
|
| 843 |
+
|
| 844 |
+
sequence_output = self.dropout(sequence_output)
|
| 845 |
+
logits = self.classifier(sequence_output)
|
| 846 |
+
|
| 847 |
+
loss = None
|
| 848 |
+
if labels is not None:
|
| 849 |
+
loss_fct = CrossEntropyLoss()
|
| 850 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 851 |
+
|
| 852 |
+
if not return_dict:
|
| 853 |
+
output = (logits,) + outputs[2:]
|
| 854 |
+
return ((loss,) + output) if loss is not None else output
|
| 855 |
+
|
| 856 |
+
return TokenClassifierOutput(
|
| 857 |
+
loss=loss,
|
| 858 |
+
logits=logits,
|
| 859 |
+
hidden_states=outputs.hidden_states,
|
| 860 |
+
attentions=outputs.attentions,
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
@add_start_docstrings(
|
| 865 |
+
"""ErnieM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 866 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).""",
|
| 867 |
+
ERNIE_M_START_DOCSTRING,
|
| 868 |
+
)
|
| 869 |
+
class ErnieMForQuestionAnswering(ErnieMPreTrainedModel):
|
| 870 |
+
def __init__(self, config):
|
| 871 |
+
super().__init__(config)
|
| 872 |
+
self.num_labels = config.num_labels
|
| 873 |
+
|
| 874 |
+
self.ernie_m = ErnieMModel(config, add_pooling_layer=False)
|
| 875 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 876 |
+
|
| 877 |
+
# Initialize weights and apply final processing
|
| 878 |
+
self.post_init()
|
| 879 |
+
|
| 880 |
+
@add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 881 |
+
@add_code_sample_docstrings(
|
| 882 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
| 883 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 884 |
+
output_type=QuestionAnsweringModelOutput,
|
| 885 |
+
config_class=_CONFIG_FOR_DOC,
|
| 886 |
+
)
|
| 887 |
+
def forward(
|
| 888 |
+
self,
|
| 889 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 890 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 891 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 892 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 893 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 894 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 895 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 896 |
+
output_attentions: Optional[bool] = None,
|
| 897 |
+
output_hidden_states: Optional[bool] = None,
|
| 898 |
+
return_dict: Optional[bool] = True,
|
| 899 |
+
) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
|
| 900 |
+
r"""
|
| 901 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 902 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 903 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 904 |
+
are not taken into account for computing the loss.
|
| 905 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 906 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 907 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 908 |
+
are not taken into account for computing the loss.
|
| 909 |
+
"""
|
| 910 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 911 |
+
|
| 912 |
+
outputs = self.ernie_m(
|
| 913 |
+
input_ids,
|
| 914 |
+
attention_mask=attention_mask,
|
| 915 |
+
position_ids=position_ids,
|
| 916 |
+
head_mask=head_mask,
|
| 917 |
+
inputs_embeds=inputs_embeds,
|
| 918 |
+
output_attentions=output_attentions,
|
| 919 |
+
output_hidden_states=output_hidden_states,
|
| 920 |
+
return_dict=return_dict,
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
sequence_output = outputs[0]
|
| 924 |
+
|
| 925 |
+
logits = self.qa_outputs(sequence_output)
|
| 926 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 927 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 928 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 929 |
+
|
| 930 |
+
total_loss = None
|
| 931 |
+
if start_positions is not None and end_positions is not None:
|
| 932 |
+
# If we are on multi-GPU, split add a dimension
|
| 933 |
+
if len(start_positions.size()) > 1:
|
| 934 |
+
start_positions = start_positions.squeeze(-1)
|
| 935 |
+
if len(end_positions.size()) > 1:
|
| 936 |
+
end_positions = end_positions.squeeze(-1)
|
| 937 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 938 |
+
ignored_index = start_logits.size(1)
|
| 939 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 940 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 941 |
+
|
| 942 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 943 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 944 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 945 |
+
total_loss = (start_loss + end_loss) / 2
|
| 946 |
+
|
| 947 |
+
if not return_dict:
|
| 948 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 949 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 950 |
+
|
| 951 |
+
return QuestionAnsweringModelOutput(
|
| 952 |
+
loss=total_loss,
|
| 953 |
+
start_logits=start_logits,
|
| 954 |
+
end_logits=end_logits,
|
| 955 |
+
hidden_states=outputs.hidden_states,
|
| 956 |
+
attentions=outputs.attentions,
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
@add_start_docstrings(
|
| 961 |
+
"""ErnieMForInformationExtraction is a Ernie-M Model with two linear layer on top of the hidden-states output to
|
| 962 |
+
compute `start_prob` and `end_prob`, designed for Universal Information Extraction.""",
|
| 963 |
+
ERNIE_M_START_DOCSTRING,
|
| 964 |
+
)
|
| 965 |
+
class ErnieMForInformationExtraction(ErnieMPreTrainedModel):
|
| 966 |
+
def __init__(self, config):
|
| 967 |
+
super(ErnieMForInformationExtraction, self).__init__(config)
|
| 968 |
+
self.ernie_m = ErnieMModel(config)
|
| 969 |
+
self.linear_start = nn.Linear(config.hidden_size, 1)
|
| 970 |
+
self.linear_end = nn.Linear(config.hidden_size, 1)
|
| 971 |
+
self.sigmoid = nn.Sigmoid()
|
| 972 |
+
self.post_init()
|
| 973 |
+
|
| 974 |
+
@add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
| 975 |
+
def forward(
|
| 976 |
+
self,
|
| 977 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 978 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 979 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 980 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 981 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 982 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 983 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 984 |
+
output_attentions: Optional[bool] = None,
|
| 985 |
+
output_hidden_states: Optional[bool] = None,
|
| 986 |
+
return_dict: Optional[bool] = True,
|
| 987 |
+
) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
|
| 988 |
+
r"""
|
| 989 |
+
start_positions (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 990 |
+
Labels for position (index) for computing the start_positions loss. Position outside of the sequence are
|
| 991 |
+
not taken into account for computing the loss.
|
| 992 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 993 |
+
Labels for position (index) for computing the end_positions loss. Position outside of the sequence are not
|
| 994 |
+
taken into account for computing the loss.
|
| 995 |
+
"""
|
| 996 |
+
|
| 997 |
+
result = self.ernie_m(
|
| 998 |
+
input_ids,
|
| 999 |
+
attention_mask=attention_mask,
|
| 1000 |
+
position_ids=position_ids,
|
| 1001 |
+
head_mask=head_mask,
|
| 1002 |
+
inputs_embeds=inputs_embeds,
|
| 1003 |
+
output_attentions=output_attentions,
|
| 1004 |
+
output_hidden_states=output_hidden_states,
|
| 1005 |
+
return_dict=return_dict,
|
| 1006 |
+
)
|
| 1007 |
+
if return_dict:
|
| 1008 |
+
sequence_output = result.last_hidden_state
|
| 1009 |
+
elif not return_dict:
|
| 1010 |
+
sequence_output = result[0]
|
| 1011 |
+
|
| 1012 |
+
start_logits = self.linear_start(sequence_output)
|
| 1013 |
+
start_logits = start_logits.squeeze(-1)
|
| 1014 |
+
end_logits = self.linear_end(sequence_output)
|
| 1015 |
+
end_logits = end_logits.squeeze(-1)
|
| 1016 |
+
|
| 1017 |
+
total_loss = None
|
| 1018 |
+
if start_positions is not None and end_positions is not None:
|
| 1019 |
+
# If we are on multi-GPU, split add a dimension
|
| 1020 |
+
if len(start_positions.size()) > 1:
|
| 1021 |
+
start_positions = start_positions.squeeze(-1)
|
| 1022 |
+
if len(end_positions.size()) > 1:
|
| 1023 |
+
end_positions = end_positions.squeeze(-1)
|
| 1024 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1025 |
+
ignored_index = start_logits.size(1)
|
| 1026 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1027 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1028 |
+
|
| 1029 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1030 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1031 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1032 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1033 |
+
|
| 1034 |
+
if not return_dict:
|
| 1035 |
+
return tuple(
|
| 1036 |
+
i
|
| 1037 |
+
for i in [total_loss, start_logits, end_logits, result.hidden_states, result.attentions]
|
| 1038 |
+
if i is not None
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
return QuestionAnsweringModelOutput(
|
| 1042 |
+
loss=total_loss,
|
| 1043 |
+
start_logits=start_logits,
|
| 1044 |
+
end_logits=end_logits,
|
| 1045 |
+
hidden_states=result.hidden_states,
|
| 1046 |
+
attentions=result.attentions,
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
__all__ = [
|
| 1051 |
+
"ErnieMForMultipleChoice",
|
| 1052 |
+
"ErnieMForQuestionAnswering",
|
| 1053 |
+
"ErnieMForSequenceClassification",
|
| 1054 |
+
"ErnieMForTokenClassification",
|
| 1055 |
+
"ErnieMModel",
|
| 1056 |
+
"ErnieMPreTrainedModel",
|
| 1057 |
+
"ErnieMForInformationExtraction",
|
| 1058 |
+
]
|
docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for Ernie-M."""
|
| 16 |
+
|
| 17 |
+
import io
|
| 18 |
+
import os
|
| 19 |
+
import unicodedata
|
| 20 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import sentencepiece as spm
|
| 23 |
+
|
| 24 |
+
from ....tokenization_utils import PreTrainedTokenizer
|
| 25 |
+
from ....utils import logging
|
| 26 |
+
from ....utils.import_utils import requires
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
SPIECE_UNDERLINE = "▁"
|
| 32 |
+
|
| 33 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "sentencepiece_model_ckpt": "sentencepiece.bpe.model"}
|
| 34 |
+
|
| 35 |
+
RESOURCE_FILES_NAMES = {
|
| 36 |
+
"sentencepiece_model_file": "sentencepiece.bpe.model",
|
| 37 |
+
"vocab_file": "vocab.txt",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Adapted from paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer
|
| 42 |
+
@requires(backends=("sentencepiece",))
|
| 43 |
+
class ErnieMTokenizer(PreTrainedTokenizer):
|
| 44 |
+
r"""
|
| 45 |
+
Constructs a Ernie-M tokenizer. It uses the `sentencepiece` tools to cut the words to sub-words.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
sentencepiece_model_file (`str`):
|
| 49 |
+
The file path of sentencepiece model.
|
| 50 |
+
vocab_file (`str`, *optional*):
|
| 51 |
+
The file path of the vocabulary.
|
| 52 |
+
do_lower_case (`str`, *optional*, defaults to `True`):
|
| 53 |
+
Whether or not to lowercase the input when tokenizing.
|
| 54 |
+
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
|
| 55 |
+
A special token representing the `unknown (out-of-vocabulary)` token. An unknown token is set to be
|
| 56 |
+
`unk_token` inorder to be converted to an ID.
|
| 57 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 58 |
+
A special token separating two different sentences in the same input.
|
| 59 |
+
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
|
| 60 |
+
A special token used to make arrays of tokens the same size for batching purposes.
|
| 61 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 62 |
+
A special token used for sequence classification. It is the last token of the sequence when built with
|
| 63 |
+
special tokens.
|
| 64 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
| 65 |
+
A special token representing a masked token. This is the token used in the masked language modeling task
|
| 66 |
+
which the model tries to predict the original unmasked ones.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
# Ernie-M model doesn't have token_type embedding.
|
| 70 |
+
model_input_names: List[str] = ["input_ids"]
|
| 71 |
+
|
| 72 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 73 |
+
resource_files_names = RESOURCE_FILES_NAMES
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
sentencepiece_model_ckpt,
|
| 78 |
+
vocab_file=None,
|
| 79 |
+
do_lower_case=False,
|
| 80 |
+
encoding="utf8",
|
| 81 |
+
unk_token="[UNK]",
|
| 82 |
+
sep_token="[SEP]",
|
| 83 |
+
pad_token="[PAD]",
|
| 84 |
+
cls_token="[CLS]",
|
| 85 |
+
mask_token="[MASK]",
|
| 86 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 87 |
+
**kwargs,
|
| 88 |
+
) -> None:
|
| 89 |
+
# Mask token behave like a normal word, i.e. include the space before it and
|
| 90 |
+
# is included in the raw text, there should be a match in a non-normalized sentence.
|
| 91 |
+
|
| 92 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 93 |
+
|
| 94 |
+
self.do_lower_case = do_lower_case
|
| 95 |
+
self.sentencepiece_model_ckpt = sentencepiece_model_ckpt
|
| 96 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 97 |
+
self.sp_model.Load(sentencepiece_model_ckpt)
|
| 98 |
+
|
| 99 |
+
# to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning
|
| 100 |
+
if vocab_file is not None:
|
| 101 |
+
self.vocab = self.load_vocab(filepath=vocab_file)
|
| 102 |
+
else:
|
| 103 |
+
self.vocab = {self.sp_model.id_to_piece(id): id for id in range(self.sp_model.get_piece_size())}
|
| 104 |
+
self.reverse_vocab = {v: k for k, v in self.vocab.items()}
|
| 105 |
+
|
| 106 |
+
super().__init__(
|
| 107 |
+
do_lower_case=do_lower_case,
|
| 108 |
+
unk_token=unk_token,
|
| 109 |
+
sep_token=sep_token,
|
| 110 |
+
pad_token=pad_token,
|
| 111 |
+
cls_token=cls_token,
|
| 112 |
+
mask_token=mask_token,
|
| 113 |
+
vocab_file=vocab_file,
|
| 114 |
+
encoding=encoding,
|
| 115 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 116 |
+
**kwargs,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def get_offset_mapping(self, text):
|
| 120 |
+
if text is None:
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
split_tokens = self.tokenize(text)
|
| 124 |
+
normalized_text, char_mapping = "", []
|
| 125 |
+
|
| 126 |
+
for i, ch in enumerate(text):
|
| 127 |
+
if ch in self.SP_CHAR_MAPPING:
|
| 128 |
+
ch = self.SP_CHAR_MAPPING.get(ch)
|
| 129 |
+
else:
|
| 130 |
+
ch = unicodedata.normalize("NFKC", ch)
|
| 131 |
+
if self.is_whitespace(ch):
|
| 132 |
+
continue
|
| 133 |
+
normalized_text += ch
|
| 134 |
+
char_mapping.extend([i] * len(ch))
|
| 135 |
+
|
| 136 |
+
text, token_mapping, offset = normalized_text, [], 0
|
| 137 |
+
|
| 138 |
+
if self.do_lower_case:
|
| 139 |
+
text = text.lower()
|
| 140 |
+
|
| 141 |
+
for token in split_tokens:
|
| 142 |
+
if token[:1] == "▁":
|
| 143 |
+
token = token[1:]
|
| 144 |
+
start = text[offset:].index(token) + offset
|
| 145 |
+
end = start + len(token)
|
| 146 |
+
|
| 147 |
+
token_mapping.append((char_mapping[start], char_mapping[end - 1] + 1))
|
| 148 |
+
offset = end
|
| 149 |
+
return token_mapping
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def vocab_size(self):
|
| 153 |
+
return len(self.vocab)
|
| 154 |
+
|
| 155 |
+
def get_vocab(self):
|
| 156 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 157 |
+
|
| 158 |
+
def __getstate__(self):
|
| 159 |
+
state = self.__dict__.copy()
|
| 160 |
+
state["sp_model"] = None
|
| 161 |
+
return state
|
| 162 |
+
|
| 163 |
+
def __setstate__(self, d):
|
| 164 |
+
self.__dict__ = d
|
| 165 |
+
|
| 166 |
+
# for backward compatibility
|
| 167 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 168 |
+
self.sp_model_kwargs = {}
|
| 169 |
+
|
| 170 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 171 |
+
self.sp_model.Load(self.sentencepiece_model_ckpt)
|
| 172 |
+
|
| 173 |
+
def clean_text(self, text):
|
| 174 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
| 175 |
+
return "".join((self.SP_CHAR_MAPPING.get(c, c) for c in text))
|
| 176 |
+
|
| 177 |
+
def _tokenize(self, text, enable_sampling=False, nbest_size=64, alpha=0.1):
|
| 178 |
+
"""Tokenize a string."""
|
| 179 |
+
|
| 180 |
+
if self.sp_model_kwargs.get("enable_sampling") is True:
|
| 181 |
+
enable_sampling = True
|
| 182 |
+
if self.sp_model_kwargs.get("alpha") is not None:
|
| 183 |
+
alpha = self.sp_model_kwargs.get("alpha")
|
| 184 |
+
if self.sp_model_kwargs.get("nbest_size") is not None:
|
| 185 |
+
nbest_size = self.sp_model_kwargs.get("nbest_size")
|
| 186 |
+
|
| 187 |
+
if not enable_sampling:
|
| 188 |
+
pieces = self.sp_model.EncodeAsPieces(text)
|
| 189 |
+
else:
|
| 190 |
+
pieces = self.sp_model.SampleEncodeAsPieces(text, nbest_size, alpha)
|
| 191 |
+
new_pieces = []
|
| 192 |
+
for pi, piece in enumerate(pieces):
|
| 193 |
+
if piece == SPIECE_UNDERLINE:
|
| 194 |
+
if not pieces[pi + 1].startswith(SPIECE_UNDERLINE) and pi != 0:
|
| 195 |
+
new_pieces.append(SPIECE_UNDERLINE)
|
| 196 |
+
continue
|
| 197 |
+
else:
|
| 198 |
+
continue
|
| 199 |
+
lst_i = 0
|
| 200 |
+
for i, chunk in enumerate(piece):
|
| 201 |
+
if chunk == SPIECE_UNDERLINE:
|
| 202 |
+
continue
|
| 203 |
+
if self.is_ch_char(chunk) or self.is_punct(chunk):
|
| 204 |
+
if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
|
| 205 |
+
new_pieces.append(piece[lst_i:i])
|
| 206 |
+
new_pieces.append(chunk)
|
| 207 |
+
lst_i = i + 1
|
| 208 |
+
elif chunk.isdigit() and i > 0 and not piece[i - 1].isdigit():
|
| 209 |
+
if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
|
| 210 |
+
new_pieces.append(piece[lst_i:i])
|
| 211 |
+
lst_i = i
|
| 212 |
+
elif not chunk.isdigit() and i > 0 and piece[i - 1].isdigit():
|
| 213 |
+
if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
|
| 214 |
+
new_pieces.append(piece[lst_i:i])
|
| 215 |
+
lst_i = i
|
| 216 |
+
if len(piece) > lst_i:
|
| 217 |
+
new_pieces.append(piece[lst_i:])
|
| 218 |
+
return new_pieces
|
| 219 |
+
|
| 220 |
+
def convert_tokens_to_string(self, tokens):
|
| 221 |
+
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
| 222 |
+
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
| 223 |
+
return out_string
|
| 224 |
+
|
| 225 |
+
def convert_ids_to_string(self, ids):
|
| 226 |
+
"""
|
| 227 |
+
Converts a sequence of tokens (strings for sub-words) in a single string.
|
| 228 |
+
"""
|
| 229 |
+
tokens = self.convert_ids_to_tokens(ids)
|
| 230 |
+
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
| 231 |
+
return out_string
|
| 232 |
+
|
| 233 |
+
# to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning
|
| 234 |
+
def _convert_token_to_id(self, token):
|
| 235 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 236 |
+
|
| 237 |
+
# to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning
|
| 238 |
+
def _convert_id_to_token(self, index):
|
| 239 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 240 |
+
return self.reverse_vocab.get(index, self.unk_token)
|
| 241 |
+
|
| 242 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 243 |
+
r"""
|
| 244 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 245 |
+
adding special tokens. An ErnieM sequence has the following format:
|
| 246 |
+
|
| 247 |
+
- single sequence: `[CLS] X [SEP]`
|
| 248 |
+
- pair of sequences: `[CLS] A [SEP] [SEP] B [SEP]`
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
token_ids_0 (`List[int]`):
|
| 252 |
+
List of IDs to which the special tokens will be added.
|
| 253 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 254 |
+
Optional second list of IDs for sequence pairs.
|
| 255 |
+
Returns:
|
| 256 |
+
`List[int]`: List of input_id with the appropriate special tokens.
|
| 257 |
+
"""
|
| 258 |
+
if token_ids_1 is None:
|
| 259 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 260 |
+
_cls = [self.cls_token_id]
|
| 261 |
+
_sep = [self.sep_token_id]
|
| 262 |
+
return _cls + token_ids_0 + _sep + _sep + token_ids_1 + _sep
|
| 263 |
+
|
| 264 |
+
def build_offset_mapping_with_special_tokens(self, offset_mapping_0, offset_mapping_1=None):
|
| 265 |
+
r"""
|
| 266 |
+
Build offset map from a pair of offset map by concatenating and adding offsets of special tokens. An Ernie-M
|
| 267 |
+
offset_mapping has the following format:
|
| 268 |
+
|
| 269 |
+
- single sequence: `(0,0) X (0,0)`
|
| 270 |
+
- pair of sequences: `(0,0) A (0,0) (0,0) B (0,0)`
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
offset_mapping_ids_0 (`List[tuple]`):
|
| 274 |
+
List of char offsets to which the special tokens will be added.
|
| 275 |
+
offset_mapping_ids_1 (`List[tuple]`, *optional*):
|
| 276 |
+
Optional second list of wordpiece offsets for offset mapping pairs.
|
| 277 |
+
Returns:
|
| 278 |
+
`List[tuple]`: List of wordpiece offsets with the appropriate offsets of special tokens.
|
| 279 |
+
"""
|
| 280 |
+
if offset_mapping_1 is None:
|
| 281 |
+
return [(0, 0)] + offset_mapping_0 + [(0, 0)]
|
| 282 |
+
|
| 283 |
+
return [(0, 0)] + offset_mapping_0 + [(0, 0), (0, 0)] + offset_mapping_1 + [(0, 0)]
|
| 284 |
+
|
| 285 |
+
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
| 286 |
+
r"""
|
| 287 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 288 |
+
special tokens using the tokenizer `encode` method.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
token_ids_0 (`List[int]`):
|
| 292 |
+
List of ids of the first sequence.
|
| 293 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 294 |
+
Optional second list of IDs for sequence pairs.
|
| 295 |
+
already_has_special_tokens (`str`, *optional*, defaults to `False`):
|
| 296 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 297 |
+
Returns:
|
| 298 |
+
`List[int]`:
|
| 299 |
+
The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
if already_has_special_tokens:
|
| 303 |
+
if token_ids_1 is not None:
|
| 304 |
+
raise ValueError(
|
| 305 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 306 |
+
"ids is already formatted with special tokens for the model."
|
| 307 |
+
)
|
| 308 |
+
return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
|
| 309 |
+
|
| 310 |
+
if token_ids_1 is not None:
|
| 311 |
+
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
| 312 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 313 |
+
|
| 314 |
+
def create_token_type_ids_from_sequences(
|
| 315 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 316 |
+
) -> List[int]:
|
| 317 |
+
"""
|
| 318 |
+
Create the token type IDs corresponding to the sequences passed. [What are token type
|
| 319 |
+
IDs?](../glossary#token-type-ids) Should be overridden in a subclass if the model has a special way of
|
| 320 |
+
building: those.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
token_ids_0 (`List[int]`):
|
| 324 |
+
The first tokenized sequence.
|
| 325 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 326 |
+
The second tokenized sequence.
|
| 327 |
+
Returns:
|
| 328 |
+
`List[int]`: The token type ids.
|
| 329 |
+
"""
|
| 330 |
+
# called when `add_special_tokens` is True, so align with `build_inputs_with_special_tokens` method
|
| 331 |
+
if token_ids_1 is None:
|
| 332 |
+
# [CLS] X [SEP]
|
| 333 |
+
return (len(token_ids_0) + 2) * [0]
|
| 334 |
+
|
| 335 |
+
# [CLS] A [SEP] [SEP] B [SEP]
|
| 336 |
+
return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3)
|
| 337 |
+
|
| 338 |
+
def is_ch_char(self, char):
|
| 339 |
+
"""
|
| 340 |
+
is_ch_char
|
| 341 |
+
"""
|
| 342 |
+
if "\u4e00" <= char <= "\u9fff":
|
| 343 |
+
return True
|
| 344 |
+
return False
|
| 345 |
+
|
| 346 |
+
def is_alpha(self, char):
|
| 347 |
+
"""
|
| 348 |
+
is_alpha
|
| 349 |
+
"""
|
| 350 |
+
if ("a" <= char <= "z") or ("A" <= char <= "Z"):
|
| 351 |
+
return True
|
| 352 |
+
return False
|
| 353 |
+
|
| 354 |
+
def is_punct(self, char):
|
| 355 |
+
"""
|
| 356 |
+
is_punct
|
| 357 |
+
"""
|
| 358 |
+
if char in ",;:.?!~,;:。?!《》【】":
|
| 359 |
+
return True
|
| 360 |
+
return False
|
| 361 |
+
|
| 362 |
+
def is_whitespace(self, char):
|
| 363 |
+
"""
|
| 364 |
+
is whitespace
|
| 365 |
+
"""
|
| 366 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
| 367 |
+
return True
|
| 368 |
+
if len(char) == 1:
|
| 369 |
+
cat = unicodedata.category(char)
|
| 370 |
+
if cat == "Zs":
|
| 371 |
+
return True
|
| 372 |
+
return False
|
| 373 |
+
|
| 374 |
+
def load_vocab(self, filepath):
|
| 375 |
+
token_to_idx = {}
|
| 376 |
+
with io.open(filepath, "r", encoding="utf-8") as f:
|
| 377 |
+
for index, line in enumerate(f):
|
| 378 |
+
token = line.rstrip("\n")
|
| 379 |
+
token_to_idx[token] = int(index)
|
| 380 |
+
|
| 381 |
+
return token_to_idx
|
| 382 |
+
|
| 383 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 384 |
+
index = 0
|
| 385 |
+
if os.path.isdir(save_directory):
|
| 386 |
+
vocab_file = os.path.join(
|
| 387 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 388 |
+
)
|
| 389 |
+
else:
|
| 390 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
| 391 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 392 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 393 |
+
if index != token_index:
|
| 394 |
+
logger.warning(
|
| 395 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
| 396 |
+
" Please check that the vocabulary is not corrupted!"
|
| 397 |
+
)
|
| 398 |
+
index = token_index
|
| 399 |
+
writer.write(token + "\n")
|
| 400 |
+
index += 1
|
| 401 |
+
|
| 402 |
+
tokenizer_model_file = os.path.join(save_directory, "sentencepiece.bpe.model")
|
| 403 |
+
with open(tokenizer_model_file, "wb") as fi:
|
| 404 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 405 |
+
fi.write(content_spiece_model)
|
| 406 |
+
|
| 407 |
+
return (vocab_file,)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
__all__ = ["ErnieMTokenizer"]
|
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_gptsan_japanese import *
|
| 22 |
+
from .modeling_gptsan_japanese import *
|
| 23 |
+
from .tokenization_gptsan_japanese import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023, HuggingFace Inc.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""GPTSAN-japanese model configuration"""
|
| 16 |
+
|
| 17 |
+
from ....configuration_utils import PretrainedConfig
|
| 18 |
+
from ....utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GPTSanJapaneseConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`GPTSanJapaneseModel`]. It is used to instantiate
|
| 27 |
+
a GPTSANJapanese model according to the specified arguments, defining the model architecture. Instantiating a
|
| 28 |
+
configuration with the defaults will yield a similar configuration to that of the GPTSANJapanese
|
| 29 |
+
[Tanrei/GPTSAN-japanese](https://huggingface.co/Tanrei/GPTSAN-japanese) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Arguments:
|
| 35 |
+
vocab_size (`int`, *optional*, defaults to 36000):
|
| 36 |
+
Vocabulary size of the GPTSANJapanese model. Defines the number of different tokens that can be represented
|
| 37 |
+
by the `inputs_ids` passed when calling [`GPTSanJapaneseModel`].
|
| 38 |
+
max_position_embeddings (`int`, *optional*, defaults to 1280):
|
| 39 |
+
The maximum sequence length that this model might ever be used with. Defaults set this to 1280.
|
| 40 |
+
d_model (`int`, *optional*, defaults to 1024):
|
| 41 |
+
Size of the encoder layers and the pooler layer.
|
| 42 |
+
d_ff (`int`, *optional*, defaults to 8192):
|
| 43 |
+
Size of the intermediate feed forward layer in each `SwitchTransformersBlock`.
|
| 44 |
+
d_ext (`int`, *optional*, defaults to 4096):
|
| 45 |
+
Size of the intermediate feed forward layer in each Extra-layers.
|
| 46 |
+
d_spout (`int`, *optional*, defaults to 128):
|
| 47 |
+
Size of the `spout` vector.
|
| 48 |
+
num_switch_layers (`int`, *optional*, defaults to 10):
|
| 49 |
+
Number of layers in the Switch Transformer layer.
|
| 50 |
+
num_ext_layers (`int`, *optional*, defaults to 0):
|
| 51 |
+
Number of layers in the Extra-layers.
|
| 52 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 53 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 54 |
+
num_experts (`int`, *optional*, defaults to 16):
|
| 55 |
+
Number of experts for each SwitchTransformer layer.
|
| 56 |
+
expert_capacity (`int`, *optional*, defaults to 128):
|
| 57 |
+
Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular
|
| 58 |
+
Transformer.
|
| 59 |
+
dropout_rate (`float`, *optional*, defaults to 0.0):
|
| 60 |
+
The ratio for all dropout layers.
|
| 61 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 62 |
+
The epsilon used by the layer normalization layers.
|
| 63 |
+
router_bias (`bool`, *optional*, defaults to `False`):
|
| 64 |
+
Whether to add a bias to the router.
|
| 65 |
+
router_jitter_noise (`float`, *optional*, defaults to 0.0):
|
| 66 |
+
Amount of noise to add to the router. Set it to 0.0 during prediction or set small value (usually 1e-2)
|
| 67 |
+
during training.
|
| 68 |
+
router_dtype (`str`, *optional*, default to `"float32"`):
|
| 69 |
+
The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the
|
| 70 |
+
*selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961).
|
| 71 |
+
router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether to ignore padding tokens when routing.
|
| 73 |
+
output_hidden_states (`bool`, *optional*, default to `False`):
|
| 74 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 75 |
+
more detail.
|
| 76 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 77 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 78 |
+
initializer_factor (`float`, *optional*, defaults to 0.002):
|
| 79 |
+
A factor for initializing all weight matrices.
|
| 80 |
+
output_router_logits (`bool`, *optional*, default to `False`):
|
| 81 |
+
Whether or not to return the router logits of all experts.
|
| 82 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 83 |
+
Whether or not the model should return the last key/values attentions (not used by all models)
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
model_type = "gptsan-japanese"
|
| 87 |
+
keys_to_ignore_at_inference = [
|
| 88 |
+
"past_key_values",
|
| 89 |
+
]
|
| 90 |
+
attribute_map = {
|
| 91 |
+
"hidden_size": "d_model",
|
| 92 |
+
"num_attention_heads": "num_heads",
|
| 93 |
+
"num_hidden_layers": "num_layers",
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
vocab_size=36000,
|
| 99 |
+
max_position_embeddings=1280,
|
| 100 |
+
d_model=1024,
|
| 101 |
+
d_ff=8192,
|
| 102 |
+
d_ext=4096,
|
| 103 |
+
d_spout=128,
|
| 104 |
+
num_switch_layers=10,
|
| 105 |
+
num_ext_layers=0,
|
| 106 |
+
num_heads=16,
|
| 107 |
+
num_experts=16,
|
| 108 |
+
expert_capacity=128,
|
| 109 |
+
dropout_rate=0.0,
|
| 110 |
+
layer_norm_epsilon=1e-5,
|
| 111 |
+
router_bias=False,
|
| 112 |
+
router_jitter_noise=0.0,
|
| 113 |
+
router_dtype="float32",
|
| 114 |
+
router_ignore_padding_tokens=False,
|
| 115 |
+
output_hidden_states=False,
|
| 116 |
+
output_attentions=False,
|
| 117 |
+
initializer_factor=0.002,
|
| 118 |
+
output_router_logits=False,
|
| 119 |
+
use_cache=True,
|
| 120 |
+
separator_token_id=35998,
|
| 121 |
+
pad_token_id=35995,
|
| 122 |
+
eos_token_id=35999,
|
| 123 |
+
**kwargs,
|
| 124 |
+
):
|
| 125 |
+
self.vocab_size = vocab_size
|
| 126 |
+
self.max_position_embeddings = max_position_embeddings
|
| 127 |
+
self.d_model = d_model
|
| 128 |
+
self.d_ff = d_ff
|
| 129 |
+
self.d_ext = d_ext
|
| 130 |
+
self.d_spout = d_spout
|
| 131 |
+
self.num_switch_layers = num_switch_layers
|
| 132 |
+
self.num_ext_layers = num_ext_layers
|
| 133 |
+
self.num_layers = num_switch_layers + num_ext_layers
|
| 134 |
+
self.num_heads = num_heads
|
| 135 |
+
self.num_experts = num_experts
|
| 136 |
+
self.expert_capacity = expert_capacity
|
| 137 |
+
self.dropout_rate = dropout_rate
|
| 138 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 139 |
+
self.router_bias = router_bias
|
| 140 |
+
self.router_jitter_noise = router_jitter_noise
|
| 141 |
+
self.router_dtype = router_dtype
|
| 142 |
+
self.router_ignore_padding_tokens = router_ignore_padding_tokens
|
| 143 |
+
self.output_hidden_states = output_hidden_states
|
| 144 |
+
self.output_attentions = output_attentions
|
| 145 |
+
self.initializer_factor = initializer_factor
|
| 146 |
+
self.output_router_logits = output_router_logits
|
| 147 |
+
self.use_cache = use_cache
|
| 148 |
+
|
| 149 |
+
super().__init__(
|
| 150 |
+
separator_token_id=separator_token_id,
|
| 151 |
+
pad_token_id=pad_token_id,
|
| 152 |
+
eos_token_id=eos_token_id,
|
| 153 |
+
**kwargs,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
__all__ = ["GPTSanJapaneseConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Convert GPTSANJapanese checkpoints from the original repository to pytorch model."""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def convert_tf_gptsan_to_pt(args):
|
| 29 |
+
parameter_file = os.path.join(args.tf_model_dir, "parameters.json")
|
| 30 |
+
params = json.loads(open(parameter_file).read())
|
| 31 |
+
if not params:
|
| 32 |
+
raise ValueError(
|
| 33 |
+
f"It seems that the json file at {parameter_file} is empty. Make sure you have a correct json file."
|
| 34 |
+
)
|
| 35 |
+
if not args.output.endswith(".pt"):
|
| 36 |
+
args.output = args.output + ".pt"
|
| 37 |
+
new_state = OrderedDict()
|
| 38 |
+
with tf.device("/CPU:0"):
|
| 39 |
+
reader = tf.train.load_checkpoint(args.tf_model_dir)
|
| 40 |
+
shapes = reader.get_variable_to_shape_map()
|
| 41 |
+
for key_name in shapes.keys():
|
| 42 |
+
vnp = reader.get_tensor(key_name).astype(np.float16)
|
| 43 |
+
if key_name.endswith("/adam_m") or key_name.endswith("/adam_v"):
|
| 44 |
+
continue
|
| 45 |
+
if key_name.startswith("pasts/"):
|
| 46 |
+
if key_name.startswith("pasts/mlp"):
|
| 47 |
+
player = int(key_name[9])
|
| 48 |
+
elif key_name.startswith("pasts/out"):
|
| 49 |
+
player = 8
|
| 50 |
+
name = "model.sqout.%d.weight" % (player * 2) # enter to nn.Sequencial with Tanh, so 2 at a time
|
| 51 |
+
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
|
| 52 |
+
new_state[name] = torch.tensor(state)
|
| 53 |
+
elif key_name.startswith("model/moe"):
|
| 54 |
+
player = int(key_name[9:].split("/")[0])
|
| 55 |
+
if key_name.endswith("/switch_gating/kernel"):
|
| 56 |
+
name = "model.blocks.%d.feed_forward.mlp.router.classifier.weight" % player
|
| 57 |
+
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
|
| 58 |
+
new_state[name] = torch.tensor(state)
|
| 59 |
+
elif key_name.endswith("/softmlp/kernel"):
|
| 60 |
+
name = "model.blocks.%d.feed_forward.soft_bypass_mlp.weight" % player
|
| 61 |
+
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
|
| 62 |
+
new_state[name] = torch.tensor(state)
|
| 63 |
+
elif key_name.endswith("/wo/kernel") or key_name.endswith("/wi/kernel"):
|
| 64 |
+
nlayer = key_name[-9:-7]
|
| 65 |
+
for i in range(16):
|
| 66 |
+
name = "model.blocks.%d.feed_forward.mlp.experts.expert_%d.%s.weight" % (player, i, nlayer)
|
| 67 |
+
state = (
|
| 68 |
+
vnp[i].transpose([1, 0]).copy()
|
| 69 |
+
) # In Mesh-Tensorflow, it is one array, so it is divided
|
| 70 |
+
new_state[name] = torch.tensor(state)
|
| 71 |
+
elif key_name.startswith("model/mlp"):
|
| 72 |
+
player = int(key_name[9:].split("/")[0])
|
| 73 |
+
if key_name.endswith("/p1/kernel"):
|
| 74 |
+
name = "model.blocks.%d.feed_forward.mlp.wi.weight" % player
|
| 75 |
+
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
|
| 76 |
+
new_state[name] = torch.tensor(state)
|
| 77 |
+
elif key_name.endswith("/p1/bias"):
|
| 78 |
+
name = "model.blocks.%d.feed_forward.mlp.wi.bias" % player
|
| 79 |
+
state = vnp.copy() # same because it is one dimensional
|
| 80 |
+
new_state[name] = torch.tensor(state)
|
| 81 |
+
elif key_name.endswith("/p2/kernel"):
|
| 82 |
+
name = "model.blocks.%d.feed_forward.mlp.wo.weight" % player
|
| 83 |
+
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
|
| 84 |
+
new_state[name] = torch.tensor(state)
|
| 85 |
+
elif key_name.endswith("/p2/bias"):
|
| 86 |
+
name = "model.blocks.%d.feed_forward.mlp.wo.bias" % player
|
| 87 |
+
state = vnp.copy() # same because it is one dimensional
|
| 88 |
+
new_state[name] = torch.tensor(state)
|
| 89 |
+
elif key_name.startswith("model/ln"):
|
| 90 |
+
player = int(key_name[8:].split("/")[0])
|
| 91 |
+
if key_name.endswith("/b"):
|
| 92 |
+
name = "model.blocks.%d.feed_forward.norm.bias" % player
|
| 93 |
+
state = vnp.copy() # same because it is one dimensional
|
| 94 |
+
new_state[name] = torch.tensor(state)
|
| 95 |
+
elif key_name.endswith("/g"):
|
| 96 |
+
name = "model.blocks.%d.feed_forward.norm.weight" % player
|
| 97 |
+
state = vnp.copy() # same because it is one dimensional
|
| 98 |
+
new_state[name] = torch.tensor(state)
|
| 99 |
+
elif key_name.startswith("model/att"):
|
| 100 |
+
player = int(key_name[9:].split("/")[0])
|
| 101 |
+
if key_name.endswith("/qkv/kernel"):
|
| 102 |
+
state = vnp.copy() # Compute same dimension as Mesh-tensorflow using einsum
|
| 103 |
+
state_q = state[:, 0, :, :]
|
| 104 |
+
state_k = state[:, 1, :, :]
|
| 105 |
+
state_v = state[:, 2, :, :]
|
| 106 |
+
state_q = (
|
| 107 |
+
state_q.reshape([state_q.shape[0], state_q.shape[1] * state_q.shape[2]])
|
| 108 |
+
.transpose([1, 0])
|
| 109 |
+
.copy()
|
| 110 |
+
) # Mesh-Tensorflow is a diagonal matrix
|
| 111 |
+
state_k = (
|
| 112 |
+
state_k.reshape([state_k.shape[0], state_k.shape[1] * state_k.shape[2]])
|
| 113 |
+
.transpose([1, 0])
|
| 114 |
+
.copy()
|
| 115 |
+
) # Mesh-Tensorflow is a diagonal matrix
|
| 116 |
+
state_v = (
|
| 117 |
+
state_v.reshape([state_v.shape[0], state_v.shape[1] * state_v.shape[2]])
|
| 118 |
+
.transpose([1, 0])
|
| 119 |
+
.copy()
|
| 120 |
+
) # Mesh-Tensorflow is a diagonal matrix
|
| 121 |
+
name = "model.blocks.%d.self_attn.self_attn.q_proj.weight" % player
|
| 122 |
+
new_state[name] = torch.tensor(state_q)
|
| 123 |
+
name = "model.blocks.%d.self_attn.self_attn.k_proj.weight" % player
|
| 124 |
+
new_state[name] = torch.tensor(state_k)
|
| 125 |
+
name = "model.blocks.%d.self_attn.self_attn.v_proj.weight" % player
|
| 126 |
+
new_state[name] = torch.tensor(state_v)
|
| 127 |
+
elif key_name.endswith("/o/kernel"):
|
| 128 |
+
name = "model.blocks.%d.self_attn.self_attn.out_proj.weight" % player
|
| 129 |
+
state = (
|
| 130 |
+
vnp.reshape([vnp.shape[0] * vnp.shape[1], vnp.shape[2]]).transpose([1, 0]).copy()
|
| 131 |
+
) # Mesh-Tensorflow is a diagonal matrix
|
| 132 |
+
new_state[name] = torch.tensor(state)
|
| 133 |
+
elif key_name.startswith("model/an"):
|
| 134 |
+
player = int(key_name[8:].split("/")[0])
|
| 135 |
+
if key_name.endswith("/b"):
|
| 136 |
+
name = "model.blocks.%d.self_attn.norm.bias" % player
|
| 137 |
+
state = vnp.copy() # same because it is one dimensional
|
| 138 |
+
new_state[name] = torch.tensor(state)
|
| 139 |
+
elif key_name.endswith("/g"):
|
| 140 |
+
name = "model.blocks.%d.self_attn.norm.weight" % player
|
| 141 |
+
state = vnp.copy() # same because it is one dimensional
|
| 142 |
+
new_state[name] = torch.tensor(state)
|
| 143 |
+
elif (
|
| 144 |
+
key_name.startswith("model/wte")
|
| 145 |
+
or key_name.startswith("model/wpe")
|
| 146 |
+
or key_name.startswith("model/ete")
|
| 147 |
+
):
|
| 148 |
+
nlayer = {"wte": "embed_tokens", "wpe": "position_embeddings", "ete": "extra_position_embeddings"}[
|
| 149 |
+
key_name[-3:]
|
| 150 |
+
]
|
| 151 |
+
name = "model.%s.weight" % nlayer
|
| 152 |
+
state = vnp.copy() # same in embedded
|
| 153 |
+
new_state[name] = torch.tensor(state)
|
| 154 |
+
if key_name.startswith("model/wte"):
|
| 155 |
+
name = "lm_head.weight"
|
| 156 |
+
state = vnp.copy() # same in embedded
|
| 157 |
+
new_state[name] = torch.tensor(state)
|
| 158 |
+
elif key_name.startswith("model/wob"):
|
| 159 |
+
name = "final_logits_bias"
|
| 160 |
+
state = vnp.copy() # same in embedded
|
| 161 |
+
state = state.reshape((1, -1))
|
| 162 |
+
new_state[name] = torch.tensor(state)
|
| 163 |
+
elif key_name == "model/dense/kernel":
|
| 164 |
+
name = "model.last_project.weight"
|
| 165 |
+
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
|
| 166 |
+
new_state[name] = torch.tensor(state)
|
| 167 |
+
elif key_name == "model/dense_1/bias":
|
| 168 |
+
name = "model.last_project.bias"
|
| 169 |
+
state = vnp.copy() # same because it is one dimensional
|
| 170 |
+
new_state[name] = torch.tensor(state)
|
| 171 |
+
torch.save(new_state, args.output)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
parser = argparse.ArgumentParser(
|
| 176 |
+
description="model converter.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument("--tf_model_dir", metavar="PATH", type=str, required=True, help="import model")
|
| 179 |
+
parser.add_argument("--output", metavar="PATH", type=str, required=True, help="output model")
|
| 180 |
+
args = parser.parse_args()
|
| 181 |
+
convert_tf_gptsan_to_pt(args)
|
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py
ADDED
|
@@ -0,0 +1,1337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Toshiyuki Sakamoto(tanreinama) and HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch GPTSANJapanese model."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
from typing import List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
from ....activations import ACT2FN
|
| 24 |
+
from ....modeling_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPastAndCrossAttentions
|
| 25 |
+
from ....modeling_utils import PreTrainedModel
|
| 26 |
+
from ....utils import (
|
| 27 |
+
DUMMY_INPUTS,
|
| 28 |
+
DUMMY_MASK,
|
| 29 |
+
add_start_docstrings,
|
| 30 |
+
add_start_docstrings_to_model_forward,
|
| 31 |
+
is_torch_fx_proxy,
|
| 32 |
+
logging,
|
| 33 |
+
)
|
| 34 |
+
from .configuration_gptsan_japanese import GPTSanJapaneseConfig
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
_CONFIG_FOR_DOC = "GPTSanJapaneseConfig"
|
| 40 |
+
_CHECKPOINT_FOR_DOC = "Tanrei/GPTSAN-japanese"
|
| 41 |
+
|
| 42 |
+
####################################################
|
| 43 |
+
# This dict contains ids and associated url
|
| 44 |
+
# for the pretrained weights provided with the models
|
| 45 |
+
####################################################
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def router_z_loss_func(router_logits: torch.Tensor) -> float:
|
| 49 |
+
r"""
|
| 50 |
+
Compute the router z-loss implemented in PyTorch.
|
| 51 |
+
|
| 52 |
+
The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
|
| 53 |
+
It encourages router logits to remain small in an effort to improve stability.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
router_logits (`float`):
|
| 57 |
+
Input logits of shape [batch_size, sequence_length, num_experts]
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Scalar router z-loss.
|
| 61 |
+
"""
|
| 62 |
+
num_groups, tokens_per_group, _ = router_logits.shape
|
| 63 |
+
log_z = torch.logsumexp(router_logits, dim=-1)
|
| 64 |
+
z_loss = log_z**2
|
| 65 |
+
return torch.sum(z_loss) / (num_groups * tokens_per_group)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
|
| 69 |
+
r"""
|
| 70 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 71 |
+
|
| 72 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
| 73 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 74 |
+
experts is too unbalanced.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
router_probs (`torch.Tensor`):
|
| 78 |
+
Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts].
|
| 79 |
+
expert_indices (`torch.Tensor`):
|
| 80 |
+
Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
The auxiliary loss.
|
| 84 |
+
"""
|
| 85 |
+
num_experts = router_probs.shape[-1]
|
| 86 |
+
|
| 87 |
+
# cast the expert indices to int64, otherwise one-hot encoding will fail
|
| 88 |
+
if expert_indices.dtype != torch.int64:
|
| 89 |
+
expert_indices = expert_indices.to(torch.int64)
|
| 90 |
+
|
| 91 |
+
if len(expert_indices.shape) == 2:
|
| 92 |
+
expert_indices = expert_indices.unsqueeze(2)
|
| 93 |
+
|
| 94 |
+
expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
|
| 95 |
+
|
| 96 |
+
# For a given token, determine if it was routed to a given expert.
|
| 97 |
+
expert_mask = torch.max(expert_mask, axis=-2).values
|
| 98 |
+
|
| 99 |
+
# cast to float32 otherwise mean will fail
|
| 100 |
+
expert_mask = expert_mask.to(torch.float32)
|
| 101 |
+
tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
|
| 102 |
+
|
| 103 |
+
router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)
|
| 104 |
+
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class GPTSanJapaneseDenseActDense(nn.Module):
|
| 108 |
+
"""
|
| 109 |
+
FFN Layer for Switch Transformer and Extra layers
|
| 110 |
+
|
| 111 |
+
GPTSAN can mix Switch Transformer layers and normal Transformer layers This class is used as Expert in Switch
|
| 112 |
+
Transformer layers and as FFN in regular Transformer layers. RELU is used in the Switch Transformer layer, and
|
| 113 |
+
Swish is used in the normal Transformer layer, so there is a choice of which is used in the argument.
|
| 114 |
+
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, config: GPTSanJapaneseConfig, ext_layer=False):
|
| 118 |
+
super().__init__()
|
| 119 |
+
d_inter = config.d_ext if ext_layer else config.d_ff
|
| 120 |
+
self.wi = nn.Linear(config.d_model, d_inter, bias=ext_layer)
|
| 121 |
+
self.wo = nn.Linear(d_inter, config.d_model, bias=ext_layer)
|
| 122 |
+
self.dropout = nn.Identity() if ext_layer else nn.Dropout(config.dropout_rate)
|
| 123 |
+
self.act = ACT2FN["swish" if ext_layer else "relu"]
|
| 124 |
+
|
| 125 |
+
def forward(self, hidden_states):
|
| 126 |
+
r"""
|
| 127 |
+
Args:
|
| 128 |
+
hidden_states (`torch.Tensor`) :
|
| 129 |
+
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
|
| 130 |
+
Returns:
|
| 131 |
+
torch.Tensor[num_groups, tokens_per_group, hidden_dim]
|
| 132 |
+
|
| 133 |
+
"""
|
| 134 |
+
hidden_states = self.wi(hidden_states)
|
| 135 |
+
hidden_states = self.act(hidden_states)
|
| 136 |
+
hidden_states = self.dropout(hidden_states)
|
| 137 |
+
hidden_states = self.wo(hidden_states)
|
| 138 |
+
return hidden_states
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class GPTSanJapaneseTop1Router(nn.Module):
|
| 142 |
+
"""
|
| 143 |
+
Router using tokens choose top-1 experts assignment.
|
| 144 |
+
|
| 145 |
+
This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE
|
| 146 |
+
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then
|
| 147 |
+
routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each
|
| 148 |
+
token is processed by an expert**, or that each expert receives at least one token.
|
| 149 |
+
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(self, config: GPTSanJapaneseConfig):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.num_experts = config.num_experts
|
| 155 |
+
self.expert_capacity = config.expert_capacity
|
| 156 |
+
self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
|
| 157 |
+
self.jitter_noise = config.router_jitter_noise
|
| 158 |
+
self.ignore_padding_tokens = config.router_ignore_padding_tokens
|
| 159 |
+
self.dtype = getattr(torch, config.router_dtype)
|
| 160 |
+
|
| 161 |
+
def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 162 |
+
r"""
|
| 163 |
+
Computes router probabilities from input hidden states.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
hidden_states (`torch.Tensor`):
|
| 167 |
+
(batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
|
| 168 |
+
Returns:
|
| 169 |
+
router_probabilities (`torch.Tensor`):
|
| 170 |
+
Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
|
| 171 |
+
token and expert. Used for routing tokens to experts.
|
| 172 |
+
router_logits (`torch.Tensor`):
|
| 173 |
+
Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
|
| 174 |
+
This is used later for computing router z-loss.
|
| 175 |
+
"""
|
| 176 |
+
# float32 is used to ensure stability. See the discussion of "selective precision" in
|
| 177 |
+
# https://arxiv.org/abs/2101.03961.
|
| 178 |
+
# We also store the previous dtype to cast back the output to the previous dtype
|
| 179 |
+
self.input_dtype = hidden_states.dtype
|
| 180 |
+
hidden_states = hidden_states.to(self.dtype)
|
| 181 |
+
|
| 182 |
+
if self.training and self.jitter_noise > 0:
|
| 183 |
+
# Multiply the token inputs by the uniform distribution - adding some noise
|
| 184 |
+
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
| 185 |
+
|
| 186 |
+
# Shape: [num_groups, tokens_per_group, num_experts]
|
| 187 |
+
self._cast_classifier()
|
| 188 |
+
router_logits = self.classifier(hidden_states)
|
| 189 |
+
|
| 190 |
+
# Apply Softmax and cast back to the original `dtype`
|
| 191 |
+
router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
|
| 192 |
+
return router_probabilities, router_logits
|
| 193 |
+
|
| 194 |
+
def _cast_classifier(self):
|
| 195 |
+
r"""
|
| 196 |
+
`bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
|
| 197 |
+
instance of the `Linear8bitLt` class by checking special attributes.
|
| 198 |
+
"""
|
| 199 |
+
if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
|
| 200 |
+
self.classifier = self.classifier.to(self.dtype)
|
| 201 |
+
|
| 202 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple:
|
| 203 |
+
r"""
|
| 204 |
+
Generic forward function for every Router class. Each Router expects to have the same input hidden states
|
| 205 |
+
(`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
|
| 206 |
+
number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.
|
| 207 |
+
|
| 208 |
+
Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
|
| 209 |
+
`router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
|
| 210 |
+
to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
hidden_states (`torch.Tensor`) :
|
| 214 |
+
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
|
| 215 |
+
Returns:
|
| 216 |
+
Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs
|
| 217 |
+
and the router logits. The router probabilities and logits are required to compute the loss.
|
| 218 |
+
"""
|
| 219 |
+
router_probs, router_logits = self._compute_router_probabilities(hidden_states)
|
| 220 |
+
|
| 221 |
+
expert_index = torch.argmax(router_probs, dim=-1)
|
| 222 |
+
expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
|
| 223 |
+
|
| 224 |
+
# Mask tokens outside expert capacity. Sum over each sequence
|
| 225 |
+
token_priority = torch.cumsum(expert_index, dim=-2)
|
| 226 |
+
# mask if the token routed to to the expert will overflow
|
| 227 |
+
expert_capacity_mask = token_priority <= self.expert_capacity
|
| 228 |
+
expert_index = expert_index * expert_capacity_mask
|
| 229 |
+
|
| 230 |
+
router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
|
| 231 |
+
return expert_index, router_probs, router_logits
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class GPTSanJapaneseSparseMLP(nn.Module):
|
| 235 |
+
r"""
|
| 236 |
+
Implementation of the Switch Transformers Sparse MLP module.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(self, config: GPTSanJapaneseConfig, expert_class: nn.Module = GPTSanJapaneseDenseActDense):
|
| 240 |
+
super().__init__()
|
| 241 |
+
# Step 1: Get the correct router according to its class
|
| 242 |
+
self.router = GPTSanJapaneseTop1Router(config)
|
| 243 |
+
|
| 244 |
+
# Step 2: Get the experts
|
| 245 |
+
self.experts = nn.ModuleDict()
|
| 246 |
+
for idx in range(config.num_experts):
|
| 247 |
+
self.experts[f"expert_{idx}"] = expert_class(config)
|
| 248 |
+
|
| 249 |
+
def forward(self, hidden_states):
|
| 250 |
+
r"""
|
| 251 |
+
Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:
|
| 252 |
+
|
| 253 |
+
1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
|
| 254 |
+
and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
|
| 255 |
+
hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
|
| 256 |
+
|
| 257 |
+
2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
|
| 258 |
+
expert the corresponding hidden states.
|
| 259 |
+
|
| 260 |
+
"""
|
| 261 |
+
# Step 1: Get the router_mask from the router as wel as the probabilities
|
| 262 |
+
router_mask, router_probs, router_logits = self.router(hidden_states)
|
| 263 |
+
expert_index = torch.argmax(router_mask, dim=-1)
|
| 264 |
+
|
| 265 |
+
# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
|
| 266 |
+
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
|
| 267 |
+
|
| 268 |
+
next_states = hidden_states.clone()
|
| 269 |
+
for idx, expert in enumerate(self.experts.values()):
|
| 270 |
+
token_indices = router_mask[:, :, idx].bool()
|
| 271 |
+
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
|
| 272 |
+
|
| 273 |
+
hidden_states = router_probs * next_states
|
| 274 |
+
return hidden_states, (router_logits, expert_index)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class GPTSanJapaneseLayerSparseFF(nn.Module):
|
| 278 |
+
r"""
|
| 279 |
+
Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module.
|
| 280 |
+
|
| 281 |
+
Parameters:
|
| 282 |
+
config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
|
| 283 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 284 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(self, config: GPTSanJapaneseConfig):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.mlp = GPTSanJapaneseSparseMLP(config)
|
| 290 |
+
self.soft_bypass_mlp = nn.Linear(config.d_model, config.d_model, bias=False)
|
| 291 |
+
self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
| 292 |
+
|
| 293 |
+
def forward(self, hidden_states, output_router_logits):
|
| 294 |
+
r"""
|
| 295 |
+
Args:
|
| 296 |
+
hidden_states (`torch.Tensor`) :
|
| 297 |
+
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
|
| 298 |
+
output_router_logits (`bool`) :
|
| 299 |
+
output experts router output.
|
| 300 |
+
Returns:
|
| 301 |
+
torch.Tensor[num_groups, tokens_per_group, hidden_dim]
|
| 302 |
+
|
| 303 |
+
"""
|
| 304 |
+
forwarded_states, router_tuple = self.mlp(hidden_states)
|
| 305 |
+
forwarded_states += torch.tanh(self.soft_bypass_mlp(hidden_states))
|
| 306 |
+
output = hidden_states + self.norm(forwarded_states)
|
| 307 |
+
|
| 308 |
+
if output_router_logits and router_tuple is not None:
|
| 309 |
+
return output, router_tuple
|
| 310 |
+
else:
|
| 311 |
+
return output
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class GPTSanJapaneseLayerDenseFF(nn.Module):
|
| 315 |
+
r"""
|
| 316 |
+
Extra Transformers Feed Forward layer module.
|
| 317 |
+
|
| 318 |
+
Parameters:
|
| 319 |
+
config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
|
| 320 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 321 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
def __init__(self, config: GPTSanJapaneseConfig):
|
| 325 |
+
super().__init__()
|
| 326 |
+
# Check if it is a sparse layer, if not then it is a dense layer
|
| 327 |
+
self.mlp = GPTSanJapaneseDenseActDense(config, ext_layer=True)
|
| 328 |
+
self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
| 329 |
+
|
| 330 |
+
def forward(self, hidden_states):
|
| 331 |
+
r"""
|
| 332 |
+
Args:
|
| 333 |
+
hidden_states (`torch.Tensor`) :
|
| 334 |
+
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
|
| 335 |
+
Returns:
|
| 336 |
+
torch.Tensor[num_groups, tokens_per_group, hidden_dim]
|
| 337 |
+
|
| 338 |
+
"""
|
| 339 |
+
forwarded_states = self.mlp(hidden_states)
|
| 340 |
+
output = hidden_states + self.norm(forwarded_states)
|
| 341 |
+
return output
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class GPTSanJapaneseAttention(nn.Module):
|
| 345 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 346 |
+
|
| 347 |
+
def __init__(
|
| 348 |
+
self,
|
| 349 |
+
embed_dim: int,
|
| 350 |
+
num_heads: int,
|
| 351 |
+
dropout: float = 0.0,
|
| 352 |
+
is_decoder: bool = False,
|
| 353 |
+
bias: bool = True,
|
| 354 |
+
is_causal: bool = False,
|
| 355 |
+
config: Optional[GPTSanJapaneseConfig] = None,
|
| 356 |
+
):
|
| 357 |
+
super().__init__()
|
| 358 |
+
self.embed_dim = embed_dim
|
| 359 |
+
self.num_heads = num_heads
|
| 360 |
+
self.dropout = dropout
|
| 361 |
+
self.head_dim = embed_dim // num_heads
|
| 362 |
+
self.config = config
|
| 363 |
+
|
| 364 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
| 365 |
+
raise ValueError(
|
| 366 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
| 367 |
+
f" and `num_heads`: {num_heads})."
|
| 368 |
+
)
|
| 369 |
+
self.scaling = self.head_dim**-0.5
|
| 370 |
+
self.is_decoder = is_decoder
|
| 371 |
+
self.is_causal = is_causal
|
| 372 |
+
|
| 373 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 374 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 375 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 376 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 377 |
+
|
| 378 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 379 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 380 |
+
|
| 381 |
+
def forward(
|
| 382 |
+
self,
|
| 383 |
+
hidden_states: torch.Tensor,
|
| 384 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 385 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 386 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 387 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 388 |
+
output_attentions: bool = False,
|
| 389 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 390 |
+
"""Input shape: Batch x Time x Channel"""
|
| 391 |
+
|
| 392 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 393 |
+
# for the decoder
|
| 394 |
+
is_cross_attention = key_value_states is not None
|
| 395 |
+
|
| 396 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 397 |
+
|
| 398 |
+
# get query proj
|
| 399 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 400 |
+
# get key, value proj
|
| 401 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
| 402 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
| 403 |
+
# the provided `key_value_states` to support prefix tuning
|
| 404 |
+
if (
|
| 405 |
+
is_cross_attention
|
| 406 |
+
and past_key_value is not None
|
| 407 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 408 |
+
):
|
| 409 |
+
# reuse k,v, cross_attentions
|
| 410 |
+
key_states = past_key_value[0]
|
| 411 |
+
value_states = past_key_value[1]
|
| 412 |
+
elif is_cross_attention:
|
| 413 |
+
# cross_attentions
|
| 414 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 415 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 416 |
+
elif past_key_value is not None:
|
| 417 |
+
# reuse k, v, self_attention
|
| 418 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 419 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 420 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 421 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 422 |
+
else:
|
| 423 |
+
# self_attention
|
| 424 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 425 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 426 |
+
|
| 427 |
+
if self.is_decoder:
|
| 428 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 429 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 430 |
+
# key/value_states (first "if" case)
|
| 431 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 432 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 433 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 434 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 435 |
+
past_key_value = (key_states, value_states)
|
| 436 |
+
|
| 437 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
| 438 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
| 439 |
+
key_states = key_states.reshape(*proj_shape)
|
| 440 |
+
value_states = value_states.reshape(*proj_shape)
|
| 441 |
+
|
| 442 |
+
src_len = key_states.size(1)
|
| 443 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 444 |
+
|
| 445 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
| 446 |
+
raise ValueError(
|
| 447 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
| 448 |
+
f" {attn_weights.size()}"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
if attention_mask is not None:
|
| 452 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 453 |
+
raise ValueError(
|
| 454 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 455 |
+
)
|
| 456 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 457 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 458 |
+
|
| 459 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 460 |
+
|
| 461 |
+
if layer_head_mask is not None:
|
| 462 |
+
if layer_head_mask.size() != (self.num_heads,):
|
| 463 |
+
raise ValueError(
|
| 464 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
| 465 |
+
f" {layer_head_mask.size()}"
|
| 466 |
+
)
|
| 467 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 468 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 469 |
+
|
| 470 |
+
if output_attentions:
|
| 471 |
+
# this operation is a bit awkward, but it's required to
|
| 472 |
+
# make sure that attn_weights keeps its gradient.
|
| 473 |
+
# In order to do so, attn_weights have to be reshaped
|
| 474 |
+
# twice and have to be reused in the following
|
| 475 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 476 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
| 477 |
+
else:
|
| 478 |
+
attn_weights_reshaped = None
|
| 479 |
+
|
| 480 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 481 |
+
|
| 482 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 483 |
+
|
| 484 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
| 485 |
+
raise ValueError(
|
| 486 |
+
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 487 |
+
f" {attn_output.size()}"
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
| 491 |
+
attn_output = attn_output.transpose(1, 2)
|
| 492 |
+
|
| 493 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
| 494 |
+
# partitioned across GPUs when using tensor-parallelism.
|
| 495 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 496 |
+
|
| 497 |
+
attn_output = self.out_proj(attn_output)
|
| 498 |
+
|
| 499 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class GPTSanJapaneseLayerSelfAttention(nn.Module):
|
| 503 |
+
"""
|
| 504 |
+
Self Attention and Normalization Unit
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
def __init__(self, config, has_relative_attention_bias=False):
|
| 508 |
+
super().__init__()
|
| 509 |
+
self.self_attn = GPTSanJapaneseAttention(
|
| 510 |
+
embed_dim=config.d_model,
|
| 511 |
+
num_heads=config.num_heads,
|
| 512 |
+
is_decoder=True,
|
| 513 |
+
bias=has_relative_attention_bias,
|
| 514 |
+
)
|
| 515 |
+
self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
| 516 |
+
|
| 517 |
+
def forward(
|
| 518 |
+
self,
|
| 519 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 520 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 521 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 522 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 523 |
+
use_cache: Optional[bool] = False,
|
| 524 |
+
output_attentions: Optional[bool] = False,
|
| 525 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
| 526 |
+
r"""
|
| 527 |
+
Self-attention and normalize block.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 531 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 532 |
+
if the model is configured as a decoder.
|
| 533 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 534 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
|
| 535 |
+
decoding. If `past_key_values` are used, the user can optionally input only the last
|
| 536 |
+
`decoder_input_ids` (those that don't have their past key value states given to this model) of shape
|
| 537 |
+
`(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 538 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 539 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
| 540 |
+
in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 541 |
+
|
| 542 |
+
- 1 for tokens that are **not masked**,
|
| 543 |
+
- 0 for tokens that are **masked**.
|
| 544 |
+
|
| 545 |
+
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
|
| 546 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
| 547 |
+
|
| 548 |
+
- 1 indicates the head is **not masked**,
|
| 549 |
+
- 0 indicates the head is **masked**.
|
| 550 |
+
|
| 551 |
+
use_cache (`bool`, *optional*):
|
| 552 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 553 |
+
(see `past_key_values`).
|
| 554 |
+
output_attentions (`bool`, *optional*):
|
| 555 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 556 |
+
returned tensors for more detail.
|
| 557 |
+
Returns:
|
| 558 |
+
Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...]
|
| 559 |
+
"""
|
| 560 |
+
# Self Attention
|
| 561 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 562 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 563 |
+
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
| 564 |
+
atten_out = self.self_attn(
|
| 565 |
+
hidden_states=hidden_states,
|
| 566 |
+
past_key_value=self_attn_past_key_value,
|
| 567 |
+
attention_mask=(1 - attention_mask) * torch.finfo(hidden_states.dtype).min,
|
| 568 |
+
layer_head_mask=head_mask,
|
| 569 |
+
output_attentions=output_attentions,
|
| 570 |
+
)
|
| 571 |
+
if output_attentions:
|
| 572 |
+
attn_weights = (atten_out[1],)
|
| 573 |
+
else:
|
| 574 |
+
attn_weights = ()
|
| 575 |
+
|
| 576 |
+
attention_output = atten_out[0]
|
| 577 |
+
|
| 578 |
+
hidden = hidden_states + self.norm(attention_output)
|
| 579 |
+
|
| 580 |
+
if use_cache:
|
| 581 |
+
outputs = (hidden, atten_out[2]) # hidden, present, (attentions)
|
| 582 |
+
else:
|
| 583 |
+
outputs = (hidden,) # hidden, (attentions)
|
| 584 |
+
|
| 585 |
+
return outputs + attn_weights
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
class GPTSanJapaneseBlock(nn.Module):
|
| 589 |
+
"""
|
| 590 |
+
Self Attention and FFN Unit
|
| 591 |
+
"""
|
| 592 |
+
|
| 593 |
+
def __init__(self, config, ext_layer=False):
|
| 594 |
+
super().__init__()
|
| 595 |
+
self.self_attn = GPTSanJapaneseLayerSelfAttention(config)
|
| 596 |
+
self.feed_forward = GPTSanJapaneseLayerDenseFF(config) if ext_layer else GPTSanJapaneseLayerSparseFF(config)
|
| 597 |
+
|
| 598 |
+
def forward(
|
| 599 |
+
self,
|
| 600 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 601 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 602 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 603 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 604 |
+
use_cache: Optional[bool] = False,
|
| 605 |
+
output_attentions: Optional[bool] = False,
|
| 606 |
+
output_router_tuple: Optional[bool] = False,
|
| 607 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
| 608 |
+
r"""
|
| 609 |
+
GPTSAN transformer block.
|
| 610 |
+
|
| 611 |
+
Args:
|
| 612 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 613 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 614 |
+
if the model is configured as a decoder.
|
| 615 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 616 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
|
| 617 |
+
decoding. If `past_key_values` are used, the user can optionally input only the last
|
| 618 |
+
`decoder_input_ids` (those that don't have their past key value states given to this model) of shape
|
| 619 |
+
`(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 620 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 621 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
| 622 |
+
in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 623 |
+
|
| 624 |
+
- 1 for tokens that are **not masked**,
|
| 625 |
+
- 0 for tokens that are **masked**.
|
| 626 |
+
|
| 627 |
+
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
|
| 628 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
| 629 |
+
|
| 630 |
+
- 1 indicates the head is **not masked**,
|
| 631 |
+
- 0 indicates the head is **masked**.
|
| 632 |
+
|
| 633 |
+
use_cache (`bool`, *optional*):
|
| 634 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 635 |
+
(see `past_key_values`).
|
| 636 |
+
output_attentions (`bool`) :
|
| 637 |
+
output attention probabirities.
|
| 638 |
+
output_router_tuple:
|
| 639 |
+
output experts router logits and expert id.
|
| 640 |
+
Returns:
|
| 641 |
+
Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...]
|
| 642 |
+
"""
|
| 643 |
+
atten_out = self.self_attn(
|
| 644 |
+
hidden_states=hidden_states,
|
| 645 |
+
past_key_value=past_key_value,
|
| 646 |
+
attention_mask=attention_mask,
|
| 647 |
+
head_mask=head_mask,
|
| 648 |
+
use_cache=use_cache,
|
| 649 |
+
output_attentions=output_attentions,
|
| 650 |
+
)
|
| 651 |
+
attention_output = atten_out[0]
|
| 652 |
+
|
| 653 |
+
if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF):
|
| 654 |
+
sparse_out = self.feed_forward(attention_output, output_router_tuple)
|
| 655 |
+
if output_router_tuple:
|
| 656 |
+
hidden, router_tuple = sparse_out
|
| 657 |
+
else:
|
| 658 |
+
hidden = sparse_out
|
| 659 |
+
else:
|
| 660 |
+
hidden = self.feed_forward(attention_output)
|
| 661 |
+
|
| 662 |
+
outputs = (hidden,) + atten_out[1:]
|
| 663 |
+
|
| 664 |
+
if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF) and output_router_tuple:
|
| 665 |
+
outputs += (router_tuple,)
|
| 666 |
+
|
| 667 |
+
return outputs
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
class GPTSanJapanesePreTrainedModel(PreTrainedModel):
|
| 671 |
+
"""
|
| 672 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 673 |
+
models.
|
| 674 |
+
"""
|
| 675 |
+
|
| 676 |
+
config_class = GPTSanJapaneseConfig
|
| 677 |
+
base_model_prefix = "gptsan_japanese"
|
| 678 |
+
supports_gradient_checkpointing = False
|
| 679 |
+
_no_split_modules = ["GPTSanJapaneseBlock"]
|
| 680 |
+
_skip_keys_device_placement = "past_key_values"
|
| 681 |
+
|
| 682 |
+
@property
|
| 683 |
+
def dummy_inputs(self):
|
| 684 |
+
input_ids = torch.tensor(DUMMY_INPUTS)
|
| 685 |
+
input_mask = torch.tensor(DUMMY_MASK)
|
| 686 |
+
dummy_inputs = {
|
| 687 |
+
"input_ids": input_ids,
|
| 688 |
+
"attention_mask": input_mask,
|
| 689 |
+
}
|
| 690 |
+
return dummy_inputs
|
| 691 |
+
|
| 692 |
+
def _init_weights(self, module):
|
| 693 |
+
"""Initialize the weights"""
|
| 694 |
+
factor = self.config.initializer_factor # Used for testing weights initialization
|
| 695 |
+
if isinstance(module, nn.LayerNorm):
|
| 696 |
+
module.weight.data.fill_(factor * 1.0)
|
| 697 |
+
module.bias.data.zero_()
|
| 698 |
+
elif isinstance(module, nn.Linear):
|
| 699 |
+
module.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
| 700 |
+
if hasattr(module, "bias") and module.bias is not None:
|
| 701 |
+
module.bias.data.zero_()
|
| 702 |
+
elif isinstance(module, nn.Embedding):
|
| 703 |
+
module.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
| 704 |
+
elif isinstance(module, GPTSanJapaneseModel):
|
| 705 |
+
# Mesh TensorFlow embeddings initialization
|
| 706 |
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
| 707 |
+
module.embed_tokens.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
| 708 |
+
module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
| 709 |
+
if hasattr(module, "extra_position_embeddings") and module.extra_position_embeddings is not None:
|
| 710 |
+
module.extra_position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
| 711 |
+
elif isinstance(module, (GPTSanJapaneseModel, GPTSanJapaneseForConditionalGeneration)):
|
| 712 |
+
# Mesh TensorFlow embeddings initialization
|
| 713 |
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
| 714 |
+
module.final_logits_bias.data.normal_(mean=0.0, std=factor * 1.0)
|
| 715 |
+
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
|
| 716 |
+
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
|
| 717 |
+
elif isinstance(module, GPTSanJapaneseDenseActDense):
|
| 718 |
+
# Mesh TensorFlow FF initialization
|
| 719 |
+
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
| 720 |
+
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
|
| 721 |
+
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
| 722 |
+
if hasattr(module.wi, "bias") and module.wi.bias is not None:
|
| 723 |
+
module.wi.bias.data.zero_()
|
| 724 |
+
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
| 725 |
+
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
| 726 |
+
module.wo.bias.data.zero_()
|
| 727 |
+
elif isinstance(module, GPTSanJapaneseAttention):
|
| 728 |
+
# Multi-headed attention
|
| 729 |
+
d_model = self.config.d_model
|
| 730 |
+
key_value_proj_dim = self.config.d_model
|
| 731 |
+
n_heads = self.config.num_heads
|
| 732 |
+
module.k_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
|
| 733 |
+
module.v_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
|
| 734 |
+
module.q_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
|
| 735 |
+
module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
|
| 736 |
+
elif isinstance(module, GPTSanJapaneseSparseMLP):
|
| 737 |
+
# Mesh TensorFlow attention initialization to avoid scaling before softmax
|
| 738 |
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
|
| 739 |
+
d_model = self.config.d_model
|
| 740 |
+
key_value_proj_dim = self.config.d_model
|
| 741 |
+
n_heads = self.config.num_heads
|
| 742 |
+
module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1)
|
| 743 |
+
for idx in range(self.config.num_experts):
|
| 744 |
+
module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
| 745 |
+
module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
| 746 |
+
|
| 747 |
+
def _shift_right(self, input_ids):
|
| 748 |
+
decoder_start_token_id = self.config.decoder_start_token_id
|
| 749 |
+
pad_token_id = self.config.pad_token_id
|
| 750 |
+
|
| 751 |
+
if decoder_start_token_id is None:
|
| 752 |
+
raise ValueError(
|
| 753 |
+
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
|
| 754 |
+
"See T5 docs for more information."
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# shift inputs to the right
|
| 758 |
+
if is_torch_fx_proxy(input_ids):
|
| 759 |
+
# Item assignment is not supported natively for proxies.
|
| 760 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
|
| 761 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
| 762 |
+
else:
|
| 763 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 764 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 765 |
+
shifted_input_ids[..., 0] = decoder_start_token_id
|
| 766 |
+
|
| 767 |
+
if pad_token_id is None:
|
| 768 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
| 769 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 770 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
| 771 |
+
|
| 772 |
+
return shifted_input_ids
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
GPTSAN_JAPANESE_START_DOCSTRING = r"""
|
| 776 |
+
|
| 777 |
+
The [GPTSAN-japanese](https://github.com/tanreinama/GPTSAN) model was proposed in General-purpose Swich transformer
|
| 778 |
+
based Japanese language model
|
| 779 |
+
|
| 780 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 781 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 782 |
+
and behavior.
|
| 783 |
+
|
| 784 |
+
Parameters:
|
| 785 |
+
config ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
|
| 786 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 787 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 788 |
+
"""
|
| 789 |
+
|
| 790 |
+
GPTSAN_JAPANESE_INPUTS_DOCSTRING = r"""
|
| 791 |
+
Args:
|
| 792 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 793 |
+
Indices of input sequence tokens in the vocabulary. GPTSAN-japanese is a model that generates sentence
|
| 794 |
+
continuations or predicts tokens at mask positions. Special tokens required for inputs to the model are
|
| 795 |
+
automatically appended.
|
| 796 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 797 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 798 |
+
|
| 799 |
+
- 1 for tokens that are **not masked**,
|
| 800 |
+
- 0 for tokens that are **masked**.
|
| 801 |
+
|
| 802 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 803 |
+
token_type_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 804 |
+
An input that masks the Prefix part in the Prefix-LM input. Mask values selected in `[0, 1]`:
|
| 805 |
+
|
| 806 |
+
- 1 for tokens that are **prefix** input,
|
| 807 |
+
- 0 for tokens that are **not-prefix** input.
|
| 808 |
+
spout (`torch.Tensor` of shape `(batch_size, config.d_spout)`):
|
| 809 |
+
This vector is transformed through an 8-layer FFN and can be used instead of `past_key_values`.
|
| 810 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 811 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 812 |
+
|
| 813 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 814 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 815 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 816 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 817 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 818 |
+
use_cache (`bool`, *optional*):
|
| 819 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 820 |
+
`past_key_values`).
|
| 821 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 822 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 823 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 824 |
+
model's internal embedding lookup matrix.
|
| 825 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
|
| 826 |
+
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
|
| 827 |
+
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
|
| 828 |
+
input (see `past_key_values`). This is useful if you want more control over how to convert
|
| 829 |
+
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
| 830 |
+
output_attentions (`bool`, *optional*):
|
| 831 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 832 |
+
tensors for more detail.
|
| 833 |
+
output_hidden_states (`bool`, *optional*):
|
| 834 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 835 |
+
more detail.
|
| 836 |
+
return_dict (`bool`, *optional*):
|
| 837 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 838 |
+
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
|
| 839 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
|
| 840 |
+
Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
|
| 841 |
+
"""
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
@add_start_docstrings(
|
| 845 |
+
"The bare GPTSAN-japanese Model transformer outputting raw hidden-states without any specific head on top.",
|
| 846 |
+
GPTSAN_JAPANESE_START_DOCSTRING,
|
| 847 |
+
)
|
| 848 |
+
class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
|
| 849 |
+
def __init__(self, config: GPTSanJapaneseConfig):
|
| 850 |
+
super().__init__(config)
|
| 851 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
|
| 852 |
+
self.config = copy.deepcopy(config)
|
| 853 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
|
| 854 |
+
self.last_project = nn.Linear(config.d_model, config.d_model, bias=True)
|
| 855 |
+
self.act = ACT2FN["swish"]
|
| 856 |
+
|
| 857 |
+
self.blocks = torch.nn.ModuleList([])
|
| 858 |
+
for _ in range(config.num_switch_layers):
|
| 859 |
+
self.blocks.append(GPTSanJapaneseBlock(config))
|
| 860 |
+
for _ in range(config.num_ext_layers):
|
| 861 |
+
self.blocks.append(GPTSanJapaneseBlock(config, ext_layer=True))
|
| 862 |
+
|
| 863 |
+
if config.num_ext_layers > 0:
|
| 864 |
+
self.extra_position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
|
| 865 |
+
|
| 866 |
+
if config.d_spout:
|
| 867 |
+
spouts = []
|
| 868 |
+
for _ in range(8):
|
| 869 |
+
spouts.append(nn.Linear(config.d_spout, config.d_spout, bias=False))
|
| 870 |
+
spouts.append(nn.Tanh())
|
| 871 |
+
spouts.append(nn.Linear(config.d_spout, config.num_layers * 2 * config.d_model, bias=False))
|
| 872 |
+
self.spout = nn.Sequential(*spouts)
|
| 873 |
+
|
| 874 |
+
self.post_init()
|
| 875 |
+
|
| 876 |
+
def get_input_embeddings(self):
|
| 877 |
+
return self.embed_tokens
|
| 878 |
+
|
| 879 |
+
def set_input_embeddings(self, new_embeddings):
|
| 880 |
+
self.embed_tokens = new_embeddings
|
| 881 |
+
|
| 882 |
+
@add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING)
|
| 883 |
+
def forward(
|
| 884 |
+
self,
|
| 885 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 886 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 887 |
+
token_type_ids: Optional[torch.FloatTensor] = None,
|
| 888 |
+
spout: Optional[torch.FloatTensor] = None,
|
| 889 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 890 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 891 |
+
use_cache: Optional[bool] = False,
|
| 892 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 893 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 894 |
+
output_attentions: Optional[bool] = None,
|
| 895 |
+
output_hidden_states: Optional[bool] = None,
|
| 896 |
+
return_dict: Optional[bool] = None,
|
| 897 |
+
output_router_logits: Optional[bool] = None,
|
| 898 |
+
num_precontext: Optional[torch.LongTensor] = None,
|
| 899 |
+
) -> Union[MoEModelOutputWithPastAndCrossAttentions, Tuple[torch.FloatTensor]]:
|
| 900 |
+
r"""
|
| 901 |
+
num_precontext (`torch.LongTensor` of shape `(batch_size,1)`):
|
| 902 |
+
length of `hybrid` input tokens in the input. Tokens up to this length refer to both front and back like
|
| 903 |
+
BERT, tokens after that refer only to front like GPT. see also:
|
| 904 |
+
https://github.com/tanreinama/GPTSAN/blob/main/report/model.md
|
| 905 |
+
|
| 906 |
+
Returns:
|
| 907 |
+
`MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns
|
| 908 |
+
MoEModelOutputWithPastAndCrossAttentions insted of tuple
|
| 909 |
+
"""
|
| 910 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 911 |
+
device = self.position_embeddings.weight.device
|
| 912 |
+
if input_ids is None:
|
| 913 |
+
input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None
|
| 914 |
+
if inputs_embeds is not None:
|
| 915 |
+
raise NotImplementedError(
|
| 916 |
+
"GPTSanJapaneseModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
|
| 917 |
+
)
|
| 918 |
+
num_pasts_contexts = 0
|
| 919 |
+
num_batch = input_ids.shape[0]
|
| 920 |
+
pasts_or_spout_value = None
|
| 921 |
+
if past_key_values is not None:
|
| 922 |
+
num_pasts_contexts = past_key_values[0][0].shape[2]
|
| 923 |
+
elif self.config.d_spout and spout is not None:
|
| 924 |
+
# `spout` is a special input vector specific to GPTSAN
|
| 925 |
+
# This controls the output by projecting embedded information such as the class of sentences during learning.
|
| 926 |
+
# It should passed instead of the first past_key_value.
|
| 927 |
+
# See the original GPTSAN repository for details
|
| 928 |
+
num_pasts_contexts += 1
|
| 929 |
+
|
| 930 |
+
# If there is an attention_mask, increase first one for spout
|
| 931 |
+
if self.config.d_spout and spout is not None and attention_mask is not None:
|
| 932 |
+
attention_mask_with_spout = torch.ones(num_batch, attention_mask.shape[1] + 1, device=device)
|
| 933 |
+
attention_mask_with_spout[:, 1:] -= 1 - attention_mask # 1st token should be spout
|
| 934 |
+
attention_mask = attention_mask_with_spout # update attention_mask
|
| 935 |
+
|
| 936 |
+
if num_precontext is not None:
|
| 937 |
+
# `num_precontext` is the number of tokens that refer to each other in prefix-lm
|
| 938 |
+
# created per batch, so dimension of num_precontext should be [batch, 1]
|
| 939 |
+
if not (
|
| 940 |
+
len(num_precontext.shape) == 2 and num_precontext.shape[1] == 1
|
| 941 |
+
): # num_precontext Should be [batch,1]
|
| 942 |
+
raise ValueError("num_precontext should be [batch, 1] size.")
|
| 943 |
+
num_precontext = torch.reshape(num_precontext, [-1])
|
| 944 |
+
else:
|
| 945 |
+
num_precontext = torch.zeros([num_batch]).int().to(device)
|
| 946 |
+
|
| 947 |
+
num_input_contexts = input_ids.shape[1]
|
| 948 |
+
num_output_contexts = num_input_contexts + num_pasts_contexts
|
| 949 |
+
|
| 950 |
+
hidden_states = self.embed_tokens(input_ids)
|
| 951 |
+
|
| 952 |
+
if past_key_values is not None:
|
| 953 |
+
pasts_or_spout_value = past_key_values
|
| 954 |
+
elif self.config.d_spout and spout is not None:
|
| 955 |
+
# Make vector from `spout` of GPTSAN to the same shape as past_key_values
|
| 956 |
+
pasts_or_spout_value = self.spout(spout) # projecting `spout` vector
|
| 957 |
+
pasts_or_spout_value = torch.reshape(
|
| 958 |
+
pasts_or_spout_value,
|
| 959 |
+
[
|
| 960 |
+
num_batch,
|
| 961 |
+
self.config.num_layers,
|
| 962 |
+
2,
|
| 963 |
+
self.config.num_heads,
|
| 964 |
+
num_pasts_contexts,
|
| 965 |
+
self.config.d_model // self.config.num_heads,
|
| 966 |
+
],
|
| 967 |
+
)
|
| 968 |
+
pasts_or_spout_value = torch.split(pasts_or_spout_value, [1] * self.config.num_layers, dim=1)
|
| 969 |
+
# make same shape as past_key_values
|
| 970 |
+
pasts_or_spout_value = tuple(
|
| 971 |
+
tuple([b.squeeze(1) for b in torch.split(a.squeeze(1), [1, 1], dim=1)]) for a in pasts_or_spout_value
|
| 972 |
+
)
|
| 973 |
+
else:
|
| 974 |
+
pasts_or_spout_value = [None] * self.config.num_layers
|
| 975 |
+
|
| 976 |
+
# Token position considering spout and pasts
|
| 977 |
+
token_position = torch.arange(num_input_contexts).to(device) + num_pasts_contexts
|
| 978 |
+
|
| 979 |
+
if attention_mask is None:
|
| 980 |
+
attention_mask = torch.ones(num_batch, num_input_contexts, device=device)
|
| 981 |
+
|
| 982 |
+
# positions for get position_embeddings
|
| 983 |
+
gather_position = (
|
| 984 |
+
(
|
| 985 |
+
torch.zeros((num_batch, self.config.d_model, num_input_contexts)).to(device)
|
| 986 |
+
+ token_position.unsqueeze(0)
|
| 987 |
+
)
|
| 988 |
+
.transpose(1, 2)
|
| 989 |
+
.long()
|
| 990 |
+
)
|
| 991 |
+
# When padding with padding_side="left", zeros line up on the left side of attention_mask, so position_embeddings is shifted accordingly
|
| 992 |
+
gather_position -= (1 - attention_mask).argmin(dim=-1).unsqueeze(1).unsqueeze(2)
|
| 993 |
+
gather_position = torch.clip(gather_position, num_pasts_contexts, self.config.max_position_embeddings - 1)
|
| 994 |
+
|
| 995 |
+
# attention_mask is applied per batch
|
| 996 |
+
for i in range(num_batch):
|
| 997 |
+
hidden_states[i] += torch.gather(self.position_embeddings.weight, dim=0, index=gather_position[i])
|
| 998 |
+
|
| 999 |
+
# Create a mask to be used when making the prefix Input length of Prefix-LM variable
|
| 1000 |
+
causal_mask = (
|
| 1001 |
+
torch.tril(torch.ones((num_output_contexts, num_output_contexts), dtype=torch.uint8))
|
| 1002 |
+
.view(1, 1, num_output_contexts, num_output_contexts)
|
| 1003 |
+
.to(device)
|
| 1004 |
+
)
|
| 1005 |
+
prefix_lm_mask = causal_mask[:, :, -num_input_contexts:, :]
|
| 1006 |
+
if token_type_ids is not None:
|
| 1007 |
+
token_type_ids = token_type_ids.unsqueeze(1).unsqueeze(2)
|
| 1008 |
+
prefix_lm_mask = ((prefix_lm_mask + token_type_ids) > 0).float()
|
| 1009 |
+
# Marge prefix_lm_mask and attention_mask
|
| 1010 |
+
extended_attention_mask = prefix_lm_mask * attention_mask.unsqueeze(1).unsqueeze(2)
|
| 1011 |
+
|
| 1012 |
+
# Prepare head mask if needed
|
| 1013 |
+
if head_mask is not None:
|
| 1014 |
+
head_mask = self.get_head_mask(
|
| 1015 |
+
head_mask, self.config.num_switch_layers + self.config.num_ext_layers
|
| 1016 |
+
) # n_layer x batch x n_heads x N x N
|
| 1017 |
+
|
| 1018 |
+
# outputs
|
| 1019 |
+
present_key_value_states = () if self.config.use_cache or use_cache else None
|
| 1020 |
+
all_hidden_states = () if self.config.output_hidden_states or output_hidden_states else None
|
| 1021 |
+
all_attentions = () if self.config.output_attentions or output_attentions else None
|
| 1022 |
+
all_router_probs = () if self.config.output_router_logits or output_router_logits else None
|
| 1023 |
+
|
| 1024 |
+
for layer, past in enumerate(pasts_or_spout_value):
|
| 1025 |
+
if layer == self.config.num_switch_layers:
|
| 1026 |
+
if self.config.num_ext_layers > 0:
|
| 1027 |
+
# extra_position_embeddings are extra position embeddings that are only created when extending the model with code from the original GPTSAN repository. Not used in the default model.
|
| 1028 |
+
# However, it is created when you create an additional layer and partially train only that location.
|
| 1029 |
+
# Therefore, convert_gptsan_tf_checkpoint_to_pytorch.py is used when converting and loading models created in the original GPTSAN repository.
|
| 1030 |
+
for i in range(num_batch):
|
| 1031 |
+
hidden_states[i] += torch.gather(
|
| 1032 |
+
self.extra_position_embeddings.weight, dim=0, index=gather_position[i]
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
output_router_tuple = (
|
| 1036 |
+
self.config.output_router_logits or output_router_logits
|
| 1037 |
+
) and layer < self.config.num_switch_layers
|
| 1038 |
+
block_output = self.blocks[layer](
|
| 1039 |
+
hidden_states=hidden_states,
|
| 1040 |
+
past_key_value=past,
|
| 1041 |
+
attention_mask=extended_attention_mask,
|
| 1042 |
+
head_mask=head_mask,
|
| 1043 |
+
use_cache=self.config.use_cache or use_cache,
|
| 1044 |
+
output_attentions=self.config.output_attentions or output_attentions,
|
| 1045 |
+
output_router_tuple=output_router_tuple,
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
outpos = 0
|
| 1049 |
+
hidden_states = block_output[outpos]
|
| 1050 |
+
if self.config.output_hidden_states or output_hidden_states:
|
| 1051 |
+
all_hidden_states += (hidden_states,)
|
| 1052 |
+
if self.config.use_cache or use_cache:
|
| 1053 |
+
outpos += 1
|
| 1054 |
+
present = block_output[outpos]
|
| 1055 |
+
present_key_value_states += (present,)
|
| 1056 |
+
if self.config.output_attentions or output_attentions:
|
| 1057 |
+
outpos += 1
|
| 1058 |
+
attention_probs = block_output[outpos]
|
| 1059 |
+
all_attentions += (attention_probs,)
|
| 1060 |
+
if output_router_tuple:
|
| 1061 |
+
outpos += 1
|
| 1062 |
+
router_tuple = block_output[outpos]
|
| 1063 |
+
all_router_probs.append(router_tuple[0])
|
| 1064 |
+
|
| 1065 |
+
hidden_states = self.last_project(hidden_states)
|
| 1066 |
+
hidden_states = self.act(hidden_states)
|
| 1067 |
+
|
| 1068 |
+
if self.config.output_hidden_states or output_hidden_states:
|
| 1069 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1070 |
+
|
| 1071 |
+
if not return_dict:
|
| 1072 |
+
return tuple(
|
| 1073 |
+
v
|
| 1074 |
+
for v in [
|
| 1075 |
+
hidden_states,
|
| 1076 |
+
present_key_value_states,
|
| 1077 |
+
all_hidden_states,
|
| 1078 |
+
all_attentions,
|
| 1079 |
+
all_router_probs,
|
| 1080 |
+
]
|
| 1081 |
+
if v is not None
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
return MoEModelOutputWithPastAndCrossAttentions(
|
| 1085 |
+
last_hidden_state=hidden_states,
|
| 1086 |
+
past_key_values=present_key_value_states,
|
| 1087 |
+
hidden_states=all_hidden_states,
|
| 1088 |
+
attentions=all_attentions,
|
| 1089 |
+
router_probs=all_router_probs,
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
@add_start_docstrings(
|
| 1094 |
+
"The bare GPTSAN-japanese Model with a language modeling head.",
|
| 1095 |
+
GPTSAN_JAPANESE_START_DOCSTRING,
|
| 1096 |
+
)
|
| 1097 |
+
class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
|
| 1098 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1099 |
+
|
| 1100 |
+
def __init__(self, config: GPTSanJapaneseConfig):
|
| 1101 |
+
super().__init__(config)
|
| 1102 |
+
self.model = GPTSanJapaneseModel(config)
|
| 1103 |
+
self.register_buffer("final_logits_bias", torch.zeros([1, config.vocab_size]))
|
| 1104 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 1105 |
+
if not self.config.torchscript:
|
| 1106 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 1107 |
+
|
| 1108 |
+
@add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING)
|
| 1109 |
+
def forward(
|
| 1110 |
+
self,
|
| 1111 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1112 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1113 |
+
token_type_ids: Optional[torch.FloatTensor] = None,
|
| 1114 |
+
spout: Optional[torch.FloatTensor] = None,
|
| 1115 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1116 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1117 |
+
use_cache: Optional[bool] = False,
|
| 1118 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1119 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1120 |
+
output_attentions: Optional[bool] = None,
|
| 1121 |
+
output_hidden_states: Optional[bool] = None,
|
| 1122 |
+
return_dict: Optional[bool] = None,
|
| 1123 |
+
output_router_logits: Optional[bool] = None,
|
| 1124 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1125 |
+
) -> Union[Tuple[torch.FloatTensor], MoECausalLMOutputWithPast]:
|
| 1126 |
+
r"""
|
| 1127 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1128 |
+
Labels for computing the sequence classification loss. Indices should be in `[-100, 0, ...,
|
| 1129 |
+
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
|
| 1130 |
+
labels in `[0, ..., config.vocab_size]`
|
| 1131 |
+
|
| 1132 |
+
Returns:
|
| 1133 |
+
`MoECausalLMOutputWithPast` or `tuple` if `return_dict` returns MoECausalLMOutputWithPast insted of tuple
|
| 1134 |
+
|
| 1135 |
+
Example:
|
| 1136 |
+
|
| 1137 |
+
Text Generation with regular LM Model
|
| 1138 |
+
```python
|
| 1139 |
+
>>> from transformers import AutoModel, AutoTokenizer, trainer_utils
|
| 1140 |
+
|
| 1141 |
+
>>> device = "cuda"
|
| 1142 |
+
>>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
|
| 1143 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
| 1144 |
+
>>> x_token = tokenizer("織田信長は、", return_tensors="pt")
|
| 1145 |
+
>>> trainer_utils.set_seed(30)
|
| 1146 |
+
>>> input_ids = x_token.input_ids.to(device)
|
| 1147 |
+
>>> gen_token = model.generate(input_ids, max_new_tokens=50)
|
| 1148 |
+
>>> tokenizer.decode(gen_token[0])
|
| 1149 |
+
"織田信長は、政治・軍事の中枢まで掌握した政治家であり、日本史上類を見ない驚異的な軍事侵攻を続け..."
|
| 1150 |
+
```
|
| 1151 |
+
|
| 1152 |
+
Text Generation with Prefix-LM Model
|
| 1153 |
+
```python
|
| 1154 |
+
>>> from transformers import AutoModel, AutoTokenizer, trainer_utils
|
| 1155 |
+
|
| 1156 |
+
>>> device = "cuda"
|
| 1157 |
+
>>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
|
| 1158 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
| 1159 |
+
>>> x_token = tokenizer("", prefix_text="織田信長は、", return_tensors="pt")
|
| 1160 |
+
>>> trainer_utils.set_seed(30)
|
| 1161 |
+
>>> input_ids = x_token.input_ids.to(device)
|
| 1162 |
+
>>> token_type_ids = x_token.token_type_ids.to(device)
|
| 1163 |
+
>>> gen_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=50)
|
| 1164 |
+
>>> tokenizer.decode(gen_token[0])
|
| 1165 |
+
"織田信長は、政治・外交で数々の戦果を上げるが、1568年からは、いわゆる本能寺の変で細川晴元に暗殺される..."
|
| 1166 |
+
```
|
| 1167 |
+
|
| 1168 |
+
Simultaneously Text Generation And Masked Language Model
|
| 1169 |
+
```python
|
| 1170 |
+
>>> from transformers import AutoModel, AutoTokenizer, trainer_utils
|
| 1171 |
+
|
| 1172 |
+
>>> device = "cuda"
|
| 1173 |
+
>>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
|
| 1174 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
| 1175 |
+
>>> masked_sentence = "武田信玄は、<|inputmask|>時代ファンならぜひ押さえ<|inputmask|>きたい名将の一人。"
|
| 1176 |
+
>>> x_token = tokenizer("", prefix_text=masked_sentence, return_tensors="pt")
|
| 1177 |
+
>>> trainer_utils.set_seed(30)
|
| 1178 |
+
>>> input_ids = x_token.input_ids.to(device)
|
| 1179 |
+
>>> token_type_ids = x_token.token_type_ids.to(device)
|
| 1180 |
+
>>> out_lm_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=50)
|
| 1181 |
+
>>> out_mlm_token = model(input_ids, token_type_ids=token_type_ids).logits.argmax(axis=-1)
|
| 1182 |
+
>>> tokenizer.decode(out_mlm_token[0])
|
| 1183 |
+
"武田信玄は、戦国時代ファンならぜひ押さえておきたい名将の一人。"
|
| 1184 |
+
|
| 1185 |
+
>>> tokenizer.decode(out_lm_token[0][input_ids.shape[1] :])
|
| 1186 |
+
"武田氏の三代に渡った武田家のひとり\n甲斐市に住む、日本史上最大の戦国大名。..."
|
| 1187 |
+
```"""
|
| 1188 |
+
SEG_TOKEN = self.config.separator_token_id
|
| 1189 |
+
use_cache = use_cache or self.config.use_cache
|
| 1190 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1191 |
+
model_return_dict = True
|
| 1192 |
+
num_precontext = None
|
| 1193 |
+
if input_ids is not None:
|
| 1194 |
+
num_batch = input_ids.shape[0]
|
| 1195 |
+
num_precontext = torch.zeros([num_batch]).int().to(input_ids.device)
|
| 1196 |
+
where_separators = torch.where(input_ids == SEG_TOKEN)
|
| 1197 |
+
num_precontext[where_separators[0]] += where_separators[1]
|
| 1198 |
+
num_precontext = num_precontext.unsqueeze(1)
|
| 1199 |
+
|
| 1200 |
+
outputs = self.model(
|
| 1201 |
+
input_ids,
|
| 1202 |
+
attention_mask,
|
| 1203 |
+
token_type_ids,
|
| 1204 |
+
spout,
|
| 1205 |
+
past_key_values,
|
| 1206 |
+
head_mask,
|
| 1207 |
+
use_cache,
|
| 1208 |
+
inputs_embeds,
|
| 1209 |
+
decoder_inputs_embeds,
|
| 1210 |
+
output_attentions,
|
| 1211 |
+
output_hidden_states,
|
| 1212 |
+
model_return_dict,
|
| 1213 |
+
output_router_logits,
|
| 1214 |
+
num_precontext,
|
| 1215 |
+
)
|
| 1216 |
+
|
| 1217 |
+
lm_logits = self.lm_head(outputs[0])
|
| 1218 |
+
if lm_logits.shape[-1] == self.final_logits_bias.shape[-1]:
|
| 1219 |
+
lm_logits = lm_logits + self.final_logits_bias
|
| 1220 |
+
|
| 1221 |
+
loss = None
|
| 1222 |
+
z_loss = None
|
| 1223 |
+
router_probs = None
|
| 1224 |
+
aux_loss = None
|
| 1225 |
+
if labels is not None:
|
| 1226 |
+
# move labels to correct device to enable model parallelism
|
| 1227 |
+
labels = labels.to(lm_logits.device)
|
| 1228 |
+
|
| 1229 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 1230 |
+
|
| 1231 |
+
if output_router_logits:
|
| 1232 |
+
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
|
| 1233 |
+
router_logits, expert_indexes = self._unpack_router_logits(outputs.router_probs)
|
| 1234 |
+
z_loss = router_z_loss_func(router_logits)
|
| 1235 |
+
router_probs = nn.Softmax(dim=-1)(router_logits)
|
| 1236 |
+
aux_loss = load_balancing_loss_func(router_probs, expert_indexes)
|
| 1237 |
+
|
| 1238 |
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
| 1239 |
+
|
| 1240 |
+
if not return_dict:
|
| 1241 |
+
return tuple(
|
| 1242 |
+
v
|
| 1243 |
+
for v in [
|
| 1244 |
+
loss,
|
| 1245 |
+
lm_logits,
|
| 1246 |
+
outputs.past_key_values,
|
| 1247 |
+
outputs.hidden_states,
|
| 1248 |
+
outputs.router_probs,
|
| 1249 |
+
z_loss,
|
| 1250 |
+
aux_loss,
|
| 1251 |
+
]
|
| 1252 |
+
if v is not None
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
return MoECausalLMOutputWithPast(
|
| 1256 |
+
loss=loss,
|
| 1257 |
+
logits=lm_logits,
|
| 1258 |
+
past_key_values=outputs.past_key_values,
|
| 1259 |
+
hidden_states=outputs.hidden_states,
|
| 1260 |
+
attentions=outputs.attentions,
|
| 1261 |
+
router_logits=outputs.router_probs,
|
| 1262 |
+
z_loss=z_loss,
|
| 1263 |
+
aux_loss=aux_loss,
|
| 1264 |
+
)
|
| 1265 |
+
|
| 1266 |
+
def prepare_inputs_for_generation(
|
| 1267 |
+
self,
|
| 1268 |
+
input_ids: torch.LongTensor,
|
| 1269 |
+
attention_mask: torch.FloatTensor,
|
| 1270 |
+
token_type_ids: Optional[torch.FloatTensor] = None,
|
| 1271 |
+
spout: Optional[Union[List, torch.FloatTensor]] = None,
|
| 1272 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1273 |
+
**kwargs,
|
| 1274 |
+
):
|
| 1275 |
+
if isinstance(spout, list):
|
| 1276 |
+
spout = torch.tensor(spout).float()
|
| 1277 |
+
if input_ids is not None:
|
| 1278 |
+
spout = spout.to(input_ids.device)
|
| 1279 |
+
if past_key_values is not None:
|
| 1280 |
+
return {
|
| 1281 |
+
"input_ids": input_ids[:, -1:] if input_ids is not None else None,
|
| 1282 |
+
"attention_mask": attention_mask,
|
| 1283 |
+
"token_type_ids": token_type_ids[:, -1:] if token_type_ids is not None else None,
|
| 1284 |
+
"spout": spout,
|
| 1285 |
+
"past_key_values": past_key_values,
|
| 1286 |
+
}
|
| 1287 |
+
return {
|
| 1288 |
+
"input_ids": input_ids,
|
| 1289 |
+
"attention_mask": attention_mask,
|
| 1290 |
+
"token_type_ids": token_type_ids,
|
| 1291 |
+
"spout": spout,
|
| 1292 |
+
"past_key_values": None,
|
| 1293 |
+
}
|
| 1294 |
+
|
| 1295 |
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
| 1296 |
+
return self._shift_right(labels)
|
| 1297 |
+
|
| 1298 |
+
def resize_token_embeddings(
|
| 1299 |
+
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
|
| 1300 |
+
) -> nn.Embedding:
|
| 1301 |
+
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
|
| 1302 |
+
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
| 1303 |
+
return new_embeddings
|
| 1304 |
+
|
| 1305 |
+
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
| 1306 |
+
old_num_tokens = self.final_logits_bias.shape[-1]
|
| 1307 |
+
if new_num_tokens <= old_num_tokens:
|
| 1308 |
+
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
| 1309 |
+
else:
|
| 1310 |
+
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
| 1311 |
+
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
| 1312 |
+
self.register_buffer("final_logits_bias", new_bias)
|
| 1313 |
+
|
| 1314 |
+
def get_input_embeddings(self):
|
| 1315 |
+
return self.model.get_input_embeddings()
|
| 1316 |
+
|
| 1317 |
+
def set_input_embeddings(self, new_embeddings):
|
| 1318 |
+
self.model.set_input_embeddings(new_embeddings)
|
| 1319 |
+
|
| 1320 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1321 |
+
self.lm_head = new_embeddings
|
| 1322 |
+
|
| 1323 |
+
def get_output_embeddings(self):
|
| 1324 |
+
return self.lm_head
|
| 1325 |
+
|
| 1326 |
+
def _unpack_router_logits(self, router_outputs):
|
| 1327 |
+
total_router_logits = []
|
| 1328 |
+
total_expert_indexes = []
|
| 1329 |
+
for router_output in router_outputs:
|
| 1330 |
+
if len(router_output[0].shape) > 1:
|
| 1331 |
+
router_logits, expert_indexes = router_output
|
| 1332 |
+
total_router_logits.append(router_logits)
|
| 1333 |
+
total_expert_indexes.append(expert_indexes)
|
| 1334 |
+
return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
|
| 1335 |
+
|
| 1336 |
+
|
| 1337 |
+
__all__ = ["GPTSanJapaneseForConditionalGeneration", "GPTSanJapaneseModel", "GPTSanJapanesePreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for GPTSANJapanese."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import re
|
| 21 |
+
import sys
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
from ....tokenization_utils import PreTrainedTokenizer
|
| 27 |
+
from ....tokenization_utils_base import (
|
| 28 |
+
BatchEncoding,
|
| 29 |
+
PreTokenizedInput,
|
| 30 |
+
PreTokenizedInputPair,
|
| 31 |
+
TextInput,
|
| 32 |
+
TextInputPair,
|
| 33 |
+
TruncationStrategy,
|
| 34 |
+
)
|
| 35 |
+
from ....utils import PaddingStrategy, logging
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_vocab_and_emoji(vocab_file, emoji_file):
|
| 44 |
+
"""Loads a vocabulary file and emoji file into a dictionary."""
|
| 45 |
+
with open(emoji_file, "r", encoding="utf-8") as f:
|
| 46 |
+
emoji = json.loads(f.read())
|
| 47 |
+
|
| 48 |
+
vocab = collections.OrderedDict()
|
| 49 |
+
raw_vocab = collections.OrderedDict()
|
| 50 |
+
ids_to_tokens = collections.OrderedDict()
|
| 51 |
+
with open(vocab_file, "r", encoding="utf-8") as f:
|
| 52 |
+
token = f.readlines()
|
| 53 |
+
token = [[t.rstrip("\n")] if (t == ",\n" or "," not in t) else t.rstrip("\n").split(",") for t in token]
|
| 54 |
+
for idx, b in enumerate(token):
|
| 55 |
+
ids_to_tokens[idx] = b
|
| 56 |
+
raw_vocab[",".join(b)] = idx
|
| 57 |
+
for wd in b:
|
| 58 |
+
vocab[wd] = idx
|
| 59 |
+
|
| 60 |
+
return vocab, raw_vocab, ids_to_tokens, emoji
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class GPTSanJapaneseTokenizer(PreTrainedTokenizer):
|
| 64 |
+
"""
|
| 65 |
+
This tokenizer is based on GPTNeoXJapaneseTokenizer and has the following modifications
|
| 66 |
+
- Decoding byte0~byte255 tokens correctly
|
| 67 |
+
- Added bagofword token handling
|
| 68 |
+
- Return token_type_ids for Prefix-LM model
|
| 69 |
+
The bagofword token represents a repetition of the previous token and is converted to 3 consecutive tokens when
|
| 70 |
+
decoding In addition, the original Japanese special Sub-Word-Encoding has been released in this repository
|
| 71 |
+
(https://github.com/tanreinama/Japanese-BPEEncoder_V2). The token_type_ids is a mask indicating the prefix input
|
| 72 |
+
position of the Prefix-LM model. To specify a prefix position, specify a prefix input for prefix_text, or specify a
|
| 73 |
+
sentence of the prefix part and the part after it as a text pair of batch input.
|
| 74 |
+
|
| 75 |
+
Example:
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
>>> from transformers import GPTSanJapaneseTokenizer
|
| 79 |
+
|
| 80 |
+
>>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
| 81 |
+
>>> # You can confirm both 慶応 and 慶應 are encoded to 17750
|
| 82 |
+
>>> tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"]
|
| 83 |
+
[35993, 35998, 34347, 31459, 30647, 31448, 25, 30659, 35729, 35676, 32417, 30647, 17750, 35589, 17750, 35590, 321, 1281]
|
| 84 |
+
|
| 85 |
+
>>> # Both 慶応 and 慶應 are decoded to 慶応
|
| 86 |
+
>>> tokenizer.decode(tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"])
|
| 87 |
+
'吾輩は猫である🐯。実は慶応(慶応)大学出身'
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
Example for Prefix-LM:
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
>>> from transformers import GPTSanJapaneseTokenizer
|
| 94 |
+
|
| 95 |
+
>>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
| 96 |
+
>>> tokenizer("実は慶応(慶應)大学出身", prefix_text="吾輩は猫である🐯。")["input_ids"]
|
| 97 |
+
[35993, 34347, 31459, 30647, 31448, 25, 30659, 35729, 35676, 35998, 32417, 30647, 17750, 35589, 17750, 35590, 321, 1281]
|
| 98 |
+
|
| 99 |
+
>>> # Mask for Prefix-LM inputs
|
| 100 |
+
>>> tokenizer("実は慶応(慶應)大学出身", prefix_text="吾輩は猫である🐯。")["token_type_ids"]
|
| 101 |
+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Example for batch encode:
|
| 105 |
+
|
| 106 |
+
```python
|
| 107 |
+
>>> from transformers import GPTSanJapaneseTokenizer
|
| 108 |
+
|
| 109 |
+
>>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
| 110 |
+
>>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["input_ids"]
|
| 111 |
+
[[35993, 35998, 8640, 25948, 35993, 35998, 30647, 35675, 35999, 35999], [35993, 35998, 10382, 9868, 35993, 35998, 30646, 9459, 30646, 35675]]
|
| 112 |
+
|
| 113 |
+
>>> # Mask for Prefix-LM inputs
|
| 114 |
+
>>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["token_type_ids"]
|
| 115 |
+
[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
| 116 |
+
|
| 117 |
+
>>> # Mask for padding
|
| 118 |
+
>>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["attention_mask"]
|
| 119 |
+
[[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
vocab_file (`str`):
|
| 124 |
+
File containing the vocabulary.
|
| 125 |
+
emoji_file (`str`):
|
| 126 |
+
File containing the emoji.
|
| 127 |
+
unk_token (`str`, *optional*, defaults to `"<|nottoken|>"`):
|
| 128 |
+
The token used for unknown charactor
|
| 129 |
+
pad_token (`str`, *optional*, defaults to `"<|separator|>"`):
|
| 130 |
+
The token used for padding
|
| 131 |
+
bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`):
|
| 132 |
+
The beginning of sequence token.
|
| 133 |
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 134 |
+
The end of sequence token.
|
| 135 |
+
sep_token (`str`, *optional*, defaults to `"<|segmenter|>"`):
|
| 136 |
+
A special token to separate token to prefix part and general input part.
|
| 137 |
+
do_clean_text (`bool`, *optional*, defaults to `False`):
|
| 138 |
+
Whether or not to clean text for URL, EMAIL, TEL, Japanese DATE and Japanese PRICE.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 142 |
+
model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
vocab_file,
|
| 147 |
+
emoji_file,
|
| 148 |
+
unk_token="<|nottoken|>",
|
| 149 |
+
pad_token="<|separator|>",
|
| 150 |
+
bos_token="<|startoftext|>",
|
| 151 |
+
eos_token="<|endoftext|>",
|
| 152 |
+
sep_token="<|segmenter|>",
|
| 153 |
+
do_clean_text=False,
|
| 154 |
+
**kwargs,
|
| 155 |
+
):
|
| 156 |
+
if not os.path.isfile(vocab_file):
|
| 157 |
+
raise ValueError(
|
| 158 |
+
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
|
| 159 |
+
" model use `tokenizer = GPTSanJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
| 160 |
+
)
|
| 161 |
+
if not os.path.isfile(emoji_file):
|
| 162 |
+
raise ValueError(
|
| 163 |
+
f"Can't find a emoji file at path '{emoji_file}'. To load the emoji information from a Google"
|
| 164 |
+
" pretrained model use `tokenizer = GPTSanJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
| 165 |
+
)
|
| 166 |
+
self.do_clean_text = do_clean_text
|
| 167 |
+
self.vocab, self.raw_vocab, self.ids_to_tokens, self.emoji = load_vocab_and_emoji(vocab_file, emoji_file)
|
| 168 |
+
self.subword_tokenizer = SubWordJapaneseTokenizer(
|
| 169 |
+
vocab=self.vocab, ids_to_tokens=self.ids_to_tokens, emoji=self.emoji
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
super().__init__(
|
| 173 |
+
unk_token=unk_token,
|
| 174 |
+
pad_token=pad_token,
|
| 175 |
+
bos_token=bos_token,
|
| 176 |
+
eos_token=eos_token,
|
| 177 |
+
sep_token=sep_token,
|
| 178 |
+
do_clean_text=do_clean_text,
|
| 179 |
+
**kwargs,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def vocab_size(self):
|
| 184 |
+
# self.vocab contains support for character fluctuation unique to Japanese, and has a large number of vocab
|
| 185 |
+
return len(self.raw_vocab)
|
| 186 |
+
|
| 187 |
+
def get_vocab(self):
|
| 188 |
+
return dict(self.raw_vocab, **self.added_tokens_encoder)
|
| 189 |
+
|
| 190 |
+
def _tokenize(self, text):
|
| 191 |
+
return self.subword_tokenizer.tokenize(text, clean=self.do_clean_text)
|
| 192 |
+
|
| 193 |
+
def _convert_token_to_id(self, token):
|
| 194 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 195 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 196 |
+
|
| 197 |
+
def _convert_id_to_token(self, index):
|
| 198 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 199 |
+
return self.subword_tokenizer.convert_id_to_token(index)
|
| 200 |
+
|
| 201 |
+
def convert_tokens_to_string(self, tokens):
|
| 202 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 203 |
+
words = []
|
| 204 |
+
byte_tokens = []
|
| 205 |
+
for word in tokens:
|
| 206 |
+
if word[:6] == "<|byte" and word[-2:] == "|>":
|
| 207 |
+
byte_tokens.append(int(word[6:-2]))
|
| 208 |
+
else:
|
| 209 |
+
if len(byte_tokens) > 0:
|
| 210 |
+
words.append(bytearray(byte_tokens).decode("utf-8", errors="replace"))
|
| 211 |
+
byte_tokens = []
|
| 212 |
+
if word[:7] == "<|emoji" and word[-2:] == "|>":
|
| 213 |
+
words.append(self.emoji["emoji_inv"][word])
|
| 214 |
+
elif word == "<SP>":
|
| 215 |
+
words.append(" ")
|
| 216 |
+
elif word == "<BR>":
|
| 217 |
+
words.append("\n")
|
| 218 |
+
elif word == "<TAB>":
|
| 219 |
+
words.append("\t")
|
| 220 |
+
elif word == "<BLOCK>":
|
| 221 |
+
words.append("▀")
|
| 222 |
+
elif word == "<KIGOU>":
|
| 223 |
+
words.append("ǀ")
|
| 224 |
+
elif word == "<U2000U2BFF>":
|
| 225 |
+
words.append("‖")
|
| 226 |
+
elif word == "<|bagoftoken|>":
|
| 227 |
+
if len(words) > 0:
|
| 228 |
+
words.append(words[-1])
|
| 229 |
+
words.append(words[-1])
|
| 230 |
+
words.append(words[-1])
|
| 231 |
+
elif word.startswith("<|") and word.endswith("|>"):
|
| 232 |
+
words.append("")
|
| 233 |
+
else:
|
| 234 |
+
words.append(word)
|
| 235 |
+
if len(byte_tokens) > 0:
|
| 236 |
+
words.append(bytearray(byte_tokens).decode("utf-8", errors="replace"))
|
| 237 |
+
text = "".join(words)
|
| 238 |
+
return text
|
| 239 |
+
|
| 240 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 241 |
+
index = 0
|
| 242 |
+
if os.path.isdir(save_directory):
|
| 243 |
+
vocab_file = os.path.join(
|
| 244 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 245 |
+
)
|
| 246 |
+
emoji_file = os.path.join(
|
| 247 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["emoji_file"]
|
| 248 |
+
)
|
| 249 |
+
else:
|
| 250 |
+
vocab_file = (
|
| 251 |
+
(filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["vocab_file"]
|
| 252 |
+
)
|
| 253 |
+
emoji_file = (
|
| 254 |
+
(filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["emoji_file"]
|
| 255 |
+
)
|
| 256 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 257 |
+
for token_index, token in self.ids_to_tokens.items():
|
| 258 |
+
if index != token_index:
|
| 259 |
+
logger.warning(
|
| 260 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
| 261 |
+
" Please check that the vocabulary is not corrupted!"
|
| 262 |
+
)
|
| 263 |
+
index = token_index
|
| 264 |
+
writer.write(",".join(token) + "\n")
|
| 265 |
+
index += 1
|
| 266 |
+
with open(emoji_file, "w", encoding="utf-8") as writer:
|
| 267 |
+
json.dump(self.emoji, writer)
|
| 268 |
+
return vocab_file, emoji_file
|
| 269 |
+
|
| 270 |
+
def create_token_type_ids_from_sequences(
|
| 271 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 272 |
+
) -> List[int]:
|
| 273 |
+
# docstyle-ignore
|
| 274 |
+
"""
|
| 275 |
+
The tokenizer returns token_type_ids as separators between the Prefix part and the rest.
|
| 276 |
+
token_type_ids is 1 for the Prefix part and 0 for the rest of the token.
|
| 277 |
+
|
| 278 |
+
Example:
|
| 279 |
+
```python
|
| 280 |
+
>>> from transformers import GPTSanJapaneseTokenizer
|
| 281 |
+
|
| 282 |
+
>>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
| 283 |
+
>>> x_token = tokenizer("アイウエ")
|
| 284 |
+
>>> # input_ids: | SOT | SEG | ア | イ | ウ | エ |
|
| 285 |
+
>>> # token_type_ids: | 1 | 0 | 0 | 0 | 0 | 0 |
|
| 286 |
+
|
| 287 |
+
>>> x_token = tokenizer("", prefix_text="アイウエ")
|
| 288 |
+
>>> # input_ids: | SOT | ア | イ | ウ | エ | SEG |
|
| 289 |
+
>>> # token_type_ids: | 1 | 1 | 1 | 1 | 1 | 0 |
|
| 290 |
+
|
| 291 |
+
>>> x_token = tokenizer("ウエ", prefix_text="アイ")
|
| 292 |
+
>>> # input_ids: | SOT | ア | イ | SEG | ウ | エ |
|
| 293 |
+
>>> # token_type_ids: | 1 | 1 | 1 | 0 | 0 | 0 |
|
| 294 |
+
```"""
|
| 295 |
+
prefix_len = 0
|
| 296 |
+
if self.sep_token in self.vocab:
|
| 297 |
+
segid = self.vocab[self.sep_token]
|
| 298 |
+
if segid in token_ids_0:
|
| 299 |
+
prefix_len = token_ids_0.index(segid)
|
| 300 |
+
if token_ids_1 is None:
|
| 301 |
+
total_len = len(token_ids_0)
|
| 302 |
+
else:
|
| 303 |
+
total_len = len(token_ids_0 + token_ids_1)
|
| 304 |
+
return prefix_len * [1] + (total_len - prefix_len) * [0]
|
| 305 |
+
|
| 306 |
+
def prepare_for_tokenization(self, text, prefix_text=None, add_sep_token=None, **kwargs):
|
| 307 |
+
# GPTSAN inserts extra SEP tokens in Prefix-LM in addition to SOT for text generation.
|
| 308 |
+
# SOT at the beginning of the text, and SEP at the separator between the Prefix part and the rest.
|
| 309 |
+
if add_sep_token is None:
|
| 310 |
+
add_sep_token = self.sep_token not in text # If insert un-prefix position explicitly
|
| 311 |
+
prepared = self.bos_token if self.bos_token in self.vocab else ""
|
| 312 |
+
prepared += prefix_text if prefix_text is not None else ""
|
| 313 |
+
if add_sep_token:
|
| 314 |
+
prepared += self.sep_token if self.sep_token in self.vocab else ""
|
| 315 |
+
prepared += text
|
| 316 |
+
return (prepared, kwargs)
|
| 317 |
+
|
| 318 |
+
def _batch_encode_plus(
|
| 319 |
+
self,
|
| 320 |
+
batch_text_or_text_pairs: Union[
|
| 321 |
+
List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair]
|
| 322 |
+
],
|
| 323 |
+
add_special_tokens: bool = True,
|
| 324 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 325 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 326 |
+
max_length: Optional[int] = None,
|
| 327 |
+
stride: int = 0,
|
| 328 |
+
is_split_into_words: bool = False,
|
| 329 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 330 |
+
return_tensors: Optional[str] = None,
|
| 331 |
+
return_token_type_ids: Optional[bool] = None,
|
| 332 |
+
return_attention_mask: Optional[bool] = None,
|
| 333 |
+
return_overflowing_tokens: bool = False,
|
| 334 |
+
return_special_tokens_mask: bool = False,
|
| 335 |
+
return_offsets_mapping: bool = False,
|
| 336 |
+
return_length: bool = False,
|
| 337 |
+
verbose: bool = True,
|
| 338 |
+
**kwargs,
|
| 339 |
+
) -> BatchEncoding:
|
| 340 |
+
# This tokenizer converts input text pairs into Prefix input and subsequent input
|
| 341 |
+
if isinstance(batch_text_or_text_pairs[0], tuple) or isinstance(tuple(batch_text_or_text_pairs[0]), list):
|
| 342 |
+
# As a single text with an explicit un-prefix position
|
| 343 |
+
batch_prefix_texts = []
|
| 344 |
+
for pref, txt in batch_text_or_text_pairs:
|
| 345 |
+
batch_prefix_texts.append(pref + self.sep_token + txt)
|
| 346 |
+
batch_text_or_text_pairs = batch_prefix_texts
|
| 347 |
+
|
| 348 |
+
return super()._batch_encode_plus(
|
| 349 |
+
batch_text_or_text_pairs,
|
| 350 |
+
add_special_tokens,
|
| 351 |
+
padding_strategy,
|
| 352 |
+
truncation_strategy,
|
| 353 |
+
max_length,
|
| 354 |
+
stride,
|
| 355 |
+
is_split_into_words,
|
| 356 |
+
pad_to_multiple_of,
|
| 357 |
+
return_tensors,
|
| 358 |
+
return_token_type_ids,
|
| 359 |
+
return_attention_mask,
|
| 360 |
+
return_overflowing_tokens,
|
| 361 |
+
return_special_tokens_mask,
|
| 362 |
+
return_offsets_mapping,
|
| 363 |
+
return_length,
|
| 364 |
+
verbose,
|
| 365 |
+
**kwargs,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class SubWordJapaneseTokenizer:
|
| 370 |
+
"""
|
| 371 |
+
This tokenizer is based on GPTNeoXJapaneseTokenizer and has the following modifications
|
| 372 |
+
- Decoding byte0~byte255 tokens correctly
|
| 373 |
+
- Added bagofword token handling
|
| 374 |
+
|
| 375 |
+
https://github.com/tanreinama/Japanese-BPEEncoder_V2 This tokenizer class is under MIT Lisence according to the
|
| 376 |
+
original repository.
|
| 377 |
+
|
| 378 |
+
MIT License
|
| 379 |
+
|
| 380 |
+
Copyright (c) 2020 tanreinama
|
| 381 |
+
|
| 382 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
|
| 383 |
+
documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
|
| 384 |
+
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
|
| 385 |
+
permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 386 |
+
|
| 387 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of
|
| 388 |
+
the Software.
|
| 389 |
+
|
| 390 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
|
| 391 |
+
THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 392 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
| 393 |
+
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 394 |
+
SOFTWARE.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
def __init__(self, vocab, ids_to_tokens, emoji):
|
| 398 |
+
self.vocab = vocab # same as swe
|
| 399 |
+
self.ids_to_tokens = ids_to_tokens # same as bpe
|
| 400 |
+
self.emoji = emoji
|
| 401 |
+
self.maxlen = np.max([len(w) for w in self.vocab.keys()])
|
| 402 |
+
self.content_repatter1 = re.compile(r"(https?|ftp)(:\/\/[-_\.!~*\'()a-zA-Z0-9;\/?:\@&=\+$,%#]+)")
|
| 403 |
+
self.content_repatter2 = re.compile(r"[A-Za-z0-9\._+]*@[\-_0-9A-Za-z]+(\.[A-Za-z]+)*")
|
| 404 |
+
self.content_repatter3 = re.compile(r"[\(]{0,1}[0-9]{2,4}[\)\-\(]{0,1}[0-9]{2,4}[\)\-]{0,1}[0-9]{3,4}")
|
| 405 |
+
self.content_repatter4 = re.compile(
|
| 406 |
+
r"([12]\d{3}[/\-年])*(0?[1-9]|1[0-2])[/\-月]((0?[1-9]|[12][0-9]|3[01])日?)*(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
|
| 407 |
+
)
|
| 408 |
+
self.content_repatter5 = re.compile(
|
| 409 |
+
r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
|
| 410 |
+
)
|
| 411 |
+
# The original version of this regex displays catastrophic backtracking behaviour. We avoid this using
|
| 412 |
+
# possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly
|
| 413 |
+
# different regex that should generally have the same behaviour in most non-pathological cases.
|
| 414 |
+
if sys.version_info >= (3, 11):
|
| 415 |
+
self.content_repatter6 = re.compile(
|
| 416 |
+
r"(?:\d,\d{3}|[\d億])*+"
|
| 417 |
+
r"(?:\d,\d{3}|[\d万])*+"
|
| 418 |
+
r"(?:\d,\d{3}|[\d千])*+"
|
| 419 |
+
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
|
| 420 |
+
r"(?:\(税込\)|\(税抜\)|\+tax)*"
|
| 421 |
+
)
|
| 422 |
+
else:
|
| 423 |
+
self.content_repatter6 = re.compile(
|
| 424 |
+
r"(?:\d,\d{3}|[\d億万千])*"
|
| 425 |
+
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
|
| 426 |
+
r"(?:\(税込\)|\(税抜\)|\+tax)*"
|
| 427 |
+
)
|
| 428 |
+
keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
|
| 429 |
+
blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
|
| 430 |
+
self.content_trans1 = str.maketrans(dict.fromkeys(keisen + blocks, "<BLOCK>"))
|
| 431 |
+
|
| 432 |
+
def __len__(self):
|
| 433 |
+
return len(self.ids_to_tokens)
|
| 434 |
+
|
| 435 |
+
def clean_text(self, content):
|
| 436 |
+
content = self.content_repatter1.sub("<URL>", content)
|
| 437 |
+
content = self.content_repatter2.sub("<EMAIL>", content)
|
| 438 |
+
content = self.content_repatter3.sub("<TEL>", content)
|
| 439 |
+
content = self.content_repatter4.sub("<DATE>", content)
|
| 440 |
+
content = self.content_repatter5.sub("<DATE>", content)
|
| 441 |
+
content = self.content_repatter6.sub("<PRICE>", content)
|
| 442 |
+
content = content.translate(self.content_trans1)
|
| 443 |
+
while "<BLOCK><BLOCK>" in content:
|
| 444 |
+
content = content.replace("<BLOCK><BLOCK>", "<BLOCK>")
|
| 445 |
+
return content
|
| 446 |
+
|
| 447 |
+
def tokenize(self, text, clean=False):
|
| 448 |
+
text = text.replace(" ", "<SP>")
|
| 449 |
+
text = text.replace(" ", "<SP>")
|
| 450 |
+
text = text.replace("\r\n", "<BR>")
|
| 451 |
+
text = text.replace("\n", "<BR>")
|
| 452 |
+
text = text.replace("\r", "<BR>")
|
| 453 |
+
text = text.replace("\t", "<TAB>")
|
| 454 |
+
text = text.replace("—", "ー")
|
| 455 |
+
text = text.replace("−", "ー")
|
| 456 |
+
for k, v in self.emoji["emoji"].items():
|
| 457 |
+
if k in text:
|
| 458 |
+
text = text.replace(k, v)
|
| 459 |
+
if clean:
|
| 460 |
+
text = self.clean_text(text)
|
| 461 |
+
|
| 462 |
+
def check_simbol(x):
|
| 463 |
+
e = x.encode()
|
| 464 |
+
if len(x) == 1 and len(e) == 2:
|
| 465 |
+
c = (int(e[0]) << 8) + int(e[1])
|
| 466 |
+
if (
|
| 467 |
+
(c >= 0xC2A1 and c <= 0xC2BF)
|
| 468 |
+
or (c >= 0xC780 and c <= 0xC783)
|
| 469 |
+
or (c >= 0xCAB9 and c <= 0xCBBF)
|
| 470 |
+
or (c >= 0xCC80 and c <= 0xCDA2)
|
| 471 |
+
):
|
| 472 |
+
return True
|
| 473 |
+
return False
|
| 474 |
+
|
| 475 |
+
def checku2e(x):
|
| 476 |
+
e = x.encode()
|
| 477 |
+
if len(x) == 1 and len(e) == 3:
|
| 478 |
+
c = (int(e[0]) << 16) + (int(e[1]) << 8) + int(e[2])
|
| 479 |
+
if c >= 0xE28080 and c <= 0xE2B07F:
|
| 480 |
+
return True
|
| 481 |
+
return False
|
| 482 |
+
|
| 483 |
+
pos = 0
|
| 484 |
+
result = []
|
| 485 |
+
while pos < len(text):
|
| 486 |
+
end = min(len(text), pos + self.maxlen + 1) if text[pos] == "<" else pos + 3
|
| 487 |
+
candidates = [] # (token_id, token, pos)
|
| 488 |
+
for e in range(end, pos, -1):
|
| 489 |
+
wd = text[pos:e]
|
| 490 |
+
if wd in self.vocab:
|
| 491 |
+
if wd[0] == "<" and len(wd) > 2:
|
| 492 |
+
candidates = [(self.vocab[wd], wd, e)]
|
| 493 |
+
break
|
| 494 |
+
else:
|
| 495 |
+
candidates.append((self.vocab[wd], wd, e))
|
| 496 |
+
if len(candidates) > 0:
|
| 497 |
+
# the smallest token_id is adopted
|
| 498 |
+
_, wd, e = sorted(candidates, key=lambda x: x[0])[0]
|
| 499 |
+
result.append(wd)
|
| 500 |
+
pos = e
|
| 501 |
+
else:
|
| 502 |
+
end = pos + 1
|
| 503 |
+
wd = text[pos:end]
|
| 504 |
+
if check_simbol(wd):
|
| 505 |
+
result.append("<KIGOU>")
|
| 506 |
+
elif checku2e(wd):
|
| 507 |
+
result.append("<U2000U2BFF>")
|
| 508 |
+
else:
|
| 509 |
+
for i in wd.encode("utf-8"):
|
| 510 |
+
result.append("<|byte%d|>" % i)
|
| 511 |
+
pos = end
|
| 512 |
+
return result
|
| 513 |
+
|
| 514 |
+
def convert_id_to_token(self, index):
|
| 515 |
+
return self.ids_to_tokens[index][0]
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
__all__ = ["GPTSanJapaneseTokenizer"]
|
docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_graphormer import *
|
| 22 |
+
from .modeling_graphormer import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation and HuggingFace
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
import cython
|
| 5 |
+
|
| 6 |
+
cimport numpy
|
| 7 |
+
from cython.parallel cimport parallel, prange
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Reduce this number if matrices are too big for large graphs
|
| 13 |
+
UNREACHABLE_NODE_DISTANCE = 510
|
| 14 |
+
|
| 15 |
+
def floyd_warshall(adjacency_matrix):
|
| 16 |
+
"""
|
| 17 |
+
Applies the Floyd-Warshall algorithm to the adjacency matrix, to compute the
|
| 18 |
+
shortest paths distance between all nodes, up to UNREACHABLE_NODE_DISTANCE.
|
| 19 |
+
"""
|
| 20 |
+
(nrows, ncols) = adjacency_matrix.shape
|
| 21 |
+
assert nrows == ncols
|
| 22 |
+
cdef unsigned int n = nrows
|
| 23 |
+
|
| 24 |
+
adj_mat_copy = adjacency_matrix.astype(np.int32, order='C', casting='safe', copy=True)
|
| 25 |
+
assert adj_mat_copy.flags['C_CONTIGUOUS']
|
| 26 |
+
cdef numpy.ndarray[numpy.int32_t, ndim=2, mode='c'] M = adj_mat_copy
|
| 27 |
+
cdef numpy.ndarray[numpy.int32_t, ndim=2, mode='c'] path = -1 * np.ones([n, n], dtype=np.int32)
|
| 28 |
+
|
| 29 |
+
cdef unsigned int i, j, k
|
| 30 |
+
cdef numpy.int32_t M_ij, M_ik, cost_ikkj
|
| 31 |
+
cdef numpy.int32_t* M_ptr = &M[0,0]
|
| 32 |
+
cdef numpy.int32_t* M_i_ptr
|
| 33 |
+
cdef numpy.int32_t* M_k_ptr
|
| 34 |
+
|
| 35 |
+
# set unreachable nodes distance to UNREACHABLE_NODE_DISTANCE
|
| 36 |
+
for i in range(n):
|
| 37 |
+
for j in range(n):
|
| 38 |
+
if i == j:
|
| 39 |
+
M[i][j] = 0
|
| 40 |
+
elif M[i][j] == 0:
|
| 41 |
+
M[i][j] = UNREACHABLE_NODE_DISTANCE
|
| 42 |
+
|
| 43 |
+
# floyed algo
|
| 44 |
+
for k in range(n):
|
| 45 |
+
M_k_ptr = M_ptr + n*k
|
| 46 |
+
for i in range(n):
|
| 47 |
+
M_i_ptr = M_ptr + n*i
|
| 48 |
+
M_ik = M_i_ptr[k]
|
| 49 |
+
for j in range(n):
|
| 50 |
+
cost_ikkj = M_ik + M_k_ptr[j]
|
| 51 |
+
M_ij = M_i_ptr[j]
|
| 52 |
+
if M_ij > cost_ikkj:
|
| 53 |
+
M_i_ptr[j] = cost_ikkj
|
| 54 |
+
path[i][j] = k
|
| 55 |
+
|
| 56 |
+
# set unreachable path to UNREACHABLE_NODE_DISTANCE
|
| 57 |
+
for i in range(n):
|
| 58 |
+
for j in range(n):
|
| 59 |
+
if M[i][j] >= UNREACHABLE_NODE_DISTANCE:
|
| 60 |
+
path[i][j] = UNREACHABLE_NODE_DISTANCE
|
| 61 |
+
M[i][j] = UNREACHABLE_NODE_DISTANCE
|
| 62 |
+
|
| 63 |
+
return M, path
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_all_edges(path, i, j):
|
| 67 |
+
"""
|
| 68 |
+
Recursive function to compute all possible paths between two nodes from the graph adjacency matrix.
|
| 69 |
+
"""
|
| 70 |
+
cdef int k = path[i][j]
|
| 71 |
+
if k == -1:
|
| 72 |
+
return []
|
| 73 |
+
else:
|
| 74 |
+
return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def gen_edge_input(max_dist, path, edge_feat):
|
| 78 |
+
"""
|
| 79 |
+
Generates the full edge feature and adjacency matrix.
|
| 80 |
+
Shape: num_nodes * num_nodes * max_distance_between_nodes * num_edge_features
|
| 81 |
+
Dim 1 is the input node, dim 2 the output node of the edge, dim 3 the depth of the edge, dim 4 the feature
|
| 82 |
+
"""
|
| 83 |
+
(nrows, ncols) = path.shape
|
| 84 |
+
assert nrows == ncols
|
| 85 |
+
cdef unsigned int n = nrows
|
| 86 |
+
cdef unsigned int max_dist_copy = max_dist
|
| 87 |
+
|
| 88 |
+
path_copy = path.astype(long, order='C', casting='safe', copy=True)
|
| 89 |
+
edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True)
|
| 90 |
+
assert path_copy.flags['C_CONTIGUOUS']
|
| 91 |
+
assert edge_feat_copy.flags['C_CONTIGUOUS']
|
| 92 |
+
|
| 93 |
+
cdef numpy.ndarray[numpy.int32_t, ndim=4, mode='c'] edge_fea_all = -1 * np.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=np.int32)
|
| 94 |
+
cdef unsigned int i, j, k, num_path, cur
|
| 95 |
+
|
| 96 |
+
for i in range(n):
|
| 97 |
+
for j in range(n):
|
| 98 |
+
if i == j:
|
| 99 |
+
continue
|
| 100 |
+
if path_copy[i][j] == UNREACHABLE_NODE_DISTANCE:
|
| 101 |
+
continue
|
| 102 |
+
path = [i] + get_all_edges(path_copy, i, j) + [j]
|
| 103 |
+
num_path = len(path) - 1
|
| 104 |
+
for k in range(num_path):
|
| 105 |
+
edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :]
|
| 106 |
+
|
| 107 |
+
return edge_fea_all
|
docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation and HuggingFace
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
from typing import Any, Dict, List, Mapping
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from ....utils import is_cython_available, requires_backends
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if is_cython_available():
|
| 13 |
+
import pyximport
|
| 14 |
+
|
| 15 |
+
pyximport.install(setup_args={"include_dirs": np.get_include()})
|
| 16 |
+
from . import algos_graphormer # noqa E402
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def convert_to_single_emb(x, offset: int = 512):
|
| 20 |
+
feature_num = x.shape[1] if len(x.shape) > 1 else 1
|
| 21 |
+
feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
|
| 22 |
+
x = x + feature_offset
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def preprocess_item(item, keep_features=True):
|
| 27 |
+
requires_backends(preprocess_item, ["cython"])
|
| 28 |
+
|
| 29 |
+
if keep_features and "edge_attr" in item.keys(): # edge_attr
|
| 30 |
+
edge_attr = np.asarray(item["edge_attr"], dtype=np.int64)
|
| 31 |
+
else:
|
| 32 |
+
edge_attr = np.ones((len(item["edge_index"][0]), 1), dtype=np.int64) # same embedding for all
|
| 33 |
+
|
| 34 |
+
if keep_features and "node_feat" in item.keys(): # input_nodes
|
| 35 |
+
node_feature = np.asarray(item["node_feat"], dtype=np.int64)
|
| 36 |
+
else:
|
| 37 |
+
node_feature = np.ones((item["num_nodes"], 1), dtype=np.int64) # same embedding for all
|
| 38 |
+
|
| 39 |
+
edge_index = np.asarray(item["edge_index"], dtype=np.int64)
|
| 40 |
+
|
| 41 |
+
input_nodes = convert_to_single_emb(node_feature) + 1
|
| 42 |
+
num_nodes = item["num_nodes"]
|
| 43 |
+
|
| 44 |
+
if len(edge_attr.shape) == 1:
|
| 45 |
+
edge_attr = edge_attr[:, None]
|
| 46 |
+
attn_edge_type = np.zeros([num_nodes, num_nodes, edge_attr.shape[-1]], dtype=np.int64)
|
| 47 |
+
attn_edge_type[edge_index[0], edge_index[1]] = convert_to_single_emb(edge_attr) + 1
|
| 48 |
+
|
| 49 |
+
# node adj matrix [num_nodes, num_nodes] bool
|
| 50 |
+
adj = np.zeros([num_nodes, num_nodes], dtype=bool)
|
| 51 |
+
adj[edge_index[0], edge_index[1]] = True
|
| 52 |
+
|
| 53 |
+
shortest_path_result, path = algos_graphormer.floyd_warshall(adj)
|
| 54 |
+
max_dist = np.amax(shortest_path_result)
|
| 55 |
+
|
| 56 |
+
input_edges = algos_graphormer.gen_edge_input(max_dist, path, attn_edge_type)
|
| 57 |
+
attn_bias = np.zeros([num_nodes + 1, num_nodes + 1], dtype=np.single) # with graph token
|
| 58 |
+
|
| 59 |
+
# combine
|
| 60 |
+
item["input_nodes"] = input_nodes + 1 # we shift all indices by one for padding
|
| 61 |
+
item["attn_bias"] = attn_bias
|
| 62 |
+
item["attn_edge_type"] = attn_edge_type
|
| 63 |
+
item["spatial_pos"] = shortest_path_result.astype(np.int64) + 1 # we shift all indices by one for padding
|
| 64 |
+
item["in_degree"] = np.sum(adj, axis=1).reshape(-1) + 1 # we shift all indices by one for padding
|
| 65 |
+
item["out_degree"] = item["in_degree"] # for undirected graph
|
| 66 |
+
item["input_edges"] = input_edges + 1 # we shift all indices by one for padding
|
| 67 |
+
if "labels" not in item:
|
| 68 |
+
item["labels"] = item["y"]
|
| 69 |
+
|
| 70 |
+
return item
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class GraphormerDataCollator:
|
| 74 |
+
def __init__(self, spatial_pos_max=20, on_the_fly_processing=False):
|
| 75 |
+
if not is_cython_available():
|
| 76 |
+
raise ImportError("Graphormer preprocessing needs Cython (pyximport)")
|
| 77 |
+
|
| 78 |
+
self.spatial_pos_max = spatial_pos_max
|
| 79 |
+
self.on_the_fly_processing = on_the_fly_processing
|
| 80 |
+
|
| 81 |
+
def __call__(self, features: List[dict]) -> Dict[str, Any]:
|
| 82 |
+
if self.on_the_fly_processing:
|
| 83 |
+
features = [preprocess_item(i) for i in features]
|
| 84 |
+
|
| 85 |
+
if not isinstance(features[0], Mapping):
|
| 86 |
+
features = [vars(f) for f in features]
|
| 87 |
+
batch = {}
|
| 88 |
+
|
| 89 |
+
max_node_num = max(len(i["input_nodes"]) for i in features)
|
| 90 |
+
node_feat_size = len(features[0]["input_nodes"][0])
|
| 91 |
+
edge_feat_size = len(features[0]["attn_edge_type"][0][0])
|
| 92 |
+
max_dist = max(len(i["input_edges"][0][0]) for i in features)
|
| 93 |
+
edge_input_size = len(features[0]["input_edges"][0][0][0])
|
| 94 |
+
batch_size = len(features)
|
| 95 |
+
|
| 96 |
+
batch["attn_bias"] = torch.zeros(batch_size, max_node_num + 1, max_node_num + 1, dtype=torch.float)
|
| 97 |
+
batch["attn_edge_type"] = torch.zeros(batch_size, max_node_num, max_node_num, edge_feat_size, dtype=torch.long)
|
| 98 |
+
batch["spatial_pos"] = torch.zeros(batch_size, max_node_num, max_node_num, dtype=torch.long)
|
| 99 |
+
batch["in_degree"] = torch.zeros(batch_size, max_node_num, dtype=torch.long)
|
| 100 |
+
batch["input_nodes"] = torch.zeros(batch_size, max_node_num, node_feat_size, dtype=torch.long)
|
| 101 |
+
batch["input_edges"] = torch.zeros(
|
| 102 |
+
batch_size, max_node_num, max_node_num, max_dist, edge_input_size, dtype=torch.long
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
for ix, f in enumerate(features):
|
| 106 |
+
for k in ["attn_bias", "attn_edge_type", "spatial_pos", "in_degree", "input_nodes", "input_edges"]:
|
| 107 |
+
f[k] = torch.tensor(f[k])
|
| 108 |
+
|
| 109 |
+
if len(f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max]) > 0:
|
| 110 |
+
f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max] = float("-inf")
|
| 111 |
+
|
| 112 |
+
batch["attn_bias"][ix, : f["attn_bias"].shape[0], : f["attn_bias"].shape[1]] = f["attn_bias"]
|
| 113 |
+
batch["attn_edge_type"][ix, : f["attn_edge_type"].shape[0], : f["attn_edge_type"].shape[1], :] = f[
|
| 114 |
+
"attn_edge_type"
|
| 115 |
+
]
|
| 116 |
+
batch["spatial_pos"][ix, : f["spatial_pos"].shape[0], : f["spatial_pos"].shape[1]] = f["spatial_pos"]
|
| 117 |
+
batch["in_degree"][ix, : f["in_degree"].shape[0]] = f["in_degree"]
|
| 118 |
+
batch["input_nodes"][ix, : f["input_nodes"].shape[0], :] = f["input_nodes"]
|
| 119 |
+
batch["input_edges"][
|
| 120 |
+
ix, : f["input_edges"].shape[0], : f["input_edges"].shape[1], : f["input_edges"].shape[2], :
|
| 121 |
+
] = f["input_edges"]
|
| 122 |
+
|
| 123 |
+
batch["out_degree"] = batch["in_degree"]
|
| 124 |
+
|
| 125 |
+
sample = features[0]["labels"]
|
| 126 |
+
if len(sample) == 1: # one task
|
| 127 |
+
if isinstance(sample[0], float): # regression
|
| 128 |
+
batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
|
| 129 |
+
else: # binary classification
|
| 130 |
+
batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
|
| 131 |
+
else: # multi task classification, left to float to keep the NaNs
|
| 132 |
+
batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0))
|
| 133 |
+
|
| 134 |
+
return batch
|
docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Microsoft, clefourrier and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Graphormer model configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
from ....configuration_utils import PretrainedConfig
|
| 20 |
+
from ....utils import logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class GraphormerConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`~GraphormerModel`]. It is used to instantiate an
|
| 29 |
+
Graphormer model according to the specified arguments, defining the model architecture. Instantiating a
|
| 30 |
+
configuration with the defaults will yield a similar configuration to that of the Graphormer
|
| 31 |
+
[graphormer-base-pcqm4mv1](https://huggingface.co/graphormer-base-pcqm4mv1) architecture.
|
| 32 |
+
|
| 33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 34 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
num_classes (`int`, *optional*, defaults to 1):
|
| 39 |
+
Number of target classes or labels, set to n for binary classification of n tasks.
|
| 40 |
+
num_atoms (`int`, *optional*, defaults to 512*9):
|
| 41 |
+
Number of node types in the graphs.
|
| 42 |
+
num_edges (`int`, *optional*, defaults to 512*3):
|
| 43 |
+
Number of edges types in the graph.
|
| 44 |
+
num_in_degree (`int`, *optional*, defaults to 512):
|
| 45 |
+
Number of in degrees types in the input graphs.
|
| 46 |
+
num_out_degree (`int`, *optional*, defaults to 512):
|
| 47 |
+
Number of out degrees types in the input graphs.
|
| 48 |
+
num_edge_dis (`int`, *optional*, defaults to 128):
|
| 49 |
+
Number of edge dis in the input graphs.
|
| 50 |
+
multi_hop_max_dist (`int`, *optional*, defaults to 20):
|
| 51 |
+
Maximum distance of multi hop edges between two nodes.
|
| 52 |
+
spatial_pos_max (`int`, *optional*, defaults to 1024):
|
| 53 |
+
Maximum distance between nodes in the graph attention bias matrices, used during preprocessing and
|
| 54 |
+
collation.
|
| 55 |
+
edge_type (`str`, *optional*, defaults to multihop):
|
| 56 |
+
Type of edge relation chosen.
|
| 57 |
+
max_nodes (`int`, *optional*, defaults to 512):
|
| 58 |
+
Maximum number of nodes which can be parsed for the input graphs.
|
| 59 |
+
share_input_output_embed (`bool`, *optional*, defaults to `False`):
|
| 60 |
+
Shares the embedding layer between encoder and decoder - careful, True is not implemented.
|
| 61 |
+
num_layers (`int`, *optional*, defaults to 12):
|
| 62 |
+
Number of layers.
|
| 63 |
+
embedding_dim (`int`, *optional*, defaults to 768):
|
| 64 |
+
Dimension of the embedding layer in encoder.
|
| 65 |
+
ffn_embedding_dim (`int`, *optional*, defaults to 768):
|
| 66 |
+
Dimension of the "intermediate" (often named feed-forward) layer in encoder.
|
| 67 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 68 |
+
Number of attention heads in the encoder.
|
| 69 |
+
self_attention (`bool`, *optional*, defaults to `True`):
|
| 70 |
+
Model is self attentive (False not implemented).
|
| 71 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 72 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 73 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 74 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 75 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 76 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 77 |
+
The dropout probability for the attention weights.
|
| 78 |
+
activation_dropout (`float`, *optional*, defaults to 0.1):
|
| 79 |
+
The dropout probability for the activation of the linear transformer layer.
|
| 80 |
+
layerdrop (`float`, *optional*, defaults to 0.0):
|
| 81 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 82 |
+
for more details.
|
| 83 |
+
bias (`bool`, *optional*, defaults to `True`):
|
| 84 |
+
Uses bias in the attention module - unsupported at the moment.
|
| 85 |
+
embed_scale(`float`, *optional*, defaults to None):
|
| 86 |
+
Scaling factor for the node embeddings.
|
| 87 |
+
num_trans_layers_to_freeze (`int`, *optional*, defaults to 0):
|
| 88 |
+
Number of transformer layers to freeze.
|
| 89 |
+
encoder_normalize_before (`bool`, *optional*, defaults to `False`):
|
| 90 |
+
Normalize features before encoding the graph.
|
| 91 |
+
pre_layernorm (`bool`, *optional*, defaults to `False`):
|
| 92 |
+
Apply layernorm before self attention and the feed forward network. Without this, post layernorm will be
|
| 93 |
+
used.
|
| 94 |
+
apply_graphormer_init (`bool`, *optional*, defaults to `False`):
|
| 95 |
+
Apply a custom graphormer initialisation to the model before training.
|
| 96 |
+
freeze_embeddings (`bool`, *optional*, defaults to `False`):
|
| 97 |
+
Freeze the embedding layer, or train it along the model.
|
| 98 |
+
encoder_normalize_before (`bool`, *optional*, defaults to `False`):
|
| 99 |
+
Apply the layer norm before each encoder block.
|
| 100 |
+
q_noise (`float`, *optional*, defaults to 0.0):
|
| 101 |
+
Amount of quantization noise (see "Training with Quantization Noise for Extreme Model Compression"). (For
|
| 102 |
+
more detail, see fairseq's documentation on quant_noise).
|
| 103 |
+
qn_block_size (`int`, *optional*, defaults to 8):
|
| 104 |
+
Size of the blocks for subsequent quantization with iPQ (see q_noise).
|
| 105 |
+
kdim (`int`, *optional*, defaults to None):
|
| 106 |
+
Dimension of the key in the attention, if different from the other values.
|
| 107 |
+
vdim (`int`, *optional*, defaults to None):
|
| 108 |
+
Dimension of the value in the attention, if different from the other values.
|
| 109 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 110 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 111 |
+
traceable (`bool`, *optional*, defaults to `False`):
|
| 112 |
+
Changes return value of the encoder's inner_state to stacked tensors.
|
| 113 |
+
|
| 114 |
+
Example:
|
| 115 |
+
```python
|
| 116 |
+
>>> from transformers import GraphormerForGraphClassification, GraphormerConfig
|
| 117 |
+
|
| 118 |
+
>>> # Initializing a Graphormer graphormer-base-pcqm4mv2 style configuration
|
| 119 |
+
>>> configuration = GraphormerConfig()
|
| 120 |
+
|
| 121 |
+
>>> # Initializing a model from the graphormer-base-pcqm4mv1 style configuration
|
| 122 |
+
>>> model = GraphormerForGraphClassification(configuration)
|
| 123 |
+
|
| 124 |
+
>>> # Accessing the model configuration
|
| 125 |
+
>>> configuration = model.config
|
| 126 |
+
```
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
model_type = "graphormer"
|
| 130 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
num_classes: int = 1,
|
| 135 |
+
num_atoms: int = 512 * 9,
|
| 136 |
+
num_edges: int = 512 * 3,
|
| 137 |
+
num_in_degree: int = 512,
|
| 138 |
+
num_out_degree: int = 512,
|
| 139 |
+
num_spatial: int = 512,
|
| 140 |
+
num_edge_dis: int = 128,
|
| 141 |
+
multi_hop_max_dist: int = 5, # sometimes is 20
|
| 142 |
+
spatial_pos_max: int = 1024,
|
| 143 |
+
edge_type: str = "multi_hop",
|
| 144 |
+
max_nodes: int = 512,
|
| 145 |
+
share_input_output_embed: bool = False,
|
| 146 |
+
num_hidden_layers: int = 12,
|
| 147 |
+
embedding_dim: int = 768,
|
| 148 |
+
ffn_embedding_dim: int = 768,
|
| 149 |
+
num_attention_heads: int = 32,
|
| 150 |
+
dropout: float = 0.1,
|
| 151 |
+
attention_dropout: float = 0.1,
|
| 152 |
+
activation_dropout: float = 0.1,
|
| 153 |
+
layerdrop: float = 0.0,
|
| 154 |
+
encoder_normalize_before: bool = False,
|
| 155 |
+
pre_layernorm: bool = False,
|
| 156 |
+
apply_graphormer_init: bool = False,
|
| 157 |
+
activation_fn: str = "gelu",
|
| 158 |
+
embed_scale: Optional[float] = None,
|
| 159 |
+
freeze_embeddings: bool = False,
|
| 160 |
+
num_trans_layers_to_freeze: int = 0,
|
| 161 |
+
traceable: bool = False,
|
| 162 |
+
q_noise: float = 0.0,
|
| 163 |
+
qn_block_size: int = 8,
|
| 164 |
+
kdim: Optional[int] = None,
|
| 165 |
+
vdim: Optional[int] = None,
|
| 166 |
+
bias: bool = True,
|
| 167 |
+
self_attention: bool = True,
|
| 168 |
+
pad_token_id=0,
|
| 169 |
+
bos_token_id=1,
|
| 170 |
+
eos_token_id=2,
|
| 171 |
+
**kwargs,
|
| 172 |
+
):
|
| 173 |
+
self.num_classes = num_classes
|
| 174 |
+
self.num_atoms = num_atoms
|
| 175 |
+
self.num_in_degree = num_in_degree
|
| 176 |
+
self.num_out_degree = num_out_degree
|
| 177 |
+
self.num_edges = num_edges
|
| 178 |
+
self.num_spatial = num_spatial
|
| 179 |
+
self.num_edge_dis = num_edge_dis
|
| 180 |
+
self.edge_type = edge_type
|
| 181 |
+
self.multi_hop_max_dist = multi_hop_max_dist
|
| 182 |
+
self.spatial_pos_max = spatial_pos_max
|
| 183 |
+
self.max_nodes = max_nodes
|
| 184 |
+
self.num_hidden_layers = num_hidden_layers
|
| 185 |
+
self.embedding_dim = embedding_dim
|
| 186 |
+
self.hidden_size = embedding_dim
|
| 187 |
+
self.ffn_embedding_dim = ffn_embedding_dim
|
| 188 |
+
self.num_attention_heads = num_attention_heads
|
| 189 |
+
self.dropout = dropout
|
| 190 |
+
self.attention_dropout = attention_dropout
|
| 191 |
+
self.activation_dropout = activation_dropout
|
| 192 |
+
self.layerdrop = layerdrop
|
| 193 |
+
self.encoder_normalize_before = encoder_normalize_before
|
| 194 |
+
self.pre_layernorm = pre_layernorm
|
| 195 |
+
self.apply_graphormer_init = apply_graphormer_init
|
| 196 |
+
self.activation_fn = activation_fn
|
| 197 |
+
self.embed_scale = embed_scale
|
| 198 |
+
self.freeze_embeddings = freeze_embeddings
|
| 199 |
+
self.num_trans_layers_to_freeze = num_trans_layers_to_freeze
|
| 200 |
+
self.share_input_output_embed = share_input_output_embed
|
| 201 |
+
self.traceable = traceable
|
| 202 |
+
self.q_noise = q_noise
|
| 203 |
+
self.qn_block_size = qn_block_size
|
| 204 |
+
|
| 205 |
+
# These parameters are here for future extensions
|
| 206 |
+
# atm, the model only supports self attention
|
| 207 |
+
self.kdim = kdim
|
| 208 |
+
self.vdim = vdim
|
| 209 |
+
self.self_attention = self_attention
|
| 210 |
+
self.bias = bias
|
| 211 |
+
|
| 212 |
+
super().__init__(
|
| 213 |
+
pad_token_id=pad_token_id,
|
| 214 |
+
bos_token_id=bos_token_id,
|
| 215 |
+
eos_token_id=eos_token_id,
|
| 216 |
+
**kwargs,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
__all__ = ["GraphormerConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py
ADDED
|
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Microsoft, clefourrier The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Graphormer model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import Iterable, Iterator, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 23 |
+
|
| 24 |
+
from ....activations import ACT2FN
|
| 25 |
+
from ....modeling_outputs import (
|
| 26 |
+
BaseModelOutputWithNoAttention,
|
| 27 |
+
SequenceClassifierOutput,
|
| 28 |
+
)
|
| 29 |
+
from ....modeling_utils import PreTrainedModel
|
| 30 |
+
from ....utils import logging
|
| 31 |
+
from .configuration_graphormer import GraphormerConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
_CHECKPOINT_FOR_DOC = "graphormer-base-pcqm4mv1"
|
| 37 |
+
_CONFIG_FOR_DOC = "GraphormerConfig"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def quant_noise(module: nn.Module, p: float, block_size: int):
|
| 41 |
+
"""
|
| 42 |
+
From:
|
| 43 |
+
https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/quant_noise.py
|
| 44 |
+
|
| 45 |
+
Wraps modules and applies quantization noise to the weights for subsequent quantization with Iterative Product
|
| 46 |
+
Quantization as described in "Training with Quantization Noise for Extreme Model Compression"
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
- module: nn.Module
|
| 50 |
+
- p: amount of Quantization Noise
|
| 51 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
| 52 |
+
|
| 53 |
+
Remarks:
|
| 54 |
+
- Module weights must have the right sizes wrt the block size
|
| 55 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
| 56 |
+
- For more detail on how to quantize by blocks with convolutional weights, see "And the Bit Goes Down:
|
| 57 |
+
Revisiting the Quantization of Neural Networks"
|
| 58 |
+
- We implement the simplest form of noise here as stated in the paper which consists in randomly dropping
|
| 59 |
+
blocks
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# if no quantization noise, don't register hook
|
| 63 |
+
if p <= 0:
|
| 64 |
+
return module
|
| 65 |
+
|
| 66 |
+
# supported modules
|
| 67 |
+
if not isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)):
|
| 68 |
+
raise NotImplementedError("Module unsupported for quant_noise.")
|
| 69 |
+
|
| 70 |
+
# test whether module.weight has the right sizes wrt block_size
|
| 71 |
+
is_conv = module.weight.ndim == 4
|
| 72 |
+
|
| 73 |
+
# 2D matrix
|
| 74 |
+
if not is_conv:
|
| 75 |
+
if module.weight.size(1) % block_size != 0:
|
| 76 |
+
raise AssertionError("Input features must be a multiple of block sizes")
|
| 77 |
+
|
| 78 |
+
# 4D matrix
|
| 79 |
+
else:
|
| 80 |
+
# 1x1 convolutions
|
| 81 |
+
if module.kernel_size == (1, 1):
|
| 82 |
+
if module.in_channels % block_size != 0:
|
| 83 |
+
raise AssertionError("Input channels must be a multiple of block sizes")
|
| 84 |
+
# regular convolutions
|
| 85 |
+
else:
|
| 86 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
| 87 |
+
if k % block_size != 0:
|
| 88 |
+
raise AssertionError("Kernel size must be a multiple of block size")
|
| 89 |
+
|
| 90 |
+
def _forward_pre_hook(mod, input):
|
| 91 |
+
# no noise for evaluation
|
| 92 |
+
if mod.training:
|
| 93 |
+
if not is_conv:
|
| 94 |
+
# gather weight and sizes
|
| 95 |
+
weight = mod.weight
|
| 96 |
+
in_features = weight.size(1)
|
| 97 |
+
out_features = weight.size(0)
|
| 98 |
+
|
| 99 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 100 |
+
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
| 101 |
+
mask.bernoulli_(p)
|
| 102 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 103 |
+
|
| 104 |
+
else:
|
| 105 |
+
# gather weight and sizes
|
| 106 |
+
weight = mod.weight
|
| 107 |
+
in_channels = mod.in_channels
|
| 108 |
+
out_channels = mod.out_channels
|
| 109 |
+
|
| 110 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 111 |
+
if mod.kernel_size == (1, 1):
|
| 112 |
+
mask = torch.zeros(
|
| 113 |
+
int(in_channels // block_size * out_channels),
|
| 114 |
+
device=weight.device,
|
| 115 |
+
)
|
| 116 |
+
mask.bernoulli_(p)
|
| 117 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 118 |
+
else:
|
| 119 |
+
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
| 120 |
+
mask.bernoulli_(p)
|
| 121 |
+
mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
| 122 |
+
|
| 123 |
+
# scale weights and apply mask
|
| 124 |
+
mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
|
| 125 |
+
s = 1 / (1 - p)
|
| 126 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 127 |
+
|
| 128 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
| 129 |
+
return module
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class LayerDropModuleList(nn.ModuleList):
|
| 133 |
+
"""
|
| 134 |
+
From:
|
| 135 |
+
https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/layer_drop.py
|
| 136 |
+
A LayerDrop implementation based on [`torch.nn.ModuleList`]. LayerDrop as described in
|
| 137 |
+
https://arxiv.org/abs/1909.11556.
|
| 138 |
+
|
| 139 |
+
We refresh the choice of which layers to drop every time we iterate over the LayerDropModuleList instance. During
|
| 140 |
+
evaluation we always iterate over all layers.
|
| 141 |
+
|
| 142 |
+
Usage:
|
| 143 |
+
|
| 144 |
+
```python
|
| 145 |
+
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
|
| 146 |
+
for layer in layers: # this might iterate over layers 1 and 3
|
| 147 |
+
x = layer(x)
|
| 148 |
+
for layer in layers: # this might iterate over all layers
|
| 149 |
+
x = layer(x)
|
| 150 |
+
for layer in layers: # this might not iterate over any layers
|
| 151 |
+
x = layer(x)
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
p (float): probability of dropping out each layer
|
| 156 |
+
modules (iterable, optional): an iterable of modules to add
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, p: float, modules: Optional[Iterable[nn.Module]] = None):
|
| 160 |
+
super().__init__(modules)
|
| 161 |
+
self.p = p
|
| 162 |
+
|
| 163 |
+
def __iter__(self) -> Iterator[nn.Module]:
|
| 164 |
+
dropout_probs = torch.empty(len(self)).uniform_()
|
| 165 |
+
for i, m in enumerate(super().__iter__()):
|
| 166 |
+
if not self.training or (dropout_probs[i] > self.p):
|
| 167 |
+
yield m
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class GraphormerGraphNodeFeature(nn.Module):
|
| 171 |
+
"""
|
| 172 |
+
Compute node features for each node in the graph.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, config: GraphormerConfig):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.num_heads = config.num_attention_heads
|
| 178 |
+
self.num_atoms = config.num_atoms
|
| 179 |
+
|
| 180 |
+
self.atom_encoder = nn.Embedding(config.num_atoms + 1, config.hidden_size, padding_idx=config.pad_token_id)
|
| 181 |
+
self.in_degree_encoder = nn.Embedding(
|
| 182 |
+
config.num_in_degree, config.hidden_size, padding_idx=config.pad_token_id
|
| 183 |
+
)
|
| 184 |
+
self.out_degree_encoder = nn.Embedding(
|
| 185 |
+
config.num_out_degree, config.hidden_size, padding_idx=config.pad_token_id
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.graph_token = nn.Embedding(1, config.hidden_size)
|
| 189 |
+
|
| 190 |
+
def forward(
|
| 191 |
+
self,
|
| 192 |
+
input_nodes: torch.LongTensor,
|
| 193 |
+
in_degree: torch.LongTensor,
|
| 194 |
+
out_degree: torch.LongTensor,
|
| 195 |
+
) -> torch.Tensor:
|
| 196 |
+
n_graph, n_node = input_nodes.size()[:2]
|
| 197 |
+
|
| 198 |
+
node_feature = ( # node feature + graph token
|
| 199 |
+
self.atom_encoder(input_nodes).sum(dim=-2) # [n_graph, n_node, n_hidden]
|
| 200 |
+
+ self.in_degree_encoder(in_degree)
|
| 201 |
+
+ self.out_degree_encoder(out_degree)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)
|
| 205 |
+
|
| 206 |
+
graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1)
|
| 207 |
+
|
| 208 |
+
return graph_node_feature
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class GraphormerGraphAttnBias(nn.Module):
|
| 212 |
+
"""
|
| 213 |
+
Compute attention bias for each head.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def __init__(self, config: GraphormerConfig):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.num_heads = config.num_attention_heads
|
| 219 |
+
self.multi_hop_max_dist = config.multi_hop_max_dist
|
| 220 |
+
|
| 221 |
+
# We do not change edge feature embedding learning, as edge embeddings are represented as a combination of the original features
|
| 222 |
+
# + shortest path
|
| 223 |
+
self.edge_encoder = nn.Embedding(config.num_edges + 1, config.num_attention_heads, padding_idx=0)
|
| 224 |
+
|
| 225 |
+
self.edge_type = config.edge_type
|
| 226 |
+
if self.edge_type == "multi_hop":
|
| 227 |
+
self.edge_dis_encoder = nn.Embedding(
|
| 228 |
+
config.num_edge_dis * config.num_attention_heads * config.num_attention_heads,
|
| 229 |
+
1,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.spatial_pos_encoder = nn.Embedding(config.num_spatial, config.num_attention_heads, padding_idx=0)
|
| 233 |
+
|
| 234 |
+
self.graph_token_virtual_distance = nn.Embedding(1, config.num_attention_heads)
|
| 235 |
+
|
| 236 |
+
def forward(
|
| 237 |
+
self,
|
| 238 |
+
input_nodes: torch.LongTensor,
|
| 239 |
+
attn_bias: torch.Tensor,
|
| 240 |
+
spatial_pos: torch.LongTensor,
|
| 241 |
+
input_edges: torch.LongTensor,
|
| 242 |
+
attn_edge_type: torch.LongTensor,
|
| 243 |
+
) -> torch.Tensor:
|
| 244 |
+
n_graph, n_node = input_nodes.size()[:2]
|
| 245 |
+
graph_attn_bias = attn_bias.clone()
|
| 246 |
+
graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
|
| 247 |
+
1, self.num_heads, 1, 1
|
| 248 |
+
) # [n_graph, n_head, n_node+1, n_node+1]
|
| 249 |
+
|
| 250 |
+
# spatial pos
|
| 251 |
+
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
|
| 252 |
+
spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
|
| 253 |
+
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias
|
| 254 |
+
|
| 255 |
+
# reset spatial pos here
|
| 256 |
+
t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
|
| 257 |
+
graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
|
| 258 |
+
graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
|
| 259 |
+
|
| 260 |
+
# edge feature
|
| 261 |
+
if self.edge_type == "multi_hop":
|
| 262 |
+
spatial_pos_ = spatial_pos.clone()
|
| 263 |
+
|
| 264 |
+
spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
|
| 265 |
+
# set 1 to 1, input_nodes > 1 to input_nodes - 1
|
| 266 |
+
spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
|
| 267 |
+
if self.multi_hop_max_dist > 0:
|
| 268 |
+
spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
|
| 269 |
+
input_edges = input_edges[:, :, :, : self.multi_hop_max_dist, :]
|
| 270 |
+
# [n_graph, n_node, n_node, max_dist, n_head]
|
| 271 |
+
|
| 272 |
+
input_edges = self.edge_encoder(input_edges).mean(-2)
|
| 273 |
+
max_dist = input_edges.size(-2)
|
| 274 |
+
edge_input_flat = input_edges.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads)
|
| 275 |
+
edge_input_flat = torch.bmm(
|
| 276 |
+
edge_input_flat,
|
| 277 |
+
self.edge_dis_encoder.weight.reshape(-1, self.num_heads, self.num_heads)[:max_dist, :, :],
|
| 278 |
+
)
|
| 279 |
+
input_edges = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.num_heads).permute(
|
| 280 |
+
1, 2, 3, 0, 4
|
| 281 |
+
)
|
| 282 |
+
input_edges = (input_edges.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2)
|
| 283 |
+
else:
|
| 284 |
+
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
|
| 285 |
+
input_edges = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
|
| 286 |
+
|
| 287 |
+
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + input_edges
|
| 288 |
+
graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset
|
| 289 |
+
|
| 290 |
+
return graph_attn_bias
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class GraphormerMultiheadAttention(nn.Module):
|
| 294 |
+
"""Multi-headed attention.
|
| 295 |
+
|
| 296 |
+
See "Attention Is All You Need" for more details.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
def __init__(self, config: GraphormerConfig):
|
| 300 |
+
super().__init__()
|
| 301 |
+
self.embedding_dim = config.embedding_dim
|
| 302 |
+
self.kdim = config.kdim if config.kdim is not None else config.embedding_dim
|
| 303 |
+
self.vdim = config.vdim if config.vdim is not None else config.embedding_dim
|
| 304 |
+
self.qkv_same_dim = self.kdim == config.embedding_dim and self.vdim == config.embedding_dim
|
| 305 |
+
|
| 306 |
+
self.num_heads = config.num_attention_heads
|
| 307 |
+
self.attention_dropout_module = torch.nn.Dropout(p=config.attention_dropout, inplace=False)
|
| 308 |
+
|
| 309 |
+
self.head_dim = config.embedding_dim // config.num_attention_heads
|
| 310 |
+
if not (self.head_dim * config.num_attention_heads == self.embedding_dim):
|
| 311 |
+
raise AssertionError("The embedding_dim must be divisible by num_heads.")
|
| 312 |
+
self.scaling = self.head_dim**-0.5
|
| 313 |
+
|
| 314 |
+
self.self_attention = True # config.self_attention
|
| 315 |
+
if not (self.self_attention):
|
| 316 |
+
raise NotImplementedError("The Graphormer model only supports self attention for now.")
|
| 317 |
+
if self.self_attention and not self.qkv_same_dim:
|
| 318 |
+
raise AssertionError("Self-attention requires query, key and value to be of the same size.")
|
| 319 |
+
|
| 320 |
+
self.k_proj = quant_noise(
|
| 321 |
+
nn.Linear(self.kdim, config.embedding_dim, bias=config.bias),
|
| 322 |
+
config.q_noise,
|
| 323 |
+
config.qn_block_size,
|
| 324 |
+
)
|
| 325 |
+
self.v_proj = quant_noise(
|
| 326 |
+
nn.Linear(self.vdim, config.embedding_dim, bias=config.bias),
|
| 327 |
+
config.q_noise,
|
| 328 |
+
config.qn_block_size,
|
| 329 |
+
)
|
| 330 |
+
self.q_proj = quant_noise(
|
| 331 |
+
nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias),
|
| 332 |
+
config.q_noise,
|
| 333 |
+
config.qn_block_size,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
self.out_proj = quant_noise(
|
| 337 |
+
nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias),
|
| 338 |
+
config.q_noise,
|
| 339 |
+
config.qn_block_size,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
self.onnx_trace = False
|
| 343 |
+
|
| 344 |
+
def reset_parameters(self):
|
| 345 |
+
if self.qkv_same_dim:
|
| 346 |
+
# Empirically observed the convergence to be much better with
|
| 347 |
+
# the scaled initialization
|
| 348 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 349 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 350 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 351 |
+
else:
|
| 352 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 353 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 354 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 355 |
+
|
| 356 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 357 |
+
if self.out_proj.bias is not None:
|
| 358 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 359 |
+
|
| 360 |
+
def forward(
|
| 361 |
+
self,
|
| 362 |
+
query: torch.LongTensor,
|
| 363 |
+
key: Optional[torch.Tensor],
|
| 364 |
+
value: Optional[torch.Tensor],
|
| 365 |
+
attn_bias: Optional[torch.Tensor],
|
| 366 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
| 367 |
+
need_weights: bool = True,
|
| 368 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 369 |
+
before_softmax: bool = False,
|
| 370 |
+
need_head_weights: bool = False,
|
| 371 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 372 |
+
"""
|
| 373 |
+
Args:
|
| 374 |
+
key_padding_mask (Bytetorch.Tensor, optional): mask to exclude
|
| 375 |
+
keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s.
|
| 376 |
+
need_weights (bool, optional): return the attention weights,
|
| 377 |
+
averaged over heads (default: False).
|
| 378 |
+
attn_mask (Bytetorch.Tensor, optional): typically used to
|
| 379 |
+
implement causal attention, where the mask prevents the attention from looking forward in time
|
| 380 |
+
(default: None).
|
| 381 |
+
before_softmax (bool, optional): return the raw attention
|
| 382 |
+
weights and values before the attention softmax.
|
| 383 |
+
need_head_weights (bool, optional): return the attention
|
| 384 |
+
weights for each head. Implies *need_weights*. Default: return the average attention weights over all
|
| 385 |
+
heads.
|
| 386 |
+
"""
|
| 387 |
+
if need_head_weights:
|
| 388 |
+
need_weights = True
|
| 389 |
+
|
| 390 |
+
tgt_len, bsz, embedding_dim = query.size()
|
| 391 |
+
src_len = tgt_len
|
| 392 |
+
if not (embedding_dim == self.embedding_dim):
|
| 393 |
+
raise AssertionError(
|
| 394 |
+
f"The query embedding dimension {embedding_dim} is not equal to the expected embedding_dim"
|
| 395 |
+
f" {self.embedding_dim}."
|
| 396 |
+
)
|
| 397 |
+
if not (list(query.size()) == [tgt_len, bsz, embedding_dim]):
|
| 398 |
+
raise AssertionError("Query size incorrect in Graphormer, compared to model dimensions.")
|
| 399 |
+
|
| 400 |
+
if key is not None:
|
| 401 |
+
src_len, key_bsz, _ = key.size()
|
| 402 |
+
if not torch.jit.is_scripting():
|
| 403 |
+
if (key_bsz != bsz) or (value is None) or not (src_len, bsz == value.shape[:2]):
|
| 404 |
+
raise AssertionError(
|
| 405 |
+
"The batch shape does not match the key or value shapes provided to the attention."
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
q = self.q_proj(query)
|
| 409 |
+
k = self.k_proj(query)
|
| 410 |
+
v = self.v_proj(query)
|
| 411 |
+
|
| 412 |
+
q *= self.scaling
|
| 413 |
+
|
| 414 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 415 |
+
if k is not None:
|
| 416 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 417 |
+
if v is not None:
|
| 418 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 419 |
+
|
| 420 |
+
if (k is None) or not (k.size(1) == src_len):
|
| 421 |
+
raise AssertionError("The shape of the key generated in the attention is incorrect")
|
| 422 |
+
|
| 423 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 424 |
+
# not supporting Optional types.
|
| 425 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
| 426 |
+
key_padding_mask = None
|
| 427 |
+
|
| 428 |
+
if key_padding_mask is not None:
|
| 429 |
+
if key_padding_mask.size(0) != bsz or key_padding_mask.size(1) != src_len:
|
| 430 |
+
raise AssertionError(
|
| 431 |
+
"The shape of the generated padding mask for the key does not match expected dimensions."
|
| 432 |
+
)
|
| 433 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 434 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 435 |
+
|
| 436 |
+
if list(attn_weights.size()) != [bsz * self.num_heads, tgt_len, src_len]:
|
| 437 |
+
raise AssertionError("The attention weights generated do not match the expected dimensions.")
|
| 438 |
+
|
| 439 |
+
if attn_bias is not None:
|
| 440 |
+
attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
|
| 441 |
+
|
| 442 |
+
if attn_mask is not None:
|
| 443 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 444 |
+
attn_weights += attn_mask
|
| 445 |
+
|
| 446 |
+
if key_padding_mask is not None:
|
| 447 |
+
# don't attend to padding symbols
|
| 448 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 449 |
+
attn_weights = attn_weights.masked_fill(
|
| 450 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
|
| 451 |
+
)
|
| 452 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 453 |
+
|
| 454 |
+
if before_softmax:
|
| 455 |
+
return attn_weights, v
|
| 456 |
+
|
| 457 |
+
attn_weights_float = torch.nn.functional.softmax(attn_weights, dim=-1)
|
| 458 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 459 |
+
attn_probs = self.attention_dropout_module(attn_weights)
|
| 460 |
+
|
| 461 |
+
if v is None:
|
| 462 |
+
raise AssertionError("No value generated")
|
| 463 |
+
attn = torch.bmm(attn_probs, v)
|
| 464 |
+
if list(attn.size()) != [bsz * self.num_heads, tgt_len, self.head_dim]:
|
| 465 |
+
raise AssertionError("The attention generated do not match the expected dimensions.")
|
| 466 |
+
|
| 467 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embedding_dim)
|
| 468 |
+
attn: torch.Tensor = self.out_proj(attn)
|
| 469 |
+
|
| 470 |
+
attn_weights = None
|
| 471 |
+
if need_weights:
|
| 472 |
+
attn_weights = attn_weights_float.contiguous().view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
| 473 |
+
if not need_head_weights:
|
| 474 |
+
# average attention weights over heads
|
| 475 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 476 |
+
|
| 477 |
+
return attn, attn_weights
|
| 478 |
+
|
| 479 |
+
def apply_sparse_mask(self, attn_weights: torch.Tensor, tgt_len: int, src_len: int, bsz: int) -> torch.Tensor:
|
| 480 |
+
return attn_weights
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class GraphormerGraphEncoderLayer(nn.Module):
|
| 484 |
+
def __init__(self, config: GraphormerConfig) -> None:
|
| 485 |
+
super().__init__()
|
| 486 |
+
|
| 487 |
+
# Initialize parameters
|
| 488 |
+
self.embedding_dim = config.embedding_dim
|
| 489 |
+
self.num_attention_heads = config.num_attention_heads
|
| 490 |
+
self.q_noise = config.q_noise
|
| 491 |
+
self.qn_block_size = config.qn_block_size
|
| 492 |
+
self.pre_layernorm = config.pre_layernorm
|
| 493 |
+
|
| 494 |
+
self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
|
| 495 |
+
|
| 496 |
+
self.activation_dropout_module = torch.nn.Dropout(p=config.activation_dropout, inplace=False)
|
| 497 |
+
|
| 498 |
+
# Initialize blocks
|
| 499 |
+
self.activation_fn = ACT2FN[config.activation_fn]
|
| 500 |
+
self.self_attn = GraphormerMultiheadAttention(config)
|
| 501 |
+
|
| 502 |
+
# layer norm associated with the self attention layer
|
| 503 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
|
| 504 |
+
|
| 505 |
+
self.fc1 = self.build_fc(
|
| 506 |
+
self.embedding_dim,
|
| 507 |
+
config.ffn_embedding_dim,
|
| 508 |
+
q_noise=config.q_noise,
|
| 509 |
+
qn_block_size=config.qn_block_size,
|
| 510 |
+
)
|
| 511 |
+
self.fc2 = self.build_fc(
|
| 512 |
+
config.ffn_embedding_dim,
|
| 513 |
+
self.embedding_dim,
|
| 514 |
+
q_noise=config.q_noise,
|
| 515 |
+
qn_block_size=config.qn_block_size,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# layer norm associated with the position wise feed-forward NN
|
| 519 |
+
self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
|
| 520 |
+
|
| 521 |
+
def build_fc(
|
| 522 |
+
self, input_dim: int, output_dim: int, q_noise: float, qn_block_size: int
|
| 523 |
+
) -> Union[nn.Module, nn.Linear, nn.Embedding, nn.Conv2d]:
|
| 524 |
+
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
|
| 525 |
+
|
| 526 |
+
def forward(
|
| 527 |
+
self,
|
| 528 |
+
input_nodes: torch.Tensor,
|
| 529 |
+
self_attn_bias: Optional[torch.Tensor] = None,
|
| 530 |
+
self_attn_mask: Optional[torch.Tensor] = None,
|
| 531 |
+
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
| 532 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 533 |
+
"""
|
| 534 |
+
nn.LayerNorm is applied either before or after the self-attention/ffn modules similar to the original
|
| 535 |
+
Transformer implementation.
|
| 536 |
+
"""
|
| 537 |
+
residual = input_nodes
|
| 538 |
+
if self.pre_layernorm:
|
| 539 |
+
input_nodes = self.self_attn_layer_norm(input_nodes)
|
| 540 |
+
|
| 541 |
+
input_nodes, attn = self.self_attn(
|
| 542 |
+
query=input_nodes,
|
| 543 |
+
key=input_nodes,
|
| 544 |
+
value=input_nodes,
|
| 545 |
+
attn_bias=self_attn_bias,
|
| 546 |
+
key_padding_mask=self_attn_padding_mask,
|
| 547 |
+
need_weights=False,
|
| 548 |
+
attn_mask=self_attn_mask,
|
| 549 |
+
)
|
| 550 |
+
input_nodes = self.dropout_module(input_nodes)
|
| 551 |
+
input_nodes = residual + input_nodes
|
| 552 |
+
if not self.pre_layernorm:
|
| 553 |
+
input_nodes = self.self_attn_layer_norm(input_nodes)
|
| 554 |
+
|
| 555 |
+
residual = input_nodes
|
| 556 |
+
if self.pre_layernorm:
|
| 557 |
+
input_nodes = self.final_layer_norm(input_nodes)
|
| 558 |
+
input_nodes = self.activation_fn(self.fc1(input_nodes))
|
| 559 |
+
input_nodes = self.activation_dropout_module(input_nodes)
|
| 560 |
+
input_nodes = self.fc2(input_nodes)
|
| 561 |
+
input_nodes = self.dropout_module(input_nodes)
|
| 562 |
+
input_nodes = residual + input_nodes
|
| 563 |
+
if not self.pre_layernorm:
|
| 564 |
+
input_nodes = self.final_layer_norm(input_nodes)
|
| 565 |
+
|
| 566 |
+
return input_nodes, attn
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class GraphormerGraphEncoder(nn.Module):
|
| 570 |
+
def __init__(self, config: GraphormerConfig):
|
| 571 |
+
super().__init__()
|
| 572 |
+
|
| 573 |
+
self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
|
| 574 |
+
self.layerdrop = config.layerdrop
|
| 575 |
+
self.embedding_dim = config.embedding_dim
|
| 576 |
+
self.apply_graphormer_init = config.apply_graphormer_init
|
| 577 |
+
self.traceable = config.traceable
|
| 578 |
+
|
| 579 |
+
self.graph_node_feature = GraphormerGraphNodeFeature(config)
|
| 580 |
+
self.graph_attn_bias = GraphormerGraphAttnBias(config)
|
| 581 |
+
|
| 582 |
+
self.embed_scale = config.embed_scale
|
| 583 |
+
|
| 584 |
+
if config.q_noise > 0:
|
| 585 |
+
self.quant_noise = quant_noise(
|
| 586 |
+
nn.Linear(self.embedding_dim, self.embedding_dim, bias=False),
|
| 587 |
+
config.q_noise,
|
| 588 |
+
config.qn_block_size,
|
| 589 |
+
)
|
| 590 |
+
else:
|
| 591 |
+
self.quant_noise = None
|
| 592 |
+
|
| 593 |
+
if config.encoder_normalize_before:
|
| 594 |
+
self.emb_layer_norm = nn.LayerNorm(self.embedding_dim)
|
| 595 |
+
else:
|
| 596 |
+
self.emb_layer_norm = None
|
| 597 |
+
|
| 598 |
+
if config.pre_layernorm:
|
| 599 |
+
self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
|
| 600 |
+
|
| 601 |
+
if self.layerdrop > 0.0:
|
| 602 |
+
self.layers = LayerDropModuleList(p=self.layerdrop)
|
| 603 |
+
else:
|
| 604 |
+
self.layers = nn.ModuleList([])
|
| 605 |
+
self.layers.extend([GraphormerGraphEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 606 |
+
|
| 607 |
+
# Apply initialization of model params after building the model
|
| 608 |
+
if config.freeze_embeddings:
|
| 609 |
+
raise NotImplementedError("Freezing embeddings is not implemented yet.")
|
| 610 |
+
|
| 611 |
+
for layer in range(config.num_trans_layers_to_freeze):
|
| 612 |
+
m = self.layers[layer]
|
| 613 |
+
if m is not None:
|
| 614 |
+
for p in m.parameters():
|
| 615 |
+
p.requires_grad = False
|
| 616 |
+
|
| 617 |
+
def forward(
|
| 618 |
+
self,
|
| 619 |
+
input_nodes: torch.LongTensor,
|
| 620 |
+
input_edges: torch.LongTensor,
|
| 621 |
+
attn_bias: torch.Tensor,
|
| 622 |
+
in_degree: torch.LongTensor,
|
| 623 |
+
out_degree: torch.LongTensor,
|
| 624 |
+
spatial_pos: torch.LongTensor,
|
| 625 |
+
attn_edge_type: torch.LongTensor,
|
| 626 |
+
perturb=None,
|
| 627 |
+
last_state_only: bool = False,
|
| 628 |
+
token_embeddings: Optional[torch.Tensor] = None,
|
| 629 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 630 |
+
) -> Tuple[Union[torch.Tensor, List[torch.LongTensor]], torch.Tensor]:
|
| 631 |
+
# compute padding mask. This is needed for multi-head attention
|
| 632 |
+
data_x = input_nodes
|
| 633 |
+
n_graph, n_node = data_x.size()[:2]
|
| 634 |
+
padding_mask = (data_x[:, :, 0]).eq(0)
|
| 635 |
+
padding_mask_cls = torch.zeros(n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype)
|
| 636 |
+
padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1)
|
| 637 |
+
|
| 638 |
+
attn_bias = self.graph_attn_bias(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type)
|
| 639 |
+
|
| 640 |
+
if token_embeddings is not None:
|
| 641 |
+
input_nodes = token_embeddings
|
| 642 |
+
else:
|
| 643 |
+
input_nodes = self.graph_node_feature(input_nodes, in_degree, out_degree)
|
| 644 |
+
|
| 645 |
+
if perturb is not None:
|
| 646 |
+
input_nodes[:, 1:, :] += perturb
|
| 647 |
+
|
| 648 |
+
if self.embed_scale is not None:
|
| 649 |
+
input_nodes = input_nodes * self.embed_scale
|
| 650 |
+
|
| 651 |
+
if self.quant_noise is not None:
|
| 652 |
+
input_nodes = self.quant_noise(input_nodes)
|
| 653 |
+
|
| 654 |
+
if self.emb_layer_norm is not None:
|
| 655 |
+
input_nodes = self.emb_layer_norm(input_nodes)
|
| 656 |
+
|
| 657 |
+
input_nodes = self.dropout_module(input_nodes)
|
| 658 |
+
|
| 659 |
+
input_nodes = input_nodes.transpose(0, 1)
|
| 660 |
+
|
| 661 |
+
inner_states = []
|
| 662 |
+
if not last_state_only:
|
| 663 |
+
inner_states.append(input_nodes)
|
| 664 |
+
|
| 665 |
+
for layer in self.layers:
|
| 666 |
+
input_nodes, _ = layer(
|
| 667 |
+
input_nodes,
|
| 668 |
+
self_attn_padding_mask=padding_mask,
|
| 669 |
+
self_attn_mask=attn_mask,
|
| 670 |
+
self_attn_bias=attn_bias,
|
| 671 |
+
)
|
| 672 |
+
if not last_state_only:
|
| 673 |
+
inner_states.append(input_nodes)
|
| 674 |
+
|
| 675 |
+
graph_rep = input_nodes[0, :, :]
|
| 676 |
+
|
| 677 |
+
if last_state_only:
|
| 678 |
+
inner_states = [input_nodes]
|
| 679 |
+
|
| 680 |
+
if self.traceable:
|
| 681 |
+
return torch.stack(inner_states), graph_rep
|
| 682 |
+
else:
|
| 683 |
+
return inner_states, graph_rep
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
class GraphormerDecoderHead(nn.Module):
|
| 687 |
+
def __init__(self, embedding_dim: int, num_classes: int):
|
| 688 |
+
super().__init__()
|
| 689 |
+
"""num_classes should be 1 for regression, or the number of classes for classification"""
|
| 690 |
+
self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
|
| 691 |
+
self.classifier = nn.Linear(embedding_dim, num_classes, bias=False)
|
| 692 |
+
self.num_classes = num_classes
|
| 693 |
+
|
| 694 |
+
def forward(self, input_nodes: torch.Tensor, **unused) -> torch.Tensor:
|
| 695 |
+
input_nodes = self.classifier(input_nodes)
|
| 696 |
+
input_nodes = input_nodes + self.lm_output_learned_bias
|
| 697 |
+
return input_nodes
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class GraphormerPreTrainedModel(PreTrainedModel):
|
| 701 |
+
"""
|
| 702 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 703 |
+
models.
|
| 704 |
+
"""
|
| 705 |
+
|
| 706 |
+
config_class = GraphormerConfig
|
| 707 |
+
base_model_prefix = "graphormer"
|
| 708 |
+
main_input_name_nodes = "input_nodes"
|
| 709 |
+
main_input_name_edges = "input_edges"
|
| 710 |
+
|
| 711 |
+
def normal_(self, data: torch.Tensor):
|
| 712 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
| 713 |
+
# so that the RNG is consistent with and without FSDP
|
| 714 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
| 715 |
+
|
| 716 |
+
def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, GraphormerMultiheadAttention]):
|
| 717 |
+
"""
|
| 718 |
+
Initialize the weights specific to the Graphormer Model.
|
| 719 |
+
"""
|
| 720 |
+
if isinstance(module, nn.Linear):
|
| 721 |
+
self.normal_(module.weight.data)
|
| 722 |
+
if module.bias is not None:
|
| 723 |
+
module.bias.data.zero_()
|
| 724 |
+
if isinstance(module, nn.Embedding):
|
| 725 |
+
self.normal_(module.weight.data)
|
| 726 |
+
if module.padding_idx is not None:
|
| 727 |
+
module.weight.data[module.padding_idx].zero_()
|
| 728 |
+
if isinstance(module, GraphormerMultiheadAttention):
|
| 729 |
+
self.normal_(module.q_proj.weight.data)
|
| 730 |
+
self.normal_(module.k_proj.weight.data)
|
| 731 |
+
self.normal_(module.v_proj.weight.data)
|
| 732 |
+
|
| 733 |
+
def _init_weights(
|
| 734 |
+
self,
|
| 735 |
+
module: Union[
|
| 736 |
+
nn.Linear, nn.Conv2d, nn.Embedding, nn.LayerNorm, GraphormerMultiheadAttention, GraphormerGraphEncoder
|
| 737 |
+
],
|
| 738 |
+
):
|
| 739 |
+
"""
|
| 740 |
+
Initialize the weights
|
| 741 |
+
"""
|
| 742 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 743 |
+
# We might be missing part of the Linear init, dependant on the layer num
|
| 744 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 745 |
+
if module.bias is not None:
|
| 746 |
+
module.bias.data.zero_()
|
| 747 |
+
elif isinstance(module, nn.Embedding):
|
| 748 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 749 |
+
if module.padding_idx is not None:
|
| 750 |
+
module.weight.data[module.padding_idx].zero_()
|
| 751 |
+
elif isinstance(module, GraphormerMultiheadAttention):
|
| 752 |
+
module.q_proj.weight.data.normal_(mean=0.0, std=0.02)
|
| 753 |
+
module.k_proj.weight.data.normal_(mean=0.0, std=0.02)
|
| 754 |
+
module.v_proj.weight.data.normal_(mean=0.0, std=0.02)
|
| 755 |
+
module.reset_parameters()
|
| 756 |
+
elif isinstance(module, nn.LayerNorm):
|
| 757 |
+
module.bias.data.zero_()
|
| 758 |
+
module.weight.data.fill_(1.0)
|
| 759 |
+
elif isinstance(module, GraphormerGraphEncoder):
|
| 760 |
+
if module.apply_graphormer_init:
|
| 761 |
+
module.apply(self.init_graphormer_params)
|
| 762 |
+
|
| 763 |
+
elif isinstance(module, nn.LayerNorm):
|
| 764 |
+
module.bias.data.zero_()
|
| 765 |
+
module.weight.data.fill_(1.0)
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
class GraphormerModel(GraphormerPreTrainedModel):
|
| 769 |
+
"""The Graphormer model is a graph-encoder model.
|
| 770 |
+
|
| 771 |
+
It goes from a graph to its representation. If you want to use the model for a downstream classification task, use
|
| 772 |
+
GraphormerForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine
|
| 773 |
+
this model with a downstream model of your choice, following the example in GraphormerForGraphClassification.
|
| 774 |
+
"""
|
| 775 |
+
|
| 776 |
+
def __init__(self, config: GraphormerConfig):
|
| 777 |
+
super().__init__(config)
|
| 778 |
+
self.max_nodes = config.max_nodes
|
| 779 |
+
|
| 780 |
+
self.graph_encoder = GraphormerGraphEncoder(config)
|
| 781 |
+
|
| 782 |
+
self.share_input_output_embed = config.share_input_output_embed
|
| 783 |
+
self.lm_output_learned_bias = None
|
| 784 |
+
|
| 785 |
+
# Remove head is set to true during fine-tuning
|
| 786 |
+
self.load_softmax = not getattr(config, "remove_head", False)
|
| 787 |
+
|
| 788 |
+
self.lm_head_transform_weight = nn.Linear(config.embedding_dim, config.embedding_dim)
|
| 789 |
+
self.activation_fn = ACT2FN[config.activation_fn]
|
| 790 |
+
self.layer_norm = nn.LayerNorm(config.embedding_dim)
|
| 791 |
+
|
| 792 |
+
self.post_init()
|
| 793 |
+
|
| 794 |
+
def reset_output_layer_parameters(self):
|
| 795 |
+
self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
|
| 796 |
+
|
| 797 |
+
def forward(
|
| 798 |
+
self,
|
| 799 |
+
input_nodes: torch.LongTensor,
|
| 800 |
+
input_edges: torch.LongTensor,
|
| 801 |
+
attn_bias: torch.Tensor,
|
| 802 |
+
in_degree: torch.LongTensor,
|
| 803 |
+
out_degree: torch.LongTensor,
|
| 804 |
+
spatial_pos: torch.LongTensor,
|
| 805 |
+
attn_edge_type: torch.LongTensor,
|
| 806 |
+
perturb: Optional[torch.FloatTensor] = None,
|
| 807 |
+
masked_tokens: None = None,
|
| 808 |
+
return_dict: Optional[bool] = None,
|
| 809 |
+
**unused,
|
| 810 |
+
) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]:
|
| 811 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 812 |
+
|
| 813 |
+
inner_states, graph_rep = self.graph_encoder(
|
| 814 |
+
input_nodes, input_edges, attn_bias, in_degree, out_degree, spatial_pos, attn_edge_type, perturb=perturb
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
# last inner state, then revert Batch and Graph len
|
| 818 |
+
input_nodes = inner_states[-1].transpose(0, 1)
|
| 819 |
+
|
| 820 |
+
# project masked tokens only
|
| 821 |
+
if masked_tokens is not None:
|
| 822 |
+
raise NotImplementedError
|
| 823 |
+
|
| 824 |
+
input_nodes = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(input_nodes)))
|
| 825 |
+
|
| 826 |
+
# project back to size of vocabulary
|
| 827 |
+
if self.share_input_output_embed and hasattr(self.graph_encoder.embed_tokens, "weight"):
|
| 828 |
+
input_nodes = torch.nn.functional.linear(input_nodes, self.graph_encoder.embed_tokens.weight)
|
| 829 |
+
|
| 830 |
+
if not return_dict:
|
| 831 |
+
return tuple(x for x in [input_nodes, inner_states] if x is not None)
|
| 832 |
+
return BaseModelOutputWithNoAttention(last_hidden_state=input_nodes, hidden_states=inner_states)
|
| 833 |
+
|
| 834 |
+
def max_nodes(self):
|
| 835 |
+
"""Maximum output length supported by the encoder."""
|
| 836 |
+
return self.max_nodes
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
class GraphormerForGraphClassification(GraphormerPreTrainedModel):
|
| 840 |
+
"""
|
| 841 |
+
This model can be used for graph-level classification or regression tasks.
|
| 842 |
+
|
| 843 |
+
It can be trained on
|
| 844 |
+
- regression (by setting config.num_classes to 1); there should be one float-type label per graph
|
| 845 |
+
- one task classification (by setting config.num_classes to the number of classes); there should be one integer
|
| 846 |
+
label per graph
|
| 847 |
+
- binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list
|
| 848 |
+
of integer labels for each graph.
|
| 849 |
+
"""
|
| 850 |
+
|
| 851 |
+
def __init__(self, config: GraphormerConfig):
|
| 852 |
+
super().__init__(config)
|
| 853 |
+
self.encoder = GraphormerModel(config)
|
| 854 |
+
self.embedding_dim = config.embedding_dim
|
| 855 |
+
self.num_classes = config.num_classes
|
| 856 |
+
self.classifier = GraphormerDecoderHead(self.embedding_dim, self.num_classes)
|
| 857 |
+
self.is_encoder_decoder = True
|
| 858 |
+
|
| 859 |
+
# Initialize weights and apply final processing
|
| 860 |
+
self.post_init()
|
| 861 |
+
|
| 862 |
+
def forward(
|
| 863 |
+
self,
|
| 864 |
+
input_nodes: torch.LongTensor,
|
| 865 |
+
input_edges: torch.LongTensor,
|
| 866 |
+
attn_bias: torch.Tensor,
|
| 867 |
+
in_degree: torch.LongTensor,
|
| 868 |
+
out_degree: torch.LongTensor,
|
| 869 |
+
spatial_pos: torch.LongTensor,
|
| 870 |
+
attn_edge_type: torch.LongTensor,
|
| 871 |
+
labels: Optional[torch.LongTensor] = None,
|
| 872 |
+
return_dict: Optional[bool] = None,
|
| 873 |
+
**unused,
|
| 874 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 875 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 876 |
+
|
| 877 |
+
encoder_outputs = self.encoder(
|
| 878 |
+
input_nodes,
|
| 879 |
+
input_edges,
|
| 880 |
+
attn_bias,
|
| 881 |
+
in_degree,
|
| 882 |
+
out_degree,
|
| 883 |
+
spatial_pos,
|
| 884 |
+
attn_edge_type,
|
| 885 |
+
return_dict=True,
|
| 886 |
+
)
|
| 887 |
+
outputs, hidden_states = encoder_outputs["last_hidden_state"], encoder_outputs["hidden_states"]
|
| 888 |
+
|
| 889 |
+
head_outputs = self.classifier(outputs)
|
| 890 |
+
logits = head_outputs[:, 0, :].contiguous()
|
| 891 |
+
|
| 892 |
+
loss = None
|
| 893 |
+
if labels is not None:
|
| 894 |
+
mask = ~torch.isnan(labels)
|
| 895 |
+
|
| 896 |
+
if self.num_classes == 1: # regression
|
| 897 |
+
loss_fct = MSELoss()
|
| 898 |
+
loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float())
|
| 899 |
+
elif self.num_classes > 1 and len(labels.shape) == 1: # One task classification
|
| 900 |
+
loss_fct = CrossEntropyLoss()
|
| 901 |
+
loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1))
|
| 902 |
+
else: # Binary multi-task classification
|
| 903 |
+
loss_fct = BCEWithLogitsLoss(reduction="sum")
|
| 904 |
+
loss = loss_fct(logits[mask], labels[mask])
|
| 905 |
+
|
| 906 |
+
if not return_dict:
|
| 907 |
+
return tuple(x for x in [loss, logits, hidden_states] if x is not None)
|
| 908 |
+
return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=None)
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
__all__ = ["GraphormerForGraphClassification", "GraphormerModel", "GraphormerPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_jukebox import *
|
| 22 |
+
from .modeling_jukebox import *
|
| 23 |
+
from .tokenization_jukebox import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py
ADDED
|
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Jukebox configuration"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from typing import List, Union
|
| 19 |
+
|
| 20 |
+
from ....configuration_utils import PretrainedConfig
|
| 21 |
+
from ....utils import logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_LARGE_ATTENTION = [
|
| 28 |
+
"block_attn",
|
| 29 |
+
"transpose_block_attn",
|
| 30 |
+
"prev_block_attn",
|
| 31 |
+
"block_attn",
|
| 32 |
+
"transpose_block_attn",
|
| 33 |
+
"prev_block_attn",
|
| 34 |
+
"block_attn",
|
| 35 |
+
"transpose_block_attn",
|
| 36 |
+
"prev_block_attn",
|
| 37 |
+
"block_attn",
|
| 38 |
+
"transpose_block_attn",
|
| 39 |
+
"prev_block_attn",
|
| 40 |
+
"block_attn",
|
| 41 |
+
"transpose_block_attn",
|
| 42 |
+
"prev_block_attn",
|
| 43 |
+
"block_attn",
|
| 44 |
+
"transpose_block_attn",
|
| 45 |
+
"prev_block_attn",
|
| 46 |
+
"cross_attention",
|
| 47 |
+
"block_attn",
|
| 48 |
+
"transpose_block_attn",
|
| 49 |
+
"prev_block_attn",
|
| 50 |
+
"block_attn",
|
| 51 |
+
"transpose_block_attn",
|
| 52 |
+
"prev_block_attn",
|
| 53 |
+
"block_attn",
|
| 54 |
+
"transpose_block_attn",
|
| 55 |
+
"prev_block_attn",
|
| 56 |
+
"cross_attention",
|
| 57 |
+
"block_attn",
|
| 58 |
+
"transpose_block_attn",
|
| 59 |
+
"prev_block_attn",
|
| 60 |
+
"block_attn",
|
| 61 |
+
"transpose_block_attn",
|
| 62 |
+
"prev_block_attn",
|
| 63 |
+
"block_attn",
|
| 64 |
+
"transpose_block_attn",
|
| 65 |
+
"prev_block_attn",
|
| 66 |
+
"cross_attention",
|
| 67 |
+
"block_attn",
|
| 68 |
+
"transpose_block_attn",
|
| 69 |
+
"prev_block_attn",
|
| 70 |
+
"block_attn",
|
| 71 |
+
"transpose_block_attn",
|
| 72 |
+
"prev_block_attn",
|
| 73 |
+
"block_attn",
|
| 74 |
+
"transpose_block_attn",
|
| 75 |
+
"prev_block_attn",
|
| 76 |
+
"cross_attention",
|
| 77 |
+
"block_attn",
|
| 78 |
+
"transpose_block_attn",
|
| 79 |
+
"prev_block_attn",
|
| 80 |
+
"block_attn",
|
| 81 |
+
"transpose_block_attn",
|
| 82 |
+
"prev_block_attn",
|
| 83 |
+
"block_attn",
|
| 84 |
+
"transpose_block_attn",
|
| 85 |
+
"prev_block_attn",
|
| 86 |
+
"cross_attention",
|
| 87 |
+
"block_attn",
|
| 88 |
+
"transpose_block_attn",
|
| 89 |
+
"prev_block_attn",
|
| 90 |
+
"block_attn",
|
| 91 |
+
"transpose_block_attn",
|
| 92 |
+
"prev_block_attn",
|
| 93 |
+
"block_attn",
|
| 94 |
+
"transpose_block_attn",
|
| 95 |
+
"prev_block_attn",
|
| 96 |
+
"cross_attention",
|
| 97 |
+
"block_attn",
|
| 98 |
+
"transpose_block_attn",
|
| 99 |
+
"prev_block_attn",
|
| 100 |
+
"block_attn",
|
| 101 |
+
"transpose_block_attn",
|
| 102 |
+
"prev_block_attn",
|
| 103 |
+
"block_attn",
|
| 104 |
+
"transpose_block_attn",
|
| 105 |
+
"prev_block_attn",
|
| 106 |
+
"cross_attention",
|
| 107 |
+
]
|
| 108 |
+
_RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"]
|
| 109 |
+
_FullDenseAttention = ["dense_attention"]
|
| 110 |
+
_PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def full_dense_attention(layer):
|
| 114 |
+
return _FullDenseAttention[0]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def raw_column_previous_row_attention(layer):
|
| 118 |
+
return _RawColumnPreviousRowAttention[layer % 3]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def large_separated_enc_dec_w_lyrics(layer):
|
| 122 |
+
return _LARGE_ATTENTION[layer % 79]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def enc_dec_with_lyrics(layer):
|
| 126 |
+
if layer % 16 == 15:
|
| 127 |
+
return _PrimePrimeDenseAttention[layer % 3]
|
| 128 |
+
return _RawColumnPreviousRowAttention[layer % 3]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
ATTENTION_PATTERNS = {
|
| 132 |
+
"full_dense_attention": full_dense_attention,
|
| 133 |
+
"raw_column_previous_row_attention": raw_column_previous_row_attention, # Alternate row, column and previous row attn
|
| 134 |
+
"large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics, # Used by large separated_enc_dec model with lyrics
|
| 135 |
+
"enc_dec_with_lyrics": enc_dec_with_lyrics, # Used by encoder_decoder model with lyrics
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class JukeboxPriorConfig(PretrainedConfig):
|
| 140 |
+
"""
|
| 141 |
+
This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a
|
| 142 |
+
`JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a
|
| 143 |
+
configuration with the defaults will yield a similar configuration to that of the top level prior from the
|
| 144 |
+
[openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox
|
| 145 |
+
-1b-lyrics) architecture.
|
| 146 |
+
|
| 147 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 148 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
act_fn (`str`, *optional*, defaults to `"quick_gelu"`):
|
| 154 |
+
Activation function.
|
| 155 |
+
alignment_head (`int`, *optional*, defaults to 2):
|
| 156 |
+
Head that is responsible of the alignment between lyrics and music. Only used to compute the lyric to audio
|
| 157 |
+
alignment
|
| 158 |
+
alignment_layer (`int`, *optional*, defaults to 68):
|
| 159 |
+
Index of the layer that is responsible of the alignment between lyrics and music. Only used to compute the
|
| 160 |
+
lyric to audio alignment
|
| 161 |
+
attention_multiplier (`float`, *optional*, defaults to 0.25):
|
| 162 |
+
Multiplier coefficient used to define the hidden dimension of the attention layers. 0.25 means that
|
| 163 |
+
0.25*width of the model will be used.
|
| 164 |
+
attention_pattern (`str`, *optional*, defaults to `"enc_dec_with_lyrics"`):
|
| 165 |
+
Which attention pattern to use for the decoder/
|
| 166 |
+
attn_dropout (`int`, *optional*, defaults to 0):
|
| 167 |
+
Dropout probability for the post-attention layer dropout in the decoder.
|
| 168 |
+
attn_res_scale (`bool`, *optional*, defaults to `False`):
|
| 169 |
+
Whether or not to scale the residuals in the attention conditioner block.
|
| 170 |
+
blocks (`int`, *optional*, defaults to 64):
|
| 171 |
+
Number of blocks used in the `block_attn`. A sequence of length seq_len is factored as `[blocks, seq_len //
|
| 172 |
+
blocks]` in the `JukeboxAttention` layer.
|
| 173 |
+
conv_res_scale (`int`, *optional*):
|
| 174 |
+
Whether or not to scale the residuals in the conditioner block. Since the top level prior does not have a
|
| 175 |
+
conditioner, the default value is to None and should not be modified.
|
| 176 |
+
num_layers (`int`, *optional*, defaults to 72):
|
| 177 |
+
Number of layers of the transformer architecture.
|
| 178 |
+
emb_dropout (`int`, *optional*, defaults to 0):
|
| 179 |
+
Embedding dropout used in the lyric decoder.
|
| 180 |
+
encoder_config (`JukeboxPriorConfig`, *optional*) :
|
| 181 |
+
Configuration of the encoder which models the prior on the lyrics.
|
| 182 |
+
encoder_loss_fraction (`float`, *optional*, defaults to 0.4):
|
| 183 |
+
Multiplication factor used in front of the lyric encoder loss.
|
| 184 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
| 185 |
+
Hidden dimension of the attention layers.
|
| 186 |
+
init_scale (`float`, *optional*, defaults to 0.2):
|
| 187 |
+
Initialization scales for the prior modules.
|
| 188 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 189 |
+
Whether or not the prior is an encoder-decoder model. In case it is not, and `nb_relevant_lyric_tokens` is
|
| 190 |
+
greater than 0, the `encoder` args should be specified for the lyric encoding.
|
| 191 |
+
mask (`bool`, *optional*, defaults to `False`):
|
| 192 |
+
Whether or not to mask the previous positions in the attention.
|
| 193 |
+
max_duration (`int`, *optional*, defaults to 600):
|
| 194 |
+
Maximum supported duration of the generated song in seconds.
|
| 195 |
+
max_nb_genres (`int`, *optional*, defaults to 1):
|
| 196 |
+
Maximum number of genres that can be used to condition the model.
|
| 197 |
+
merged_decoder (`bool`, *optional*, defaults to `True`):
|
| 198 |
+
Whether or not the decoder and the encoder inputs are merged. This is used for the separated
|
| 199 |
+
encoder-decoder architecture
|
| 200 |
+
metadata_conditioning (`bool`, *optional*, defaults to `True)`:
|
| 201 |
+
Whether or not to condition on the artist and genre metadata.
|
| 202 |
+
metadata_dims (`List[int]`, *optional*, defaults to `[604, 7898]`):
|
| 203 |
+
Number of genres and the number of artists that were used to train the embedding layers of the prior
|
| 204 |
+
models.
|
| 205 |
+
min_duration (`int`, *optional*, defaults to 0):
|
| 206 |
+
Minimum duration of the generated audio on which the model was trained.
|
| 207 |
+
mlp_multiplier (`float`, *optional*, defaults to 1.0):
|
| 208 |
+
Multiplier coefficient used to define the hidden dimension of the MLP layers. 0.25 means that 0.25*width of
|
| 209 |
+
the model will be used.
|
| 210 |
+
music_vocab_size (`int`, *optional*, defaults to 2048):
|
| 211 |
+
Number of different music tokens. Should be similar to the `JukeboxVQVAEConfig.nb_discrete_codes`.
|
| 212 |
+
n_ctx (`int`, *optional*, defaults to 6144):
|
| 213 |
+
Number of context tokens for each prior. The context tokens are the music tokens that are attended to when
|
| 214 |
+
generating music tokens.
|
| 215 |
+
n_heads (`int`, *optional*, defaults to 2):
|
| 216 |
+
Number of attention heads.
|
| 217 |
+
nb_relevant_lyric_tokens (`int`, *optional*, defaults to 384):
|
| 218 |
+
Number of lyric tokens that are used when sampling a single window of length `n_ctx`
|
| 219 |
+
res_conv_depth (`int`, *optional*, defaults to 3):
|
| 220 |
+
Depth of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the
|
| 221 |
+
`JukeboxMusicTokenConditioner`.
|
| 222 |
+
res_conv_width (`int`, *optional*, defaults to 128):
|
| 223 |
+
Width of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the
|
| 224 |
+
`JukeboxMusicTokenConditioner`.
|
| 225 |
+
res_convolution_multiplier (`int`, *optional*, defaults to 1):
|
| 226 |
+
Multiplier used to scale the `hidden_dim` of the `JukeboxResConv1DBlock`.
|
| 227 |
+
res_dilation_cycle (`int`, *optional*):
|
| 228 |
+
Dilation cycle used to define the `JukeboxMusicTokenConditioner`. Usually similar to the ones used in the
|
| 229 |
+
corresponding level of the VQVAE. The first prior does not use it as it is not conditioned on upper level
|
| 230 |
+
tokens.
|
| 231 |
+
res_dilation_growth_rate (`int`, *optional*, defaults to 1):
|
| 232 |
+
Dilation grow rate used between each convolutionnal block of the `JukeboxMusicTokenConditioner`
|
| 233 |
+
res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`):
|
| 234 |
+
Downsampling rates used in the audio conditioning network
|
| 235 |
+
res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
|
| 236 |
+
Striding used in the audio conditioning network
|
| 237 |
+
resid_dropout (`int`, *optional*, defaults to 0):
|
| 238 |
+
Residual dropout used in the attention pattern.
|
| 239 |
+
sampling_rate (`int`, *optional*, defaults to 44100):
|
| 240 |
+
Sampling rate used for training.
|
| 241 |
+
spread (`int`, *optional*):
|
| 242 |
+
Spread used in the `summary_spread_attention` pattern
|
| 243 |
+
timing_dims (`int`, *optional*, defaults to 64):
|
| 244 |
+
Dimension of the timing embedding.
|
| 245 |
+
zero_out (`bool`, *optional*, defaults to `False`):
|
| 246 |
+
Whether or not to zero out convolution weights when initializing.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
model_type = "jukebox_prior"
|
| 250 |
+
attribute_map = {
|
| 251 |
+
"max_position_embeddings": "n_positions",
|
| 252 |
+
"num_attention_heads": "n_head",
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
act_fn="quick_gelu",
|
| 258 |
+
level=0,
|
| 259 |
+
alignment_head=2,
|
| 260 |
+
alignment_layer=68,
|
| 261 |
+
attention_multiplier=0.25,
|
| 262 |
+
attention_pattern="enc_dec_with_lyrics",
|
| 263 |
+
attn_dropout=0,
|
| 264 |
+
attn_res_scale=False,
|
| 265 |
+
blocks=64,
|
| 266 |
+
conv_res_scale=None,
|
| 267 |
+
num_layers=72,
|
| 268 |
+
emb_dropout=0,
|
| 269 |
+
encoder_config=None,
|
| 270 |
+
encoder_loss_fraction=0.4,
|
| 271 |
+
hidden_size=2048,
|
| 272 |
+
init_scale=0.2,
|
| 273 |
+
is_encoder_decoder=True,
|
| 274 |
+
lyric_vocab_size=80,
|
| 275 |
+
mask=False,
|
| 276 |
+
max_duration=600,
|
| 277 |
+
max_nb_genres=1,
|
| 278 |
+
merged_decoder=True,
|
| 279 |
+
metadata_conditioning=True,
|
| 280 |
+
metadata_dims=[604, 7898],
|
| 281 |
+
min_duration=0,
|
| 282 |
+
mlp_multiplier=1.0,
|
| 283 |
+
music_vocab_size=2048,
|
| 284 |
+
n_ctx=6144,
|
| 285 |
+
n_heads=2,
|
| 286 |
+
nb_relevant_lyric_tokens=384,
|
| 287 |
+
res_conv_depth=3,
|
| 288 |
+
res_conv_width=128,
|
| 289 |
+
res_convolution_multiplier=1,
|
| 290 |
+
res_dilation_cycle=None,
|
| 291 |
+
res_dilation_growth_rate=1,
|
| 292 |
+
res_downs_t=[3, 2, 2],
|
| 293 |
+
res_strides_t=[2, 2, 2],
|
| 294 |
+
resid_dropout=0,
|
| 295 |
+
sampling_rate=44100,
|
| 296 |
+
spread=None,
|
| 297 |
+
timing_dims=64,
|
| 298 |
+
zero_out=False,
|
| 299 |
+
**kwargs,
|
| 300 |
+
):
|
| 301 |
+
self.act_fn = act_fn
|
| 302 |
+
self.alignment_head = alignment_head
|
| 303 |
+
self.alignment_layer = alignment_layer
|
| 304 |
+
self.attention_multiplier = attention_multiplier
|
| 305 |
+
self.attention_pattern = attention_pattern
|
| 306 |
+
self.attn_dropout = attn_dropout
|
| 307 |
+
self.attn_res_scale = attn_res_scale
|
| 308 |
+
self.blocks = blocks
|
| 309 |
+
self.conv_res_scale = conv_res_scale
|
| 310 |
+
self.num_layers = num_layers
|
| 311 |
+
self.emb_dropout = emb_dropout
|
| 312 |
+
self.music_vocab_size = music_vocab_size
|
| 313 |
+
if encoder_config is not None:
|
| 314 |
+
self.encoder_config = JukeboxPriorConfig(**encoder_config)
|
| 315 |
+
else:
|
| 316 |
+
self.encoder_config = None
|
| 317 |
+
self.encoder_loss_fraction = encoder_loss_fraction
|
| 318 |
+
self.init_scale = init_scale
|
| 319 |
+
self.is_encoder_decoder = is_encoder_decoder
|
| 320 |
+
self.lyric_vocab_size = lyric_vocab_size
|
| 321 |
+
self.level = level
|
| 322 |
+
self.mask = mask
|
| 323 |
+
self.max_duration = max_duration
|
| 324 |
+
self.max_nb_genres = max_nb_genres
|
| 325 |
+
self.merged_decoder = merged_decoder
|
| 326 |
+
self.metadata_conditioning = metadata_conditioning
|
| 327 |
+
self.metadata_dims = metadata_dims
|
| 328 |
+
self.min_duration = min_duration
|
| 329 |
+
self.mlp_multiplier = mlp_multiplier
|
| 330 |
+
self.n_ctx = n_ctx
|
| 331 |
+
self.n_heads = n_heads
|
| 332 |
+
self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens
|
| 333 |
+
self.res_conv_depth = res_conv_depth
|
| 334 |
+
self.res_conv_width = res_conv_width
|
| 335 |
+
self.res_convolution_multiplier = res_convolution_multiplier
|
| 336 |
+
self.res_dilation_cycle = res_dilation_cycle
|
| 337 |
+
self.res_dilation_growth_rate = res_dilation_growth_rate
|
| 338 |
+
self.res_downs_t = res_downs_t
|
| 339 |
+
self.res_strides_t = res_strides_t
|
| 340 |
+
self.resid_dropout = resid_dropout
|
| 341 |
+
self.sampling_rate = sampling_rate
|
| 342 |
+
self.spread = spread
|
| 343 |
+
self.timing_dims = timing_dims
|
| 344 |
+
self.hidden_size = hidden_size
|
| 345 |
+
self.zero_out = zero_out
|
| 346 |
+
|
| 347 |
+
@classmethod
|
| 348 |
+
def from_pretrained(
|
| 349 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs
|
| 350 |
+
) -> "PretrainedConfig":
|
| 351 |
+
cls._set_token_in_kwargs(kwargs)
|
| 352 |
+
|
| 353 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 354 |
+
|
| 355 |
+
# get the prior config dict if we are loading from JukeboxConfig
|
| 356 |
+
if config_dict.get("model_type") == "jukebox":
|
| 357 |
+
config_dict = config_dict[f"prior_{level}"]
|
| 358 |
+
|
| 359 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 360 |
+
logger.warning(
|
| 361 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 362 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class JukeboxVQVAEConfig(PretrainedConfig):
|
| 369 |
+
"""
|
| 370 |
+
This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a
|
| 371 |
+
`JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 372 |
+
with the defaults will yield a similar configuration to that of the VQVAE from
|
| 373 |
+
[openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
|
| 374 |
+
|
| 375 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 376 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
act_fn (`str`, *optional*, defaults to `"relu"`):
|
| 380 |
+
Activation function of the model.
|
| 381 |
+
nb_discrete_codes (`int`, *optional*, defaults to 2048):
|
| 382 |
+
Number of codes of the VQVAE.
|
| 383 |
+
commit (`float`, *optional*, defaults to 0.02):
|
| 384 |
+
Commit loss multiplier.
|
| 385 |
+
conv_input_shape (`int`, *optional*, defaults to 1):
|
| 386 |
+
Number of audio channels.
|
| 387 |
+
conv_res_scale (`bool`, *optional*, defaults to `False`):
|
| 388 |
+
Whether or not to scale the residuals of the `JukeboxResConv1DBlock`.
|
| 389 |
+
embed_dim (`int`, *optional*, defaults to 64):
|
| 390 |
+
Embedding dimension of the codebook vectors.
|
| 391 |
+
hop_fraction (`List[int]`, *optional*, defaults to `[0.125, 0.5, 0.5]`):
|
| 392 |
+
Fraction of non-intersecting window used when continuing the sampling process.
|
| 393 |
+
levels (`int`, *optional*, defaults to 3):
|
| 394 |
+
Number of hierarchical levels that used in the VQVAE.
|
| 395 |
+
lmu (`float`, *optional*, defaults to 0.99):
|
| 396 |
+
Used in the codebook update, exponential moving average coefficient. For more detail refer to Appendix A.1
|
| 397 |
+
of the original [VQVAE paper](https://arxiv.org/pdf/1711.00937v2.pdf)
|
| 398 |
+
multipliers (`List[int]`, *optional*, defaults to `[2, 1, 1]`):
|
| 399 |
+
Depth and width multipliers used for each level. Used on the `res_conv_width` and `res_conv_depth`
|
| 400 |
+
res_conv_depth (`int`, *optional*, defaults to 4):
|
| 401 |
+
Depth of the encoder and decoder block. If no `multipliers` are used, this is the same for each level.
|
| 402 |
+
res_conv_width (`int`, *optional*, defaults to 32):
|
| 403 |
+
Width of the encoder and decoder block. If no `multipliers` are used, this is the same for each level.
|
| 404 |
+
res_convolution_multiplier (`int`, *optional*, defaults to 1):
|
| 405 |
+
Scaling factor of the hidden dimension used in the `JukeboxResConv1DBlock`.
|
| 406 |
+
res_dilation_cycle (`int`, *optional*):
|
| 407 |
+
Dilation cycle value used in the `JukeboxResnet`. If an int is used, each new Conv1 block will have a depth
|
| 408 |
+
reduced by a power of `res_dilation_cycle`.
|
| 409 |
+
res_dilation_growth_rate (`int`, *optional*, defaults to 3):
|
| 410 |
+
Resnet dilation growth rate used in the VQVAE (dilation_growth_rate ** depth)
|
| 411 |
+
res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`):
|
| 412 |
+
Downsampling rate for each level of the hierarchical VQ-VAE.
|
| 413 |
+
res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
|
| 414 |
+
Stride used for each level of the hierarchical VQ-VAE.
|
| 415 |
+
sample_length (`int`, *optional*, defaults to 1058304):
|
| 416 |
+
Provides the max input shape of the VQVAE. Is used to compute the input shape of each level.
|
| 417 |
+
init_scale (`float`, *optional*, defaults to 0.2):
|
| 418 |
+
Initialization scale.
|
| 419 |
+
zero_out (`bool`, *optional*, defaults to `False`):
|
| 420 |
+
Whether or not to zero out convolution weights when initializing.
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
model_type = "jukebox_vqvae"
|
| 424 |
+
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
act_fn="relu",
|
| 428 |
+
nb_discrete_codes=2048,
|
| 429 |
+
commit=0.02,
|
| 430 |
+
conv_input_shape=1,
|
| 431 |
+
conv_res_scale=False,
|
| 432 |
+
embed_dim=64,
|
| 433 |
+
hop_fraction=[0.125, 0.5, 0.5],
|
| 434 |
+
levels=3,
|
| 435 |
+
lmu=0.99,
|
| 436 |
+
multipliers=[2, 1, 1],
|
| 437 |
+
res_conv_depth=4,
|
| 438 |
+
res_conv_width=32,
|
| 439 |
+
res_convolution_multiplier=1,
|
| 440 |
+
res_dilation_cycle=None,
|
| 441 |
+
res_dilation_growth_rate=3,
|
| 442 |
+
res_downs_t=[3, 2, 2],
|
| 443 |
+
res_strides_t=[2, 2, 2],
|
| 444 |
+
sample_length=1058304,
|
| 445 |
+
init_scale=0.2,
|
| 446 |
+
zero_out=False,
|
| 447 |
+
**kwargs,
|
| 448 |
+
):
|
| 449 |
+
self.hop_fraction = hop_fraction
|
| 450 |
+
self.conv_input_shape = conv_input_shape
|
| 451 |
+
self.sample_length = sample_length
|
| 452 |
+
|
| 453 |
+
# VQVAE parameters (all used)
|
| 454 |
+
self.levels = levels
|
| 455 |
+
self.embed_dim = embed_dim
|
| 456 |
+
self.nb_discrete_codes = nb_discrete_codes
|
| 457 |
+
self.res_conv_width = res_conv_width
|
| 458 |
+
self.res_conv_depth = res_conv_depth
|
| 459 |
+
self.res_convolution_multiplier = res_convolution_multiplier
|
| 460 |
+
self.res_dilation_growth_rate = res_dilation_growth_rate
|
| 461 |
+
self.res_dilation_cycle = res_dilation_cycle
|
| 462 |
+
self.multipliers = multipliers
|
| 463 |
+
self.res_downs_t = res_downs_t
|
| 464 |
+
self.res_strides_t = res_strides_t
|
| 465 |
+
self.lmu = lmu
|
| 466 |
+
self.commit = commit
|
| 467 |
+
self.conv_res_scale = conv_res_scale
|
| 468 |
+
self.act_fn = act_fn
|
| 469 |
+
self.init_scale = init_scale
|
| 470 |
+
self.zero_out = zero_out
|
| 471 |
+
|
| 472 |
+
@classmethod
|
| 473 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 474 |
+
cls._set_token_in_kwargs(kwargs)
|
| 475 |
+
|
| 476 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 477 |
+
|
| 478 |
+
# get the text config dict if we are loading from CLIPConfig
|
| 479 |
+
if config_dict.get("model_type") == "jukebox":
|
| 480 |
+
config_dict = config_dict["vqvae_config"]
|
| 481 |
+
|
| 482 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 483 |
+
logger.warning(
|
| 484 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 485 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class JukeboxConfig(PretrainedConfig):
|
| 492 |
+
"""
|
| 493 |
+
This is the configuration class to store the configuration of a [`JukeboxModel`].
|
| 494 |
+
|
| 495 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 496 |
+
documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will
|
| 497 |
+
yield a similar configuration to that of
|
| 498 |
+
[openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling =
|
| 502 |
+
(5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256
|
| 503 |
+
to get the second level codes. This is mostly true for training the top level prior and the upsamplers.
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
vqvae_config (`JukeboxVQVAEConfig`, *optional*):
|
| 507 |
+
Configuration for the `JukeboxVQVAE` model.
|
| 508 |
+
prior_config_list (`List[JukeboxPriorConfig]`, *optional*):
|
| 509 |
+
List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors.
|
| 510 |
+
nb_priors (`int`, *optional*, defaults to 3):
|
| 511 |
+
Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive
|
| 512 |
+
(decoder) model, apart from the top prior, which can include a lyric encoder. The available models were
|
| 513 |
+
trained using a top prior and 2 upsampler priors.
|
| 514 |
+
sampling_rate (`int`, *optional*, defaults to 44100):
|
| 515 |
+
Sampling rate of the raw audio.
|
| 516 |
+
timing_dims (`int`, *optional*, defaults to 64):
|
| 517 |
+
Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding
|
| 518 |
+
layer. The timing embedding layer converts the absolute and relative position in the currently sampled
|
| 519 |
+
audio to a tensor of length `timing_dims` that will be added to the music tokens.
|
| 520 |
+
min_duration (`int`, *optional*, defaults to 0):
|
| 521 |
+
Minimum duration of the audios to generate
|
| 522 |
+
max_duration (`float`, *optional*, defaults to 600.0):
|
| 523 |
+
Maximum duration of the audios to generate
|
| 524 |
+
max_nb_genres (`int`, *optional*, defaults to 5):
|
| 525 |
+
Maximum number of genres that can be used to condition a single sample.
|
| 526 |
+
metadata_conditioning (`bool`, *optional*, defaults to `True`):
|
| 527 |
+
Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum
|
| 528 |
+
duration.
|
| 529 |
+
|
| 530 |
+
Example:
|
| 531 |
+
|
| 532 |
+
```python
|
| 533 |
+
>>> from transformers import JukeboxModel, JukeboxConfig
|
| 534 |
+
|
| 535 |
+
>>> # Initializing a Jukebox configuration
|
| 536 |
+
>>> configuration = JukeboxConfig()
|
| 537 |
+
|
| 538 |
+
>>> # Initializing a model from the configuration
|
| 539 |
+
>>> model = JukeboxModel(configuration)
|
| 540 |
+
|
| 541 |
+
>>> # Accessing the model configuration
|
| 542 |
+
>>> configuration = model.config
|
| 543 |
+
```
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
model_type = "jukebox"
|
| 547 |
+
|
| 548 |
+
def __init__(
|
| 549 |
+
self,
|
| 550 |
+
vqvae_config=None,
|
| 551 |
+
prior_config_list=None,
|
| 552 |
+
nb_priors=3,
|
| 553 |
+
sampling_rate=44100,
|
| 554 |
+
timing_dims=64,
|
| 555 |
+
min_duration=0,
|
| 556 |
+
max_duration=600.0,
|
| 557 |
+
max_nb_genres=5,
|
| 558 |
+
metadata_conditioning=True,
|
| 559 |
+
**kwargs,
|
| 560 |
+
):
|
| 561 |
+
if vqvae_config is None:
|
| 562 |
+
vqvae_config = {}
|
| 563 |
+
logger.info("vqvae_config is None. initializing the JukeboxVQVAE with default values.")
|
| 564 |
+
|
| 565 |
+
self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config)
|
| 566 |
+
if prior_config_list is not None:
|
| 567 |
+
self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list]
|
| 568 |
+
else:
|
| 569 |
+
self.prior_configs = []
|
| 570 |
+
for prior_idx in range(nb_priors):
|
| 571 |
+
prior_config = kwargs.pop(f"prior_{prior_idx}", None)
|
| 572 |
+
if prior_config is None:
|
| 573 |
+
prior_config = {}
|
| 574 |
+
logger.info(
|
| 575 |
+
f"prior_{prior_idx}'s config is None. Initializing the JukeboxPriorConfig list with default"
|
| 576 |
+
" values."
|
| 577 |
+
)
|
| 578 |
+
self.prior_configs.append(JukeboxPriorConfig(**prior_config))
|
| 579 |
+
|
| 580 |
+
self.hop_fraction = self.vqvae_config.hop_fraction
|
| 581 |
+
|
| 582 |
+
self.nb_priors = nb_priors
|
| 583 |
+
|
| 584 |
+
# Metadata conditioning
|
| 585 |
+
self.max_nb_genres = max_nb_genres
|
| 586 |
+
self.sampling_rate = sampling_rate
|
| 587 |
+
self.timing_dims = timing_dims
|
| 588 |
+
self.min_duration = min_duration
|
| 589 |
+
self.max_duration = max_duration
|
| 590 |
+
self.metadata_conditioning = metadata_conditioning
|
| 591 |
+
|
| 592 |
+
super().__init__(**kwargs)
|
| 593 |
+
|
| 594 |
+
@classmethod
|
| 595 |
+
def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs):
|
| 596 |
+
r"""
|
| 597 |
+
Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model
|
| 598 |
+
configuration.
|
| 599 |
+
|
| 600 |
+
Returns:
|
| 601 |
+
[`JukeboxConfig`]: An instance of a configuration object
|
| 602 |
+
"""
|
| 603 |
+
prior_config_list = [config.to_dict() for config in prior_configs]
|
| 604 |
+
return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)
|
| 605 |
+
|
| 606 |
+
def to_dict(self):
|
| 607 |
+
# Override the default to_dict to apply to_dict to the list of prior configs.
|
| 608 |
+
result = super().to_dict()
|
| 609 |
+
result["prior_config_list"] = [config.to_dict() for config in result.pop("prior_configs")]
|
| 610 |
+
return result
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
__all__ = ["JukeboxConfig", "JukeboxPriorConfig", "JukeboxVQVAEConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert Jukebox checkpoints"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import requests
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from transformers import JukeboxConfig, JukeboxModel
|
| 26 |
+
from transformers.utils import logging
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logging.set_verbosity_info()
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
PREFIX = "https://openaipublic.azureedge.net/jukebox/models/"
|
| 34 |
+
MODEL_MAPPING = {
|
| 35 |
+
"jukebox-1b-lyrics": [
|
| 36 |
+
"5b/vqvae.pth.tar",
|
| 37 |
+
"5b/prior_level_0.pth.tar",
|
| 38 |
+
"5b/prior_level_1.pth.tar",
|
| 39 |
+
"1b_lyrics/prior_level_2.pth.tar",
|
| 40 |
+
],
|
| 41 |
+
"jukebox-5b-lyrics": [
|
| 42 |
+
"5b/vqvae.pth.tar",
|
| 43 |
+
"5b/prior_level_0.pth.tar",
|
| 44 |
+
"5b/prior_level_1.pth.tar",
|
| 45 |
+
"5b_lyrics/prior_level_2.pth.tar",
|
| 46 |
+
],
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def replace_key(key):
|
| 51 |
+
if key.endswith(".model.1.bias") and len(key.split(".")) > 10:
|
| 52 |
+
key = key.replace(".model.1.bias", ".conv1d_1.bias")
|
| 53 |
+
elif key.endswith(".model.1.weight") and len(key.split(".")) > 10:
|
| 54 |
+
key = key.replace(".model.1.weight", ".conv1d_1.weight")
|
| 55 |
+
elif key.endswith(".model.3.bias") and len(key.split(".")) > 10:
|
| 56 |
+
key = key.replace(".model.3.bias", ".conv1d_2.bias")
|
| 57 |
+
elif key.endswith(".model.3.weight") and len(key.split(".")) > 10:
|
| 58 |
+
key = key.replace(".model.3.weight", ".conv1d_2.weight")
|
| 59 |
+
|
| 60 |
+
if "conditioner_blocks.0." in key:
|
| 61 |
+
key = key.replace("conditioner_blocks.0", "conditioner_blocks")
|
| 62 |
+
|
| 63 |
+
if "prime_prior" in key:
|
| 64 |
+
key = key.replace("prime_prior", "encoder")
|
| 65 |
+
|
| 66 |
+
if ".emb." in key and "total" not in key and "absolute" not in key and "relative" not in key:
|
| 67 |
+
key = key.replace(".emb.", ".")
|
| 68 |
+
|
| 69 |
+
if key.endswith("k"): # replace vqvae.X.k with vqvae.X.codebook
|
| 70 |
+
return key.replace(".k", ".codebook")
|
| 71 |
+
if "y_emb." in key:
|
| 72 |
+
return key.replace("y_emb.", "metadata_embedding.")
|
| 73 |
+
|
| 74 |
+
if "x_emb.emb." in key:
|
| 75 |
+
key = key.replace("0.x_emb.emb", "embed_tokens")
|
| 76 |
+
|
| 77 |
+
if "prime_state_ln" in key:
|
| 78 |
+
return key.replace("prime_state_ln", "encoder.final_layer_norm")
|
| 79 |
+
if ".ln" in key:
|
| 80 |
+
return key.replace(".ln", ".layer_norm")
|
| 81 |
+
if "_ln" in key:
|
| 82 |
+
return key.replace("_ln", "_layer_norm")
|
| 83 |
+
|
| 84 |
+
if "prime_state_proj" in key:
|
| 85 |
+
return key.replace("prime_state_proj", "encoder.proj_in")
|
| 86 |
+
if "prime_x_out" in key:
|
| 87 |
+
return key.replace("prime_x_out", "encoder.lm_head")
|
| 88 |
+
if "prior.x_out" in key:
|
| 89 |
+
return key.replace("x_out", "fc_proj_out")
|
| 90 |
+
if "x_emb" in key:
|
| 91 |
+
return key.replace("x_emb", "embed_tokens")
|
| 92 |
+
|
| 93 |
+
return key
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def fix_jukebox_keys(state_dict, model_state_dict, key_prefix, mapping):
|
| 97 |
+
new_dict = {}
|
| 98 |
+
import re
|
| 99 |
+
|
| 100 |
+
re_encoder_block_conv_in = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
|
| 101 |
+
re_encoder_block_resnet = re.compile(
|
| 102 |
+
r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
|
| 103 |
+
)
|
| 104 |
+
re_encoder_block_proj_out = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
|
| 105 |
+
|
| 106 |
+
re_decoder_block_conv_out = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
|
| 107 |
+
re_decoder_block_resnet = re.compile(
|
| 108 |
+
r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
|
| 109 |
+
)
|
| 110 |
+
re_decoder_block_proj_in = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
|
| 111 |
+
|
| 112 |
+
re_prior_cond_conv_out = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).(bias|weight)")
|
| 113 |
+
re_prior_cond_resnet = re.compile(
|
| 114 |
+
r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
|
| 115 |
+
)
|
| 116 |
+
re_prior_cond_proj_in = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(bias|weight)")
|
| 117 |
+
|
| 118 |
+
for original_key, value in state_dict.items():
|
| 119 |
+
# rename vqvae.encoder keys
|
| 120 |
+
if re_encoder_block_conv_in.fullmatch(original_key):
|
| 121 |
+
regex_match = re_encoder_block_conv_in.match(original_key)
|
| 122 |
+
groups = regex_match.groups()
|
| 123 |
+
block_index = int(groups[2]) * 2 + int(groups[3])
|
| 124 |
+
re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}.{groups[-1]}"
|
| 125 |
+
key = re_encoder_block_conv_in.sub(re_new_key, original_key)
|
| 126 |
+
|
| 127 |
+
elif re_encoder_block_resnet.fullmatch(original_key):
|
| 128 |
+
regex_match = re_encoder_block_resnet.match(original_key)
|
| 129 |
+
groups = regex_match.groups()
|
| 130 |
+
block_index = int(groups[2]) * 2 + int(groups[3])
|
| 131 |
+
conv_index = {"1": 1, "3": 2}[groups[-2]]
|
| 132 |
+
prefix = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}."
|
| 133 |
+
resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
|
| 134 |
+
re_new_key = prefix + resnet_block
|
| 135 |
+
key = re_encoder_block_resnet.sub(re_new_key, original_key)
|
| 136 |
+
|
| 137 |
+
elif re_encoder_block_proj_out.fullmatch(original_key):
|
| 138 |
+
regex_match = re_encoder_block_proj_out.match(original_key)
|
| 139 |
+
groups = regex_match.groups()
|
| 140 |
+
re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.proj_out.{groups[-1]}"
|
| 141 |
+
key = re_encoder_block_proj_out.sub(re_new_key, original_key)
|
| 142 |
+
|
| 143 |
+
# rename vqvae.decoder keys
|
| 144 |
+
elif re_decoder_block_conv_out.fullmatch(original_key):
|
| 145 |
+
regex_match = re_decoder_block_conv_out.match(original_key)
|
| 146 |
+
groups = regex_match.groups()
|
| 147 |
+
block_index = int(groups[2]) * 2 + int(groups[3]) - 2
|
| 148 |
+
re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}.{groups[-1]}"
|
| 149 |
+
key = re_decoder_block_conv_out.sub(re_new_key, original_key)
|
| 150 |
+
|
| 151 |
+
elif re_decoder_block_resnet.fullmatch(original_key):
|
| 152 |
+
regex_match = re_decoder_block_resnet.match(original_key)
|
| 153 |
+
groups = regex_match.groups()
|
| 154 |
+
block_index = int(groups[2]) * 2 + int(groups[3]) - 2
|
| 155 |
+
conv_index = {"1": 1, "3": 2}[groups[-2]]
|
| 156 |
+
prefix = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}."
|
| 157 |
+
resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
|
| 158 |
+
re_new_key = prefix + resnet_block
|
| 159 |
+
key = re_decoder_block_resnet.sub(re_new_key, original_key)
|
| 160 |
+
|
| 161 |
+
elif re_decoder_block_proj_in.fullmatch(original_key):
|
| 162 |
+
regex_match = re_decoder_block_proj_in.match(original_key)
|
| 163 |
+
groups = regex_match.groups()
|
| 164 |
+
re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.proj_in.{groups[-1]}"
|
| 165 |
+
key = re_decoder_block_proj_in.sub(re_new_key, original_key)
|
| 166 |
+
|
| 167 |
+
# rename prior cond.model to upsampler.upsample_block and resnet
|
| 168 |
+
elif re_prior_cond_conv_out.fullmatch(original_key):
|
| 169 |
+
regex_match = re_prior_cond_conv_out.match(original_key)
|
| 170 |
+
groups = regex_match.groups()
|
| 171 |
+
block_index = int(groups[1]) * 2 + int(groups[2]) - 2
|
| 172 |
+
re_new_key = f"conditioner_blocks.upsampler.upsample_block.{block_index}.{groups[-1]}"
|
| 173 |
+
key = re_prior_cond_conv_out.sub(re_new_key, original_key)
|
| 174 |
+
|
| 175 |
+
elif re_prior_cond_resnet.fullmatch(original_key):
|
| 176 |
+
regex_match = re_prior_cond_resnet.match(original_key)
|
| 177 |
+
groups = regex_match.groups()
|
| 178 |
+
block_index = int(groups[1]) * 2 + int(groups[2]) - 2
|
| 179 |
+
conv_index = {"1": 1, "3": 2}[groups[-2]]
|
| 180 |
+
prefix = f"conditioner_blocks.upsampler.upsample_block.{block_index}."
|
| 181 |
+
resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
|
| 182 |
+
re_new_key = prefix + resnet_block
|
| 183 |
+
key = re_prior_cond_resnet.sub(re_new_key, original_key)
|
| 184 |
+
|
| 185 |
+
elif re_prior_cond_proj_in.fullmatch(original_key):
|
| 186 |
+
regex_match = re_prior_cond_proj_in.match(original_key)
|
| 187 |
+
groups = regex_match.groups()
|
| 188 |
+
re_new_key = f"conditioner_blocks.upsampler.proj_in.{groups[-1]}"
|
| 189 |
+
key = re_prior_cond_proj_in.sub(re_new_key, original_key)
|
| 190 |
+
|
| 191 |
+
# keep original key
|
| 192 |
+
else:
|
| 193 |
+
key = original_key
|
| 194 |
+
|
| 195 |
+
key = replace_key(key)
|
| 196 |
+
|
| 197 |
+
if f"{key_prefix}.{key}" not in model_state_dict or key is None:
|
| 198 |
+
print(f"failed converting {original_key} to {key}, does not match")
|
| 199 |
+
|
| 200 |
+
# handle missmatched shape
|
| 201 |
+
elif value.shape != model_state_dict[f"{key_prefix}.{key}"].shape:
|
| 202 |
+
val = model_state_dict[f"{key_prefix}.{key}"]
|
| 203 |
+
print(f"{original_key}-> {key} : \nshape {val.shape} and {value.shape}, do not match")
|
| 204 |
+
key = original_key
|
| 205 |
+
|
| 206 |
+
mapping[key] = original_key
|
| 207 |
+
new_dict[key] = value
|
| 208 |
+
|
| 209 |
+
return new_dict
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
def convert_openai_checkpoint(model_name=None, pytorch_dump_folder_path=None):
|
| 214 |
+
"""
|
| 215 |
+
Copy/paste/tweak model's weights to our Jukebox structure.
|
| 216 |
+
"""
|
| 217 |
+
for file in MODEL_MAPPING[model_name]:
|
| 218 |
+
if not os.path.isfile(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}"):
|
| 219 |
+
r = requests.get(f"{PREFIX}{file}", allow_redirects=True)
|
| 220 |
+
os.makedirs(f"{pytorch_dump_folder_path}/", exist_ok=True)
|
| 221 |
+
open(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}", "wb").write(r.content)
|
| 222 |
+
|
| 223 |
+
model_to_convert = MODEL_MAPPING[model_name.split("/")[-1]]
|
| 224 |
+
|
| 225 |
+
config = JukeboxConfig.from_pretrained(model_name)
|
| 226 |
+
model = JukeboxModel(config)
|
| 227 |
+
|
| 228 |
+
weight_dict = []
|
| 229 |
+
mapping = {}
|
| 230 |
+
for i, dict_name in enumerate(model_to_convert):
|
| 231 |
+
old_dic = torch.load(f"{pytorch_dump_folder_path}/{dict_name.split('/')[-1]}", weights_only=True)["model"]
|
| 232 |
+
|
| 233 |
+
new_dic = {}
|
| 234 |
+
for k in old_dic.keys():
|
| 235 |
+
if k.endswith(".b"):
|
| 236 |
+
new_dic[k.replace("b", "bias")] = old_dic[k]
|
| 237 |
+
elif k.endswith(".w"):
|
| 238 |
+
new_dic[k.replace("w", "weight")] = old_dic[k]
|
| 239 |
+
elif "level_2" not in dict_name and "cond.model." in k:
|
| 240 |
+
new_dic[k.replace(".blocks.", ".model.")] = old_dic[k]
|
| 241 |
+
else:
|
| 242 |
+
new_dic[k] = old_dic[k]
|
| 243 |
+
|
| 244 |
+
key_prefix = "vqvae" if i == 0 else f"priors.{3 - i}"
|
| 245 |
+
new_dic = fix_jukebox_keys(new_dic, model.state_dict(), key_prefix, mapping)
|
| 246 |
+
weight_dict.append(new_dic)
|
| 247 |
+
|
| 248 |
+
vqvae_state_dict = weight_dict.pop(0)
|
| 249 |
+
model.vqvae.load_state_dict(vqvae_state_dict)
|
| 250 |
+
for i in range(len(weight_dict)):
|
| 251 |
+
model.priors[i].load_state_dict(weight_dict[2 - i])
|
| 252 |
+
|
| 253 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 254 |
+
with open(f"{pytorch_dump_folder_path}/mapping.json", "w") as txtfile:
|
| 255 |
+
json.dump(mapping, txtfile)
|
| 256 |
+
|
| 257 |
+
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
| 258 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 259 |
+
|
| 260 |
+
return weight_dict
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
parser = argparse.ArgumentParser()
|
| 265 |
+
# Required parameters
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--model_name",
|
| 268 |
+
default="jukebox-5b-lyrics",
|
| 269 |
+
type=str,
|
| 270 |
+
help="Name of the model you'd like to convert.",
|
| 271 |
+
)
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
"--pytorch_dump_folder_path",
|
| 274 |
+
default="jukebox-5b-lyrics-converted",
|
| 275 |
+
type=str,
|
| 276 |
+
help="Path to the output PyTorch model directory.",
|
| 277 |
+
)
|
| 278 |
+
args = parser.parse_args()
|
| 279 |
+
convert_openai_checkpoint(args.model_name, args.pytorch_dump_folder_path)
|
docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Open AI Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for OpenAI Jukebox."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import unicodedata
|
| 21 |
+
from json.encoder import INFINITY
|
| 22 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import regex
|
| 26 |
+
|
| 27 |
+
from ....tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 28 |
+
from ....tokenization_utils_base import BatchEncoding
|
| 29 |
+
from ....utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging
|
| 30 |
+
from ....utils.generic import _is_jax, _is_numpy
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
VOCAB_FILES_NAMES = {
|
| 36 |
+
"artists_file": "artists.json",
|
| 37 |
+
"lyrics_file": "lyrics.json",
|
| 38 |
+
"genres_file": "genres.json",
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class JukeboxTokenizer(PreTrainedTokenizer):
|
| 43 |
+
"""
|
| 44 |
+
Constructs a Jukebox tokenizer. Jukebox can be conditioned on 3 different inputs :
|
| 45 |
+
- Artists, unique ids are associated to each artist from the provided dictionary.
|
| 46 |
+
- Genres, unique ids are associated to each genre from the provided dictionary.
|
| 47 |
+
- Lyrics, character based tokenization. Must be initialized with the list of characters that are inside the
|
| 48 |
+
vocabulary.
|
| 49 |
+
|
| 50 |
+
This tokenizer does not require training. It should be able to process a different number of inputs:
|
| 51 |
+
as the conditioning of the model can be done on the three different queries. If None is provided, defaults values will be used.:
|
| 52 |
+
|
| 53 |
+
Depending on the number of genres on which the model should be conditioned (`n_genres`).
|
| 54 |
+
```python
|
| 55 |
+
>>> from transformers import JukeboxTokenizer
|
| 56 |
+
|
| 57 |
+
>>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
|
| 58 |
+
>>> tokenizer("Alan Jackson", "Country Rock", "old town road")["input_ids"]
|
| 59 |
+
[tensor([[ 0, 0, 0, 6785, 546, 41, 38, 30, 76, 46, 41, 49,
|
| 60 |
+
40, 76, 44, 41, 27, 30]]), tensor([[ 0, 0, 0, 145, 0]]), tensor([[ 0, 0, 0, 145, 0]])]
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
|
| 64 |
+
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
|
| 65 |
+
|
| 66 |
+
<Tip>
|
| 67 |
+
|
| 68 |
+
If nothing is provided, the genres and the artist will either be selected randomly or set to None
|
| 69 |
+
|
| 70 |
+
</Tip>
|
| 71 |
+
|
| 72 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to:
|
| 73 |
+
this superclass for more information regarding those methods.
|
| 74 |
+
|
| 75 |
+
However the code does not allow that and only supports composing from various genres.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
artists_file (`str`):
|
| 79 |
+
Path to the vocabulary file which contains a mapping between artists and ids. The default file supports
|
| 80 |
+
both "v2" and "v3"
|
| 81 |
+
genres_file (`str`):
|
| 82 |
+
Path to the vocabulary file which contain a mapping between genres and ids.
|
| 83 |
+
lyrics_file (`str`):
|
| 84 |
+
Path to the vocabulary file which contains the accepted characters for the lyrics tokenization.
|
| 85 |
+
version (`List[str]`, `optional`, default to `["v3", "v2", "v2"]`) :
|
| 86 |
+
List of the tokenizer versions. The `5b-lyrics`'s top level prior model was trained using `v3` instead of
|
| 87 |
+
`v2`.
|
| 88 |
+
n_genres (`int`, `optional`, defaults to 1):
|
| 89 |
+
Maximum number of genres to use for composition.
|
| 90 |
+
max_n_lyric_tokens (`int`, `optional`, defaults to 512):
|
| 91 |
+
Maximum number of lyric tokens to keep.
|
| 92 |
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 93 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 94 |
+
token instead.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 98 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
artists_file,
|
| 103 |
+
genres_file,
|
| 104 |
+
lyrics_file,
|
| 105 |
+
version=["v3", "v2", "v2"],
|
| 106 |
+
max_n_lyric_tokens=512,
|
| 107 |
+
n_genres=5,
|
| 108 |
+
unk_token="<|endoftext|>",
|
| 109 |
+
**kwargs,
|
| 110 |
+
):
|
| 111 |
+
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
| 112 |
+
self.version = version
|
| 113 |
+
self.max_n_lyric_tokens = max_n_lyric_tokens
|
| 114 |
+
self.n_genres = n_genres
|
| 115 |
+
self._added_tokens_decoder = {0: unk_token}
|
| 116 |
+
|
| 117 |
+
with open(artists_file, encoding="utf-8") as vocab_handle:
|
| 118 |
+
self.artists_encoder = json.load(vocab_handle)
|
| 119 |
+
|
| 120 |
+
with open(genres_file, encoding="utf-8") as vocab_handle:
|
| 121 |
+
self.genres_encoder = json.load(vocab_handle)
|
| 122 |
+
|
| 123 |
+
with open(lyrics_file, encoding="utf-8") as vocab_handle:
|
| 124 |
+
self.lyrics_encoder = json.load(vocab_handle)
|
| 125 |
+
|
| 126 |
+
oov = r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+"
|
| 127 |
+
# In v2, we had a n_vocab=80 and in v3 we missed + and so n_vocab=79 of characters.
|
| 128 |
+
if len(self.lyrics_encoder) == 79:
|
| 129 |
+
oov = oov.replace(r"\-'", r"\-+'")
|
| 130 |
+
|
| 131 |
+
self.out_of_vocab = regex.compile(oov)
|
| 132 |
+
self.artists_decoder = {v: k for k, v in self.artists_encoder.items()}
|
| 133 |
+
self.genres_decoder = {v: k for k, v in self.genres_encoder.items()}
|
| 134 |
+
self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()}
|
| 135 |
+
super().__init__(
|
| 136 |
+
unk_token=unk_token,
|
| 137 |
+
n_genres=n_genres,
|
| 138 |
+
version=version,
|
| 139 |
+
max_n_lyric_tokens=max_n_lyric_tokens,
|
| 140 |
+
**kwargs,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def vocab_size(self):
|
| 145 |
+
return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder)
|
| 146 |
+
|
| 147 |
+
def get_vocab(self):
|
| 148 |
+
return {
|
| 149 |
+
"artists_encoder": self.artists_encoder,
|
| 150 |
+
"genres_encoder": self.genres_encoder,
|
| 151 |
+
"lyrics_encoder": self.lyrics_encoder,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
def _convert_token_to_id(self, list_artists, list_genres, list_lyrics):
|
| 155 |
+
"""Converts the artist, genre and lyrics tokens to their index using the vocabulary.
|
| 156 |
+
The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to
|
| 157 |
+
the lyrics token sequence.
|
| 158 |
+
"""
|
| 159 |
+
artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists]
|
| 160 |
+
for genres in range(len(list_genres)):
|
| 161 |
+
list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]]
|
| 162 |
+
list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres]))
|
| 163 |
+
|
| 164 |
+
lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []]
|
| 165 |
+
return artists_id, list_genres, lyric_ids
|
| 166 |
+
|
| 167 |
+
def _tokenize(self, lyrics):
|
| 168 |
+
"""
|
| 169 |
+
Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
| 170 |
+
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
|
| 171 |
+
|
| 172 |
+
Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary.
|
| 173 |
+
"""
|
| 174 |
+
# only lyrics are not tokenized, but character based is easily handled
|
| 175 |
+
return list(lyrics)
|
| 176 |
+
|
| 177 |
+
def tokenize(self, artist, genre, lyrics, **kwargs):
|
| 178 |
+
"""
|
| 179 |
+
Converts three strings in a 3 sequence of tokens using the tokenizer
|
| 180 |
+
"""
|
| 181 |
+
artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics)
|
| 182 |
+
lyrics = self._tokenize(lyrics)
|
| 183 |
+
return artist, genre, lyrics
|
| 184 |
+
|
| 185 |
+
def prepare_for_tokenization(
|
| 186 |
+
self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False
|
| 187 |
+
) -> Tuple[str, str, str, Dict[str, Any]]:
|
| 188 |
+
"""
|
| 189 |
+
Performs any necessary transformations before tokenization.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
artist (`str`):
|
| 193 |
+
The artist name to prepare. This will mostly lower the string
|
| 194 |
+
genres (`str`):
|
| 195 |
+
The genre name to prepare. This will mostly lower the string.
|
| 196 |
+
lyrics (`str`):
|
| 197 |
+
The lyrics to prepare.
|
| 198 |
+
is_split_into_words (`bool`, *optional*, defaults to `False`):
|
| 199 |
+
Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
|
| 200 |
+
tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
|
| 201 |
+
which it will tokenize. This is useful for NER or token classification.
|
| 202 |
+
"""
|
| 203 |
+
for idx in range(len(self.version)):
|
| 204 |
+
if self.version[idx] == "v3":
|
| 205 |
+
artists[idx] = artists[idx].lower()
|
| 206 |
+
genres[idx] = [genres[idx].lower()]
|
| 207 |
+
else:
|
| 208 |
+
artists[idx] = self._normalize(artists[idx]) + ".v2"
|
| 209 |
+
genres[idx] = [
|
| 210 |
+
self._normalize(genre) + ".v2" for genre in genres[idx].split("_")
|
| 211 |
+
] # split is for the full dictionary with combined genres
|
| 212 |
+
|
| 213 |
+
if self.version[0] == "v2":
|
| 214 |
+
self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+")
|
| 215 |
+
vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n"
|
| 216 |
+
self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))}
|
| 217 |
+
self.vocab["<unk>"] = 0
|
| 218 |
+
self.n_vocab = len(vocab) + 1
|
| 219 |
+
self.lyrics_encoder = self.vocab
|
| 220 |
+
self.lyrics_decoder = {v: k for k, v in self.vocab.items()}
|
| 221 |
+
self.lyrics_decoder[0] = ""
|
| 222 |
+
else:
|
| 223 |
+
self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+")
|
| 224 |
+
|
| 225 |
+
lyrics = self._run_strip_accents(lyrics)
|
| 226 |
+
lyrics = lyrics.replace("\\", "\n")
|
| 227 |
+
lyrics = self.out_of_vocab.sub("", lyrics), [], []
|
| 228 |
+
return artists, genres, lyrics
|
| 229 |
+
|
| 230 |
+
def _run_strip_accents(self, text):
|
| 231 |
+
"""Strips accents from a piece of text."""
|
| 232 |
+
text = unicodedata.normalize("NFD", text)
|
| 233 |
+
output = []
|
| 234 |
+
for char in text:
|
| 235 |
+
cat = unicodedata.category(char)
|
| 236 |
+
if cat == "Mn":
|
| 237 |
+
continue
|
| 238 |
+
output.append(char)
|
| 239 |
+
return "".join(output)
|
| 240 |
+
|
| 241 |
+
def _normalize(self, text: str) -> str:
|
| 242 |
+
"""
|
| 243 |
+
Normalizes the input text. This process is for the genres and the artist
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
text (`str`):
|
| 247 |
+
Artist or Genre string to normalize
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
accepted = (
|
| 251 |
+
[chr(i) for i in range(ord("a"), ord("z") + 1)]
|
| 252 |
+
+ [chr(i) for i in range(ord("A"), ord("Z") + 1)]
|
| 253 |
+
+ [chr(i) for i in range(ord("0"), ord("9") + 1)]
|
| 254 |
+
+ ["."]
|
| 255 |
+
)
|
| 256 |
+
accepted = frozenset(accepted)
|
| 257 |
+
pattern = re.compile(r"_+")
|
| 258 |
+
text = "".join([c if c in accepted else "_" for c in text.lower()])
|
| 259 |
+
text = pattern.sub("_", text).strip("_")
|
| 260 |
+
return text
|
| 261 |
+
|
| 262 |
+
def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str:
|
| 263 |
+
return " ".join(lyrics)
|
| 264 |
+
|
| 265 |
+
def convert_to_tensors(
|
| 266 |
+
self, inputs, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False
|
| 267 |
+
):
|
| 268 |
+
"""
|
| 269 |
+
Convert the inner content to tensors.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
tensor_type (`str` or [`~utils.TensorType`], *optional*):
|
| 273 |
+
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
|
| 274 |
+
unset, no modification is done.
|
| 275 |
+
prepend_batch_axis (`int`, *optional*, defaults to `False`):
|
| 276 |
+
Whether or not to add the batch dimension during the conversion.
|
| 277 |
+
"""
|
| 278 |
+
# Convert to TensorType
|
| 279 |
+
if not isinstance(tensor_type, TensorType):
|
| 280 |
+
tensor_type = TensorType(tensor_type)
|
| 281 |
+
|
| 282 |
+
# Get a function reference for the correct framework
|
| 283 |
+
if tensor_type == TensorType.TENSORFLOW:
|
| 284 |
+
if not is_tf_available():
|
| 285 |
+
raise ImportError(
|
| 286 |
+
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
|
| 287 |
+
)
|
| 288 |
+
import tensorflow as tf
|
| 289 |
+
|
| 290 |
+
as_tensor = tf.constant
|
| 291 |
+
is_tensor = tf.is_tensor
|
| 292 |
+
elif tensor_type == TensorType.PYTORCH:
|
| 293 |
+
if not is_torch_available():
|
| 294 |
+
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
| 295 |
+
import torch
|
| 296 |
+
|
| 297 |
+
as_tensor = torch.tensor
|
| 298 |
+
is_tensor = torch.is_tensor
|
| 299 |
+
elif tensor_type == TensorType.JAX:
|
| 300 |
+
if not is_flax_available():
|
| 301 |
+
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
|
| 302 |
+
import jax.numpy as jnp # noqa: F811
|
| 303 |
+
|
| 304 |
+
as_tensor = jnp.array
|
| 305 |
+
is_tensor = _is_jax
|
| 306 |
+
else:
|
| 307 |
+
as_tensor = np.asarray
|
| 308 |
+
is_tensor = _is_numpy
|
| 309 |
+
|
| 310 |
+
# Do the tensor conversion in batch
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
if prepend_batch_axis:
|
| 314 |
+
inputs = [inputs]
|
| 315 |
+
|
| 316 |
+
if not is_tensor(inputs):
|
| 317 |
+
inputs = as_tensor(inputs)
|
| 318 |
+
except: # noqa E722
|
| 319 |
+
raise ValueError(
|
| 320 |
+
"Unable to create tensor, you should probably activate truncation and/or padding "
|
| 321 |
+
"with 'padding=True' 'truncation=True' to have batched tensors with the same length."
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
return inputs
|
| 325 |
+
|
| 326 |
+
def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding:
|
| 327 |
+
"""Convert the raw string to a list of token ids
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
artist (`str`):
|
| 331 |
+
Name of the artist.
|
| 332 |
+
genres (`str`):
|
| 333 |
+
List of genres that will be mixed to condition the audio
|
| 334 |
+
lyrics (`str`, *optional*, defaults to `""`):
|
| 335 |
+
Lyrics used to condition the generation
|
| 336 |
+
"""
|
| 337 |
+
input_ids = [0, 0, 0]
|
| 338 |
+
artist = [artist] * len(self.version)
|
| 339 |
+
genres = [genres] * len(self.version)
|
| 340 |
+
|
| 341 |
+
artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics)
|
| 342 |
+
artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens)
|
| 343 |
+
|
| 344 |
+
attention_masks = [-INFINITY] * len(full_tokens[-1])
|
| 345 |
+
input_ids = [
|
| 346 |
+
self.convert_to_tensors(
|
| 347 |
+
[input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors
|
| 348 |
+
)
|
| 349 |
+
for i in range(len(self.version))
|
| 350 |
+
]
|
| 351 |
+
return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks})
|
| 352 |
+
|
| 353 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 354 |
+
"""
|
| 355 |
+
Saves the tokenizer's vocabulary dictionary to the provided save_directory.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
save_directory (`str`):
|
| 359 |
+
A path to the directory where to saved. It will be created if it doesn't exist.
|
| 360 |
+
|
| 361 |
+
filename_prefix (`Optional[str]`, *optional*):
|
| 362 |
+
A prefix to add to the names of the files saved by the tokenizer.
|
| 363 |
+
|
| 364 |
+
"""
|
| 365 |
+
if not os.path.isdir(save_directory):
|
| 366 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 367 |
+
return
|
| 368 |
+
|
| 369 |
+
artists_file = os.path.join(
|
| 370 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"]
|
| 371 |
+
)
|
| 372 |
+
with open(artists_file, "w", encoding="utf-8") as f:
|
| 373 |
+
f.write(json.dumps(self.artists_encoder, ensure_ascii=False))
|
| 374 |
+
|
| 375 |
+
genres_file = os.path.join(
|
| 376 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"]
|
| 377 |
+
)
|
| 378 |
+
with open(genres_file, "w", encoding="utf-8") as f:
|
| 379 |
+
f.write(json.dumps(self.genres_encoder, ensure_ascii=False))
|
| 380 |
+
|
| 381 |
+
lyrics_file = os.path.join(
|
| 382 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"]
|
| 383 |
+
)
|
| 384 |
+
with open(lyrics_file, "w", encoding="utf-8") as f:
|
| 385 |
+
f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False))
|
| 386 |
+
|
| 387 |
+
return (artists_file, genres_file, lyrics_file)
|
| 388 |
+
|
| 389 |
+
def _convert_id_to_token(self, artists_index, genres_index, lyric_index):
|
| 390 |
+
"""
|
| 391 |
+
Converts an index (integer) in a token (str) using the vocab.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
artists_index (`int`):
|
| 395 |
+
Index of the artist in its corresponding dictionary.
|
| 396 |
+
genres_index (`Union[List[int], int]`):
|
| 397 |
+
Index of the genre in its corresponding dictionary.
|
| 398 |
+
lyric_index (`List[int]`):
|
| 399 |
+
List of character indices, which each correspond to a character.
|
| 400 |
+
"""
|
| 401 |
+
artist = self.artists_decoder.get(artists_index)
|
| 402 |
+
genres = [self.genres_decoder.get(genre) for genre in genres_index]
|
| 403 |
+
lyrics = [self.lyrics_decoder.get(character) for character in lyric_index]
|
| 404 |
+
return artist, genres, lyrics
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
__all__ = ["JukeboxTokenizer"]
|
docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_mctct import *
|
| 22 |
+
from .feature_extraction_mctct import *
|
| 23 |
+
from .modeling_mctct import *
|
| 24 |
+
from .processing_mctct import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""M-CTC-T model configuration"""
|
| 16 |
+
|
| 17 |
+
from ....configuration_utils import PretrainedConfig
|
| 18 |
+
from ....utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MCTCTConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`MCTCTModel`]. It is used to instantiate an
|
| 27 |
+
M-CTC-T model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 28 |
+
with the defaults will yield a similar configuration to that of the M-CTC-T
|
| 29 |
+
[speechbrain/m-ctc-t-large](https://huggingface.co/speechbrain/m-ctc-t-large) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 8065):
|
| 37 |
+
Vocabulary size of the M-CTC-T model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`MCTCTModel`].
|
| 39 |
+
hidden_size (`int`, *optional*, defaults to 1536):
|
| 40 |
+
Dimension of the encoder layers and the pooler layer.
|
| 41 |
+
num_hidden_layers (`int`, *optional*, defaults to 36):
|
| 42 |
+
Number of hidden layers in the Transformer encoder.
|
| 43 |
+
intermediate_size (`int`, *optional*, defaults to 6144):
|
| 44 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 45 |
+
num_attention_heads (`int`, *optional*, defaults to 4):
|
| 46 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 47 |
+
attention_head_dim (`int`, *optional*, defaults to 384):
|
| 48 |
+
Dimensions of each attention head for each attention layer in the Transformer encoder.
|
| 49 |
+
max_position_embeddings (`int`, *optional*, defaults to 920):
|
| 50 |
+
The maximum sequence length that this model might ever be used with (after log-mel spectrogram extraction).
|
| 51 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 52 |
+
The epsilon used by the layer normalization layers.
|
| 53 |
+
layerdrop (`float`, *optional*, defaults to 0.3):
|
| 54 |
+
The probability of dropping an encoder layer during training. The default 0.3 value is used in the original
|
| 55 |
+
implementation.
|
| 56 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
|
| 57 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 58 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 59 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 60 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 61 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.3):
|
| 62 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 63 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.3):
|
| 64 |
+
The dropout ratio for the attention probabilities.
|
| 65 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 66 |
+
The tokenizer index of the pad token.
|
| 67 |
+
bos_token_id (`int`, *optional*, defaults to 0):
|
| 68 |
+
The tokenizer index of the bos token.
|
| 69 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 70 |
+
The tokenizer index of the eos token.
|
| 71 |
+
conv_glu_dim (`int`, *optional*, defaults to 1):
|
| 72 |
+
The dimension of the output of the `Conv1dSubsampler` layer in which GLU is applied on. Though the original
|
| 73 |
+
Flashlight code uses the value of 2, here it's adapted to 1 due to transposition differences.
|
| 74 |
+
conv_dropout (`int`, *optional*, defaults to 0.3):
|
| 75 |
+
The probability of randomly dropping the `Conv1dSubsampler` layer during training.
|
| 76 |
+
num_conv_layers (`int`, *optional*, defaults to 1):
|
| 77 |
+
Number of convolution layers before applying transformer encoder layers.
|
| 78 |
+
conv_kernel (`Sequence[int]`, *optional*, defaults to `(7,)`):
|
| 79 |
+
The kernel size of the 1D convolution applied before transformer layers. `len(conv_kernel)` must be equal
|
| 80 |
+
to `num_conv_layers`.
|
| 81 |
+
conv_stride (`Sequence[int]`, *optional*, defaults to `(3,)`):
|
| 82 |
+
The stride length of the 1D convolution applied before transformer layers. `len(conv_stride)` must be equal
|
| 83 |
+
to `num_conv_layers`.
|
| 84 |
+
input_feat_per_channel (`int`, *optional*, defaults to 80):
|
| 85 |
+
Feature dimensions of the channels of the input to the Conv1D layer.
|
| 86 |
+
input_channels (`int`, *optional*, defaults to 1):
|
| 87 |
+
Number of input channels of the input to the Conv1D layer.
|
| 88 |
+
conv_channels (`List[int]`, *optional*):
|
| 89 |
+
Channel sizes of intermediate Conv1D layers.
|
| 90 |
+
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
|
| 91 |
+
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
| 92 |
+
instance of [`MCTCTForCTC`].
|
| 93 |
+
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
|
| 94 |
+
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
| 95 |
+
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
| 96 |
+
of [`MCTCTForCTC`].
|
| 97 |
+
|
| 98 |
+
Example:
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
>>> from transformers import MCTCTConfig, MCTCTModel
|
| 102 |
+
|
| 103 |
+
>>> # Initializing a M-CTC-T mctct-large style configuration
|
| 104 |
+
>>> configuration = MCTCTConfig()
|
| 105 |
+
|
| 106 |
+
>>> # Initializing a model (with random weights) from the mctct-large style configuration
|
| 107 |
+
>>> model = MCTCTModel(configuration)
|
| 108 |
+
|
| 109 |
+
>>> # Accessing the model configuration
|
| 110 |
+
>>> configuration = model.config
|
| 111 |
+
```"""
|
| 112 |
+
|
| 113 |
+
model_type = "mctct"
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
vocab_size=8065,
|
| 118 |
+
hidden_size=1536,
|
| 119 |
+
num_hidden_layers=36,
|
| 120 |
+
intermediate_size=6144,
|
| 121 |
+
num_attention_heads=4,
|
| 122 |
+
attention_head_dim=384,
|
| 123 |
+
max_position_embeddings=920,
|
| 124 |
+
layer_norm_eps=1e-5,
|
| 125 |
+
layerdrop=0.3,
|
| 126 |
+
hidden_act="relu",
|
| 127 |
+
initializer_range=0.02,
|
| 128 |
+
hidden_dropout_prob=0.3,
|
| 129 |
+
attention_probs_dropout_prob=0.3,
|
| 130 |
+
pad_token_id=1,
|
| 131 |
+
bos_token_id=0,
|
| 132 |
+
eos_token_id=2,
|
| 133 |
+
conv_glu_dim=1,
|
| 134 |
+
conv_dropout=0.3,
|
| 135 |
+
num_conv_layers=1,
|
| 136 |
+
conv_kernel=(7,),
|
| 137 |
+
conv_stride=(3,),
|
| 138 |
+
input_feat_per_channel=80,
|
| 139 |
+
input_channels=1,
|
| 140 |
+
conv_channels=None,
|
| 141 |
+
ctc_loss_reduction="sum",
|
| 142 |
+
ctc_zero_infinity=False,
|
| 143 |
+
**kwargs,
|
| 144 |
+
):
|
| 145 |
+
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
| 146 |
+
self.vocab_size = vocab_size
|
| 147 |
+
self.hidden_size = hidden_size
|
| 148 |
+
self.num_hidden_layers = num_hidden_layers
|
| 149 |
+
self.intermediate_size = intermediate_size
|
| 150 |
+
self.num_attention_heads = num_attention_heads
|
| 151 |
+
self.attention_head_dim = attention_head_dim
|
| 152 |
+
self.max_position_embeddings = max_position_embeddings
|
| 153 |
+
self.layer_norm_eps = layer_norm_eps
|
| 154 |
+
self.layerdrop = layerdrop
|
| 155 |
+
self.hidden_act = hidden_act
|
| 156 |
+
self.initializer_range = initializer_range
|
| 157 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 158 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 159 |
+
self.pad_token_id = pad_token_id
|
| 160 |
+
self.bos_token_id = bos_token_id
|
| 161 |
+
self.eos_token_id = eos_token_id
|
| 162 |
+
self.conv_glu_dim = conv_glu_dim
|
| 163 |
+
self.conv_dropout = conv_dropout
|
| 164 |
+
self.num_conv_layers = num_conv_layers
|
| 165 |
+
self.input_feat_per_channel = input_feat_per_channel
|
| 166 |
+
self.input_channels = input_channels
|
| 167 |
+
self.conv_channels = conv_channels
|
| 168 |
+
self.ctc_loss_reduction = ctc_loss_reduction
|
| 169 |
+
self.ctc_zero_infinity = ctc_zero_infinity
|
| 170 |
+
|
| 171 |
+
# prevents config testing fail with exporting to json
|
| 172 |
+
self.conv_kernel = list(conv_kernel)
|
| 173 |
+
self.conv_stride = list(conv_stride)
|
| 174 |
+
|
| 175 |
+
if len(self.conv_kernel) != self.num_conv_layers:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
"Configuration for convolutional module is incorrect. "
|
| 178 |
+
"It is required that `len(config.conv_kernel)` == `config.num_conv_layers` "
|
| 179 |
+
f"but is `len(config.conv_kernel) = {len(self.conv_kernel)}`, "
|
| 180 |
+
f"`config.num_conv_layers = {self.num_conv_layers}`."
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
__all__ = ["MCTCTConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Feature extractor class for M-CTC-T
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from typing import List, Optional, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from ....audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function
|
| 24 |
+
from ....feature_extraction_sequence_utils import SequenceFeatureExtractor
|
| 25 |
+
from ....feature_extraction_utils import BatchFeature
|
| 26 |
+
from ....file_utils import PaddingStrategy, TensorType
|
| 27 |
+
from ....utils import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class MCTCTFeatureExtractor(SequenceFeatureExtractor):
|
| 34 |
+
r"""
|
| 35 |
+
Constructs a M-CTC-T feature extractor.
|
| 36 |
+
|
| 37 |
+
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
| 38 |
+
most of the main methods. Users should refer to this superclass for more information regarding those methods. This
|
| 39 |
+
code has been adapted from Flashlight's C++ code. For more information about the implementation, one can refer to
|
| 40 |
+
this [notebook](https://colab.research.google.com/drive/1GLtINkkhzms-IsdcGy_-tVCkv0qNF-Gt#scrollTo=pMCRGMmUC_an)
|
| 41 |
+
that takes the user step-by-step in the implementation.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
feature_size (`int`, defaults to 80):
|
| 45 |
+
The feature dimension of the extracted features. This is the number of mel_frequency
|
| 46 |
+
sampling_rate (`int`, defaults to 16000):
|
| 47 |
+
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
|
| 48 |
+
padding_value (`float`, defaults to 0.0):
|
| 49 |
+
The value that is used to fill the padding values.
|
| 50 |
+
hop_length (`int`, defaults to 10):
|
| 51 |
+
Number of audio samples between windows. Otherwise referred to as "shift" in many papers.
|
| 52 |
+
win_length (`int`, defaults to 25):
|
| 53 |
+
Number of ms per window
|
| 54 |
+
win_function (`str`, defaults to `"hamming_window"`):
|
| 55 |
+
Name for the window function used for windowing, must be accessible via `torch.{win_function}`
|
| 56 |
+
frame_signal_scale (`float`, defaults to 32768.0):
|
| 57 |
+
Constant multiplied in creating the frames before applying DFT.
|
| 58 |
+
preemphasis_coeff (`float`, defaults to 0.97):
|
| 59 |
+
Constant multiplied in applying Pre-emphasis before DFT.
|
| 60 |
+
mel_floor (`float` defaults to 1.0):
|
| 61 |
+
Minimum value of mel frequency banks.
|
| 62 |
+
normalize_means (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether or not to zero-mean normalize the extracted features.
|
| 64 |
+
normalize_vars (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether or not to unit-variance normalize the extracted features.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
model_input_names = ["input_features", "attention_mask"]
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
feature_size=80,
|
| 73 |
+
sampling_rate=16000,
|
| 74 |
+
padding_value=0.0,
|
| 75 |
+
hop_length=10,
|
| 76 |
+
win_length=25,
|
| 77 |
+
win_function="hamming_window",
|
| 78 |
+
frame_signal_scale=32768.0,
|
| 79 |
+
preemphasis_coeff=0.97,
|
| 80 |
+
mel_floor=1.0,
|
| 81 |
+
normalize_means=True,
|
| 82 |
+
normalize_vars=True,
|
| 83 |
+
return_attention_mask=False,
|
| 84 |
+
**kwargs,
|
| 85 |
+
):
|
| 86 |
+
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
| 87 |
+
|
| 88 |
+
self.feature_size = feature_size
|
| 89 |
+
self.sampling_rate = sampling_rate
|
| 90 |
+
self.padding_value = padding_value
|
| 91 |
+
self.hop_length = hop_length
|
| 92 |
+
self.win_length = win_length
|
| 93 |
+
self.frame_signal_scale = frame_signal_scale
|
| 94 |
+
self.preemphasis_coeff = preemphasis_coeff
|
| 95 |
+
self.mel_floor = mel_floor
|
| 96 |
+
self.normalize_means = normalize_means
|
| 97 |
+
self.normalize_vars = normalize_vars
|
| 98 |
+
self.win_function = win_function
|
| 99 |
+
self.return_attention_mask = return_attention_mask
|
| 100 |
+
|
| 101 |
+
self.sample_size = win_length * sampling_rate // 1000
|
| 102 |
+
self.sample_stride = hop_length * sampling_rate // 1000
|
| 103 |
+
|
| 104 |
+
self.n_fft = optimal_fft_length(self.sample_size)
|
| 105 |
+
self.n_freqs = (self.n_fft // 2) + 1
|
| 106 |
+
|
| 107 |
+
def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray:
|
| 108 |
+
"""
|
| 109 |
+
Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
|
| 110 |
+
"""
|
| 111 |
+
if self.win_function == "hamming_window":
|
| 112 |
+
window = window_function(window_length=self.sample_size, name=self.win_function, periodic=False)
|
| 113 |
+
else:
|
| 114 |
+
window = window_function(window_length=self.sample_size, name=self.win_function)
|
| 115 |
+
|
| 116 |
+
fbanks = mel_filter_bank(
|
| 117 |
+
num_frequency_bins=self.n_freqs,
|
| 118 |
+
num_mel_filters=self.feature_size,
|
| 119 |
+
min_frequency=0.0,
|
| 120 |
+
max_frequency=self.sampling_rate / 2.0,
|
| 121 |
+
sampling_rate=self.sampling_rate,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
msfc_features = spectrogram(
|
| 125 |
+
one_waveform * self.frame_signal_scale,
|
| 126 |
+
window=window,
|
| 127 |
+
frame_length=self.sample_size,
|
| 128 |
+
hop_length=self.sample_stride,
|
| 129 |
+
fft_length=self.n_fft,
|
| 130 |
+
center=False,
|
| 131 |
+
preemphasis=self.preemphasis_coeff,
|
| 132 |
+
mel_filters=fbanks,
|
| 133 |
+
mel_floor=self.mel_floor,
|
| 134 |
+
log_mel="log",
|
| 135 |
+
)
|
| 136 |
+
return msfc_features.T
|
| 137 |
+
|
| 138 |
+
def _normalize_one(self, x, input_length, padding_value):
|
| 139 |
+
# make sure we normalize float32 arrays
|
| 140 |
+
if self.normalize_means:
|
| 141 |
+
mean = x[:input_length].mean(axis=0)
|
| 142 |
+
x = np.subtract(x, mean)
|
| 143 |
+
if self.normalize_vars:
|
| 144 |
+
std = x[:input_length].std(axis=0)
|
| 145 |
+
x = np.divide(x, std)
|
| 146 |
+
|
| 147 |
+
if input_length < x.shape[0]:
|
| 148 |
+
x[input_length:] = padding_value
|
| 149 |
+
|
| 150 |
+
# make sure array is in float32
|
| 151 |
+
x = x.astype(np.float32)
|
| 152 |
+
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
def normalize(
|
| 156 |
+
self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
|
| 157 |
+
) -> List[np.ndarray]:
|
| 158 |
+
lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
|
| 159 |
+
return [self._normalize_one(x, n, self.padding_value) for x, n in zip(input_features, lengths)]
|
| 160 |
+
|
| 161 |
+
def __call__(
|
| 162 |
+
self,
|
| 163 |
+
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
| 164 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 165 |
+
max_length: Optional[int] = None,
|
| 166 |
+
truncation: bool = False,
|
| 167 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 168 |
+
return_attention_mask: Optional[bool] = None,
|
| 169 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 170 |
+
sampling_rate: Optional[int] = None,
|
| 171 |
+
**kwargs,
|
| 172 |
+
) -> BatchFeature:
|
| 173 |
+
"""
|
| 174 |
+
Main method to featurize and prepare for the model one or several sequence(s). sequences. It returns the
|
| 175 |
+
log-mel spectrogram of the input audio, as implemented in the original Flashlight MFSC feature extraction code.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
raw_speech (`torch.Tensor`, `np.ndarray`, `List[float]`, `List[torch.Tensor]`, `List[np.ndarray]`, `List[List[float]]`):
|
| 179 |
+
The sequence or batch of sequences to be padded. Each sequence can be a tensor, a numpy array, a list
|
| 180 |
+
of float values, a list of tensors, a list of numpy arrays or a list of list of float values. Must be
|
| 181 |
+
mono channel audio, not stereo, i.e. single float per timestep.
|
| 182 |
+
padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
|
| 183 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 184 |
+
index) among:
|
| 185 |
+
|
| 186 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 187 |
+
sequence if provided).
|
| 188 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 189 |
+
acceptable input length for the model if that argument is not provided.
|
| 190 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
| 191 |
+
lengths).
|
| 192 |
+
max_length (`int`, *optional*):
|
| 193 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 194 |
+
truncation (`bool`):
|
| 195 |
+
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
|
| 196 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 197 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 198 |
+
|
| 199 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 200 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
|
| 201 |
+
return_attention_mask (`bool`, *optional*):
|
| 202 |
+
Whether to return the attention mask. If left to the default, will return the attention mask according
|
| 203 |
+
to the specific feature_extractor's default.
|
| 204 |
+
|
| 205 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 206 |
+
|
| 207 |
+
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
|
| 208 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
| 209 |
+
|
| 210 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 211 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 212 |
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
| 213 |
+
sampling_rate (`int`, *optional*):
|
| 214 |
+
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
|
| 215 |
+
`sampling_rate` at the forward call to prevent silent errors.
|
| 216 |
+
padding_value (`float`, defaults to 0.0):
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
if sampling_rate is not None:
|
| 220 |
+
if sampling_rate != self.sampling_rate:
|
| 221 |
+
raise ValueError(
|
| 222 |
+
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
| 223 |
+
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
|
| 224 |
+
f" {self.sampling_rate} and not {sampling_rate}."
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
logger.warning(
|
| 228 |
+
"It is strongly recommended to pass the ``sampling_rate`` argument to this function. "
|
| 229 |
+
"Failing to do so can result in silent errors that might be hard to debug."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
| 233 |
+
if is_batched_numpy and len(raw_speech.shape) > 2:
|
| 234 |
+
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
|
| 235 |
+
is_batched = is_batched_numpy or (
|
| 236 |
+
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
if is_batched:
|
| 240 |
+
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
|
| 241 |
+
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
| 242 |
+
raw_speech = np.asarray(raw_speech, dtype=np.float32)
|
| 243 |
+
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
|
| 244 |
+
raw_speech = raw_speech.astype(np.float32)
|
| 245 |
+
|
| 246 |
+
# always return batch
|
| 247 |
+
if not is_batched:
|
| 248 |
+
raw_speech = [raw_speech]
|
| 249 |
+
|
| 250 |
+
# extract fbank features
|
| 251 |
+
features = [self._extract_mfsc_features(one_waveform) for one_waveform in raw_speech]
|
| 252 |
+
|
| 253 |
+
# convert into correct format for padding
|
| 254 |
+
encoded_inputs = BatchFeature({"input_features": features})
|
| 255 |
+
|
| 256 |
+
padded_inputs = self.pad(
|
| 257 |
+
encoded_inputs,
|
| 258 |
+
padding=padding,
|
| 259 |
+
max_length=max_length,
|
| 260 |
+
truncation=truncation,
|
| 261 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 262 |
+
return_attention_mask=True,
|
| 263 |
+
**kwargs,
|
| 264 |
+
)
|
| 265 |
+
# make sure list is in array format
|
| 266 |
+
input_features = padded_inputs.get("input_features")
|
| 267 |
+
if isinstance(input_features[0], list):
|
| 268 |
+
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
| 269 |
+
|
| 270 |
+
attention_mask = padded_inputs.get("attention_mask")
|
| 271 |
+
if attention_mask is not None:
|
| 272 |
+
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
|
| 273 |
+
|
| 274 |
+
if self.normalize_means or self.normalize_vars:
|
| 275 |
+
attention_mask = (
|
| 276 |
+
np.array(attention_mask, dtype=np.int32)
|
| 277 |
+
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
|
| 278 |
+
and padding
|
| 279 |
+
else None
|
| 280 |
+
)
|
| 281 |
+
padded_inputs["input_features"] = self.normalize(
|
| 282 |
+
padded_inputs["input_features"], attention_mask=attention_mask
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if return_tensors is not None:
|
| 286 |
+
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
| 287 |
+
|
| 288 |
+
return padded_inputs
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
__all__ = ["MCTCTFeatureExtractor"]
|
docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py
ADDED
|
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch M-CTC-T model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
|
| 24 |
+
from ....activations import ACT2FN
|
| 25 |
+
from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
| 26 |
+
from ....integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 27 |
+
from ....integrations.fsdp import is_fsdp_managed_module
|
| 28 |
+
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 29 |
+
from ....modeling_outputs import BaseModelOutput, CausalLMOutput
|
| 30 |
+
from ....modeling_utils import (
|
| 31 |
+
PreTrainedModel,
|
| 32 |
+
apply_chunking_to_forward,
|
| 33 |
+
find_pruneable_heads_and_indices,
|
| 34 |
+
prune_linear_layer,
|
| 35 |
+
)
|
| 36 |
+
from ....utils import logging
|
| 37 |
+
from .configuration_mctct import MCTCTConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
_HIDDEN_STATES_START_POSITION = 1
|
| 43 |
+
|
| 44 |
+
_CONFIG_FOR_DOC = "MCTCTConfig"
|
| 45 |
+
|
| 46 |
+
# Base docstring
|
| 47 |
+
_CHECKPOINT_FOR_DOC = "speechbrain/m-ctc-t-large"
|
| 48 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 195, 1536]
|
| 49 |
+
|
| 50 |
+
# CTC docstring
|
| 51 |
+
_CTC_EXPECTED_OUTPUT = '"Mr. Quilter is the apostle of the middle classes, and we\'re glad to welcome his gospel."'
|
| 52 |
+
_CTC_EXPECTED_LOSS = 1885.65
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class MCTCTConv1dSubsampler(nn.Module):
|
| 56 |
+
"""
|
| 57 |
+
Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation
|
| 58 |
+
via gated linear units (https://arxiv.org/abs/1911.08460)
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, config):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.config = config
|
| 64 |
+
self.glu_dim = config.conv_glu_dim
|
| 65 |
+
|
| 66 |
+
self.dropout = nn.Dropout(config.conv_dropout)
|
| 67 |
+
|
| 68 |
+
self.num_layers = config.num_conv_layers
|
| 69 |
+
self.in_channels = config.input_feat_per_channel * config.input_channels
|
| 70 |
+
|
| 71 |
+
if self.num_layers > 1:
|
| 72 |
+
if config.conv_channels is None:
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"Need to specify `conv_channels` configuration in `MCTCTConfig` to use multiple convolution"
|
| 75 |
+
" layers."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.mid_channels = config.conv_channels
|
| 79 |
+
else:
|
| 80 |
+
self.mid_channels = None
|
| 81 |
+
|
| 82 |
+
self.out_channels = config.hidden_size * 2 # considering GLU halving
|
| 83 |
+
self.kernel_size = config.conv_kernel
|
| 84 |
+
self.stride = config.conv_stride
|
| 85 |
+
|
| 86 |
+
# NOTE: MCTCT by construction only uses one convolution kernel. I've made this flexible to allow for
|
| 87 |
+
# multiple layers of convolutions, but not sure if this model definition should just restrict it
|
| 88 |
+
# to one layer. This becomes especially relevant when considering the padding like line 1 of forward().
|
| 89 |
+
self.conv_layers = nn.ModuleList(
|
| 90 |
+
nn.Conv1d(
|
| 91 |
+
self.in_channels if i == 0 else self.mid_channels[i],
|
| 92 |
+
self.mid_channels[i] if i < self.num_layers - 1 else self.out_channels,
|
| 93 |
+
kernel_size=k,
|
| 94 |
+
stride=self.stride[i],
|
| 95 |
+
padding="valid",
|
| 96 |
+
)
|
| 97 |
+
for i, k in enumerate(self.kernel_size)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def forward(self, input_features):
|
| 101 |
+
# NOTE: in reference to the NOTE in __init__, right now it just calculates padding as if
|
| 102 |
+
# there will be just one conv layer.
|
| 103 |
+
padding = sum([size // 2 for size in self.kernel_size]) # (7, 7) -> (3, 3)
|
| 104 |
+
|
| 105 |
+
input_features = torch.nn.functional.pad(input_features, (0, 0, padding, padding), "constant", 0)
|
| 106 |
+
hidden_states = input_features.transpose(1, 2).contiguous() # -> Batch x Frame x Time
|
| 107 |
+
for conv in self.conv_layers:
|
| 108 |
+
hidden_states = conv(hidden_states)
|
| 109 |
+
hidden_states = nn.functional.glu(hidden_states, dim=self.glu_dim)
|
| 110 |
+
hidden_states = self.dropout(hidden_states)
|
| 111 |
+
|
| 112 |
+
hidden_states = hidden_states.transpose(1, 2).contiguous() # -> Batch x Time x Frame
|
| 113 |
+
return hidden_states
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class MCTCTEmbeddings(nn.Module):
|
| 117 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 118 |
+
|
| 119 |
+
def __init__(self, config):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 122 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 123 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 124 |
+
|
| 125 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 126 |
+
# any TensorFlow checkpoint file
|
| 127 |
+
# self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 128 |
+
self.LayerNorm = MCTCTLayerNorm()
|
| 129 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 130 |
+
|
| 131 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 132 |
+
self.register_buffer(
|
| 133 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 134 |
+
)
|
| 135 |
+
self.register_buffer(
|
| 136 |
+
"token_type_ids",
|
| 137 |
+
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
| 138 |
+
persistent=False,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(
|
| 142 |
+
self, input_features=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 143 |
+
):
|
| 144 |
+
input_shape = input_features.size() if input_features is not None else inputs_embeds.size()[:-1]
|
| 145 |
+
|
| 146 |
+
seq_length = input_shape[1]
|
| 147 |
+
|
| 148 |
+
if position_ids is None:
|
| 149 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 150 |
+
|
| 151 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 152 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 153 |
+
# issue #5664
|
| 154 |
+
if token_type_ids is None:
|
| 155 |
+
if hasattr(self, "token_type_ids"):
|
| 156 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 157 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 158 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 159 |
+
else:
|
| 160 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 161 |
+
|
| 162 |
+
if inputs_embeds is None:
|
| 163 |
+
inputs_embeds = self.word_embeddings(input_features)
|
| 164 |
+
|
| 165 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 166 |
+
|
| 167 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 168 |
+
|
| 169 |
+
embeddings = self.LayerNorm(embeddings)
|
| 170 |
+
embeddings = self.dropout(embeddings)
|
| 171 |
+
return embeddings
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class MCTCTSelfAttention(nn.Module):
|
| 175 |
+
def __init__(self, config):
|
| 176 |
+
super().__init__()
|
| 177 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 178 |
+
raise ValueError(
|
| 179 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 180 |
+
f"heads ({config.num_attention_heads})"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self.num_attention_heads = config.num_attention_heads
|
| 184 |
+
self.attention_head_size = config.attention_head_dim
|
| 185 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 186 |
+
|
| 187 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
| 188 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
| 189 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
| 190 |
+
|
| 191 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 192 |
+
|
| 193 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 194 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 195 |
+
|
| 196 |
+
self.is_decoder = config.is_decoder
|
| 197 |
+
|
| 198 |
+
def transpose_for_scores(self, x):
|
| 199 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 200 |
+
x = x.view(*new_x_shape)
|
| 201 |
+
return x.permute(0, 2, 1, 3)
|
| 202 |
+
|
| 203 |
+
def reshape_fortran(self, x, shape):
|
| 204 |
+
if len(x.shape) > 0:
|
| 205 |
+
x = x.permute(*reversed(range(len(x.shape))))
|
| 206 |
+
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
|
| 207 |
+
|
| 208 |
+
def relative_position_embedding_rotate(self, scores):
|
| 209 |
+
# NOTE: should re-evaluate whether this re-implementation was truly necessary
|
| 210 |
+
# or the reason why my complete re-haul worked was due to some other part
|
| 211 |
+
# of the code. Adding this and the reshape fortrain code seems very undesirable.
|
| 212 |
+
scores = scores.permute(0, 2, 3, 1) # e.g. [10, 1839, 14, 4]
|
| 213 |
+
|
| 214 |
+
batch, hidden_state, seq_len, heads = scores.shape
|
| 215 |
+
|
| 216 |
+
# e.g. [10, 1853, 14, 4]
|
| 217 |
+
scores = torch.cat((scores, torch.zeros((batch, seq_len, seq_len, heads), device=scores.device)), dim=1)
|
| 218 |
+
|
| 219 |
+
# e.g. [10, 25942, 1, 4]
|
| 220 |
+
scores = self.reshape_fortran(scores, [batch, (hidden_state + seq_len) * seq_len, 1, heads])
|
| 221 |
+
|
| 222 |
+
# e.g. [10, 25928, 1, 4]
|
| 223 |
+
scores = scores[:, : (seq_len + hidden_state - 1) * seq_len]
|
| 224 |
+
|
| 225 |
+
# e.g. [10, 1852, 14, 4]
|
| 226 |
+
scores = self.reshape_fortran(scores, [batch, hidden_state + seq_len - 1, seq_len, heads])
|
| 227 |
+
|
| 228 |
+
halfpoint = hidden_state // 2
|
| 229 |
+
scores = scores[:, halfpoint : halfpoint + seq_len].transpose(1, 2) # e.g. [10, 14, 14, 4]
|
| 230 |
+
|
| 231 |
+
return scores.permute(0, 3, 1, 2)
|
| 232 |
+
|
| 233 |
+
def forward(
|
| 234 |
+
self,
|
| 235 |
+
hidden_states,
|
| 236 |
+
attention_mask=None,
|
| 237 |
+
head_mask=None,
|
| 238 |
+
output_attentions=False,
|
| 239 |
+
):
|
| 240 |
+
mixed_query_layer = self.query(hidden_states)
|
| 241 |
+
mixed_query_layer = mixed_query_layer / math.sqrt(self.attention_head_size)
|
| 242 |
+
|
| 243 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 244 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 245 |
+
|
| 246 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 247 |
+
|
| 248 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 249 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 250 |
+
|
| 251 |
+
# relative key position embeddings
|
| 252 |
+
positional_embedding = self.distance_embedding.weight
|
| 253 |
+
relative_position_scores = torch.einsum("lh, bche -> bcle", positional_embedding, query_layer.transpose(2, 3))
|
| 254 |
+
|
| 255 |
+
relative_position_scores = self.relative_position_embedding_rotate(relative_position_scores)
|
| 256 |
+
attention_scores = attention_scores + relative_position_scores
|
| 257 |
+
|
| 258 |
+
if attention_mask is not None:
|
| 259 |
+
# Apply the attention mask is (precomputed for all layers in MCTCTModel forward() function)
|
| 260 |
+
attention_scores = attention_scores + attention_mask
|
| 261 |
+
|
| 262 |
+
# Normalize the attention scores to probabilities.
|
| 263 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 264 |
+
|
| 265 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 266 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 267 |
+
attention_probs = self.dropout(attention_probs)
|
| 268 |
+
|
| 269 |
+
# Mask heads if we want to
|
| 270 |
+
if head_mask is not None:
|
| 271 |
+
attention_probs = attention_probs * head_mask
|
| 272 |
+
|
| 273 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 274 |
+
|
| 275 |
+
context_layer = context_layer.permute(0, 2, 1, 3).flatten(start_dim=-2)
|
| 276 |
+
|
| 277 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 278 |
+
|
| 279 |
+
return outputs
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class MCTCTLayerNorm(nn.Module):
|
| 283 |
+
def __init__(self):
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.singleton_weight = nn.Parameter(torch.ones(1))
|
| 286 |
+
self.singleton_bias = nn.Parameter(torch.zeros(1))
|
| 287 |
+
|
| 288 |
+
def forward(self, hidden_states):
|
| 289 |
+
return (hidden_states * self.singleton_weight) + self.singleton_bias
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class MCTCTSelfOutput(nn.Module):
|
| 293 |
+
def __init__(self, config):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.config = config
|
| 296 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
| 297 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 298 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 299 |
+
|
| 300 |
+
def forward(self, hidden_states, input_tensor):
|
| 301 |
+
hidden_states = self.dense(hidden_states)
|
| 302 |
+
hidden_states = self.dropout(hidden_states)
|
| 303 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 304 |
+
return hidden_states
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class MCTCTAttention(nn.Module):
|
| 308 |
+
def __init__(self, config):
|
| 309 |
+
super().__init__()
|
| 310 |
+
self.self = MCTCTSelfAttention(config)
|
| 311 |
+
self.output = MCTCTSelfOutput(config)
|
| 312 |
+
self.pruned_heads = set()
|
| 313 |
+
|
| 314 |
+
def prune_heads(self, heads):
|
| 315 |
+
if len(heads) == 0:
|
| 316 |
+
return
|
| 317 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 318 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Prune linear layers
|
| 322 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 323 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 324 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 325 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 326 |
+
|
| 327 |
+
# Update hyper params and store pruned heads
|
| 328 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 329 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 330 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
hidden_states,
|
| 335 |
+
attention_mask=None,
|
| 336 |
+
head_mask=None,
|
| 337 |
+
output_attentions=False,
|
| 338 |
+
):
|
| 339 |
+
self_outputs = self.self(
|
| 340 |
+
hidden_states,
|
| 341 |
+
attention_mask,
|
| 342 |
+
head_mask,
|
| 343 |
+
output_attentions,
|
| 344 |
+
)
|
| 345 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 346 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 347 |
+
|
| 348 |
+
return outputs
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class MCTCTIntermediate(nn.Module):
|
| 352 |
+
def __init__(self, config):
|
| 353 |
+
super().__init__()
|
| 354 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 355 |
+
if isinstance(config.hidden_act, str):
|
| 356 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 357 |
+
else:
|
| 358 |
+
self.intermediate_act_fn = config.hidden_act
|
| 359 |
+
|
| 360 |
+
def forward(self, hidden_states):
|
| 361 |
+
hidden_states = self.dense(hidden_states)
|
| 362 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 363 |
+
return hidden_states
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class MCTCTOutput(nn.Module):
|
| 367 |
+
def __init__(self, config):
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 370 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 371 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 372 |
+
|
| 373 |
+
def forward(self, hidden_states, input_tensor):
|
| 374 |
+
hidden_states = self.dense(hidden_states)
|
| 375 |
+
hidden_states = self.dropout(hidden_states)
|
| 376 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 377 |
+
return hidden_states
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class MCTCTLayer(nn.Module):
|
| 381 |
+
def __init__(self, config: MCTCTConfig):
|
| 382 |
+
super().__init__()
|
| 383 |
+
|
| 384 |
+
self.seq_len_dim = 1
|
| 385 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 386 |
+
|
| 387 |
+
self.intermediate = MCTCTIntermediate(config)
|
| 388 |
+
self.attention = MCTCTAttention(config)
|
| 389 |
+
self.is_decoder = config.is_decoder
|
| 390 |
+
self.output = MCTCTOutput(config)
|
| 391 |
+
|
| 392 |
+
def forward(
|
| 393 |
+
self,
|
| 394 |
+
hidden_states,
|
| 395 |
+
attention_mask=None,
|
| 396 |
+
head_mask=None,
|
| 397 |
+
output_attentions=False,
|
| 398 |
+
):
|
| 399 |
+
self_attention_outputs = self.attention(
|
| 400 |
+
hidden_states, attention_mask, head_mask, output_attentions=output_attentions
|
| 401 |
+
)
|
| 402 |
+
attention_output = self_attention_outputs[0]
|
| 403 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 404 |
+
|
| 405 |
+
layer_output = apply_chunking_to_forward(
|
| 406 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
outputs = (layer_output,) + outputs
|
| 410 |
+
|
| 411 |
+
return outputs
|
| 412 |
+
|
| 413 |
+
def feed_forward_chunk(self, attention_output):
|
| 414 |
+
intermediate_output = self.intermediate(attention_output)
|
| 415 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 416 |
+
return layer_output
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class MCTCTPreTrainedModel(PreTrainedModel):
|
| 420 |
+
"""
|
| 421 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 422 |
+
models.
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
config_class = MCTCTConfig
|
| 426 |
+
base_model_prefix = "mctct"
|
| 427 |
+
main_input_name = "input_features"
|
| 428 |
+
supports_gradient_checkpointing = True
|
| 429 |
+
|
| 430 |
+
def _init_weights(self, module):
|
| 431 |
+
"""Initialize the weights"""
|
| 432 |
+
std = self.config.initializer_range
|
| 433 |
+
if isinstance(module, nn.Linear):
|
| 434 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 435 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 436 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 437 |
+
if module.bias is not None:
|
| 438 |
+
module.bias.data.zero_()
|
| 439 |
+
elif isinstance(module, nn.Embedding):
|
| 440 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 441 |
+
if module.padding_idx is not None:
|
| 442 |
+
module.weight.data[module.padding_idx].zero_()
|
| 443 |
+
elif isinstance(module, nn.LayerNorm):
|
| 444 |
+
module.bias.data.zero_()
|
| 445 |
+
module.weight.data.fill_(1.0)
|
| 446 |
+
elif isinstance(module, MCTCTLayerNorm):
|
| 447 |
+
module.singleton_weight.data.fill_(1.0)
|
| 448 |
+
module.singleton_bias.data.zero_()
|
| 449 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
| 450 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 451 |
+
if module.bias is not None:
|
| 452 |
+
module.bias.data.zero_()
|
| 453 |
+
|
| 454 |
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
| 455 |
+
"""
|
| 456 |
+
Computes the output length of the convolutional layers
|
| 457 |
+
"""
|
| 458 |
+
dilation = 1
|
| 459 |
+
for _, kernel_sz, stride in zip(
|
| 460 |
+
range(self.config.num_conv_layers), self.config.conv_kernel, self.config.conv_stride
|
| 461 |
+
):
|
| 462 |
+
padding = kernel_sz // 2
|
| 463 |
+
input_lengths = input_lengths + 2 * padding - dilation * (kernel_sz - 1) - 1
|
| 464 |
+
input_lengths = torch.div(input_lengths, stride, rounding_mode="trunc") + 1
|
| 465 |
+
|
| 466 |
+
return input_lengths
|
| 467 |
+
|
| 468 |
+
def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask):
|
| 469 |
+
# generate creates 3D attention mask, because of the shape of input_features
|
| 470 |
+
# convert it to 2D if thats the case
|
| 471 |
+
if len(attention_mask.shape) > 2:
|
| 472 |
+
attention_mask = attention_mask[:, :, -1]
|
| 473 |
+
|
| 474 |
+
# subsampled_lengths = attention_mask.sum(-1)
|
| 475 |
+
subsampled_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
| 476 |
+
bsz = attention_mask.size()[0]
|
| 477 |
+
attention_mask = torch.zeros(
|
| 478 |
+
(bsz, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# these two operations makes sure that all values
|
| 482 |
+
# before the output lengths indices are attended to
|
| 483 |
+
attention_mask[(torch.arange(bsz, device=attention_mask.device), subsampled_lengths - 1)] = 1
|
| 484 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long()
|
| 485 |
+
return attention_mask
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
MCTCT_START_DOCSTRING = r"""
|
| 489 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 490 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 491 |
+
behavior.
|
| 492 |
+
|
| 493 |
+
Parameters:
|
| 494 |
+
config ([`MCTCTConfig`]): Model configuration class with all the parameters of the model.
|
| 495 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 496 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
MCTCT_INPUTS_DOCSTRING = r"""
|
| 500 |
+
Args:
|
| 501 |
+
input_features (`torch.LongTensor` of shape `({0})`):
|
| 502 |
+
Indices of input sequence tokens in the vocabulary.
|
| 503 |
+
|
| 504 |
+
Indices can be obtained using [`Wav2Vec2CTCTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 505 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 506 |
+
|
| 507 |
+
[What are input IDs?](../glossary#input-ids)
|
| 508 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 509 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 510 |
+
|
| 511 |
+
- 1 for tokens that are **not masked**,
|
| 512 |
+
- 0 for tokens that are **masked**.
|
| 513 |
+
|
| 514 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 515 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 516 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 517 |
+
|
| 518 |
+
- 1 indicates the head is **not masked**,
|
| 519 |
+
- 0 indicates the head is **masked**.
|
| 520 |
+
output_attentions (`bool`, *optional*):
|
| 521 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 522 |
+
tensors for more detail.
|
| 523 |
+
output_hidden_states (`bool`, *optional*):
|
| 524 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 525 |
+
more detail.
|
| 526 |
+
return_dict (`bool`, *optional*):
|
| 527 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class MCTCTEncoder(MCTCTPreTrainedModel):
|
| 532 |
+
def __init__(self, config: MCTCTConfig):
|
| 533 |
+
super().__init__(config)
|
| 534 |
+
self.hidden_dropout_prob = config.hidden_dropout_prob
|
| 535 |
+
|
| 536 |
+
self.layer_norm = MCTCTLayerNorm()
|
| 537 |
+
self.conv = MCTCTConv1dSubsampler(config)
|
| 538 |
+
self.layers = nn.ModuleList([MCTCTLayer(config) for _ in range(config.num_hidden_layers)])
|
| 539 |
+
|
| 540 |
+
self.gradient_checkpointing = False
|
| 541 |
+
|
| 542 |
+
def forward(
|
| 543 |
+
self,
|
| 544 |
+
input_features: torch.Tensor,
|
| 545 |
+
attention_mask: torch.Tensor,
|
| 546 |
+
head_mask: torch.Tensor,
|
| 547 |
+
output_attentions: bool = False,
|
| 548 |
+
output_hidden_states: bool = False,
|
| 549 |
+
return_dict: bool = True,
|
| 550 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 551 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 552 |
+
output_hidden_states = (
|
| 553 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 554 |
+
)
|
| 555 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 556 |
+
|
| 557 |
+
input_features = self.layer_norm(input_features)
|
| 558 |
+
|
| 559 |
+
inputs_embeds = self.conv(input_features)
|
| 560 |
+
|
| 561 |
+
# subsample attention mask if necessary
|
| 562 |
+
if attention_mask is not None:
|
| 563 |
+
attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
|
| 564 |
+
|
| 565 |
+
hidden_states = nn.functional.dropout(inputs_embeds, p=self.hidden_dropout_prob, training=self.training)
|
| 566 |
+
|
| 567 |
+
# expand attention_mask
|
| 568 |
+
if attention_mask is not None:
|
| 569 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 570 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
| 571 |
+
|
| 572 |
+
encoder_states = () if output_hidden_states else None
|
| 573 |
+
all_attentions = () if output_attentions else None
|
| 574 |
+
|
| 575 |
+
# check if head_mask has a correct number of layers specified if desired
|
| 576 |
+
if head_mask is not None:
|
| 577 |
+
if head_mask.size()[0] != len(self.layers):
|
| 578 |
+
raise ValueError(
|
| 579 |
+
f"The head_mask should be specified for {len(self.layers)} layers, "
|
| 580 |
+
f"but it is for {head_mask.size()[0]}."
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
| 584 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 585 |
+
if output_hidden_states:
|
| 586 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 587 |
+
|
| 588 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 589 |
+
dropout_probability = torch.rand([])
|
| 590 |
+
|
| 591 |
+
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
| 592 |
+
if not skip_the_layer or synced_gpus:
|
| 593 |
+
# under fsdp or deepspeed zero3 all gpus must run in sync
|
| 594 |
+
if self.gradient_checkpointing and self.training:
|
| 595 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 596 |
+
encoder_layer.__call__,
|
| 597 |
+
hidden_states,
|
| 598 |
+
attention_mask,
|
| 599 |
+
(head_mask[idx] if head_mask is not None else None),
|
| 600 |
+
output_attentions,
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
layer_outputs = encoder_layer(
|
| 604 |
+
hidden_states=hidden_states,
|
| 605 |
+
attention_mask=attention_mask,
|
| 606 |
+
output_attentions=output_attentions,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
hidden_states = layer_outputs[0]
|
| 610 |
+
|
| 611 |
+
if skip_the_layer:
|
| 612 |
+
layer_outputs = (None, None)
|
| 613 |
+
|
| 614 |
+
if output_attentions:
|
| 615 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 616 |
+
|
| 617 |
+
if output_hidden_states:
|
| 618 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 619 |
+
|
| 620 |
+
if not return_dict:
|
| 621 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 622 |
+
return BaseModelOutput(
|
| 623 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
@add_start_docstrings(
|
| 628 |
+
"The bare M-CTC-T Model transformer outputting raw hidden-states without any specific head on top.",
|
| 629 |
+
MCTCT_START_DOCSTRING,
|
| 630 |
+
)
|
| 631 |
+
class MCTCTModel(MCTCTPreTrainedModel):
|
| 632 |
+
def __init__(self, config):
|
| 633 |
+
super().__init__(config)
|
| 634 |
+
self.config = config
|
| 635 |
+
|
| 636 |
+
self.encoder = MCTCTEncoder(config)
|
| 637 |
+
|
| 638 |
+
# Initialize weights and apply final processing
|
| 639 |
+
self.post_init()
|
| 640 |
+
|
| 641 |
+
@add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 642 |
+
@add_code_sample_docstrings(
|
| 643 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 644 |
+
output_type=BaseModelOutput,
|
| 645 |
+
config_class=_CONFIG_FOR_DOC,
|
| 646 |
+
modality="audio",
|
| 647 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 648 |
+
)
|
| 649 |
+
def forward(
|
| 650 |
+
self,
|
| 651 |
+
input_features: torch.Tensor,
|
| 652 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 653 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 654 |
+
output_attentions: Optional[bool] = None,
|
| 655 |
+
output_hidden_states: Optional[bool] = None,
|
| 656 |
+
return_dict: Optional[bool] = None,
|
| 657 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 658 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 659 |
+
output_hidden_states = (
|
| 660 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 661 |
+
)
|
| 662 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 663 |
+
|
| 664 |
+
if input_features is None:
|
| 665 |
+
raise ValueError("You have to specify input_features.")
|
| 666 |
+
|
| 667 |
+
encoder_outputs = self.encoder(
|
| 668 |
+
input_features,
|
| 669 |
+
attention_mask=attention_mask,
|
| 670 |
+
head_mask=head_mask,
|
| 671 |
+
output_attentions=output_attentions,
|
| 672 |
+
output_hidden_states=output_hidden_states,
|
| 673 |
+
return_dict=return_dict,
|
| 674 |
+
)
|
| 675 |
+
sequence_output = encoder_outputs[0]
|
| 676 |
+
|
| 677 |
+
if not return_dict:
|
| 678 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 679 |
+
|
| 680 |
+
return BaseModelOutput(
|
| 681 |
+
last_hidden_state=sequence_output,
|
| 682 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 683 |
+
attentions=encoder_outputs.attentions,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
@add_start_docstrings(
|
| 688 |
+
"""MCTCT Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
| 689 |
+
MCTCT_START_DOCSTRING,
|
| 690 |
+
)
|
| 691 |
+
class MCTCTForCTC(MCTCTPreTrainedModel):
|
| 692 |
+
def __init__(self, config):
|
| 693 |
+
super().__init__(config)
|
| 694 |
+
|
| 695 |
+
self.mctct = MCTCTModel(config)
|
| 696 |
+
|
| 697 |
+
if config.vocab_size is None:
|
| 698 |
+
raise ValueError(
|
| 699 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
| 700 |
+
"does not define the vocabulary size of the language model head. Please "
|
| 701 |
+
"instantiate the model as follows: `MCTCTForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
| 702 |
+
"or define `vocab_size` of your model's configuration."
|
| 703 |
+
)
|
| 704 |
+
output_hidden_size = config.hidden_size
|
| 705 |
+
|
| 706 |
+
self.ctc_head = nn.Linear(output_hidden_size, config.vocab_size)
|
| 707 |
+
|
| 708 |
+
# Initialize weights and apply final processing
|
| 709 |
+
self.post_init()
|
| 710 |
+
|
| 711 |
+
@add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING)
|
| 712 |
+
@add_code_sample_docstrings(
|
| 713 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 714 |
+
output_type=CausalLMOutput,
|
| 715 |
+
config_class=_CONFIG_FOR_DOC,
|
| 716 |
+
expected_output=_CTC_EXPECTED_OUTPUT,
|
| 717 |
+
expected_loss=_CTC_EXPECTED_LOSS,
|
| 718 |
+
)
|
| 719 |
+
def forward(
|
| 720 |
+
self,
|
| 721 |
+
input_features: torch.Tensor,
|
| 722 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 723 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 724 |
+
output_attentions: Optional[bool] = None,
|
| 725 |
+
output_hidden_states: Optional[bool] = None,
|
| 726 |
+
return_dict: Optional[bool] = None,
|
| 727 |
+
labels: Optional[torch.LongTensor] = None,
|
| 728 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 729 |
+
r"""
|
| 730 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 731 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 732 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 733 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 734 |
+
config.vocab_size - 1]`.
|
| 735 |
+
"""
|
| 736 |
+
if labels is not None and labels.max() >= self.config.vocab_size:
|
| 737 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 738 |
+
|
| 739 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 740 |
+
outputs = self.mctct(
|
| 741 |
+
input_features,
|
| 742 |
+
attention_mask=attention_mask,
|
| 743 |
+
head_mask=head_mask,
|
| 744 |
+
output_attentions=output_attentions,
|
| 745 |
+
output_hidden_states=output_hidden_states,
|
| 746 |
+
return_dict=return_dict,
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
hidden_states = outputs[0]
|
| 750 |
+
|
| 751 |
+
logits = self.ctc_head(hidden_states)
|
| 752 |
+
|
| 753 |
+
loss = None
|
| 754 |
+
if labels is not None:
|
| 755 |
+
# retrieve loss input_lengths from attention_mask
|
| 756 |
+
attention_mask = (
|
| 757 |
+
attention_mask
|
| 758 |
+
if attention_mask is not None
|
| 759 |
+
else torch.ones(input_features.shape[:-1], dtype=torch.long)
|
| 760 |
+
)
|
| 761 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 762 |
+
# assuming that padded tokens are filled with -100
|
| 763 |
+
# when not being attended to
|
| 764 |
+
labels_mask = labels >= 0
|
| 765 |
+
target_lengths = labels_mask.sum(-1)
|
| 766 |
+
flattened_targets = labels.masked_select(labels_mask)
|
| 767 |
+
|
| 768 |
+
# ctc_loss doesn't support fp16
|
| 769 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 770 |
+
|
| 771 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 772 |
+
loss = nn.functional.ctc_loss(
|
| 773 |
+
log_probs,
|
| 774 |
+
flattened_targets,
|
| 775 |
+
input_lengths,
|
| 776 |
+
target_lengths,
|
| 777 |
+
blank=self.config.pad_token_id,
|
| 778 |
+
reduction=self.config.ctc_loss_reduction,
|
| 779 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
if not return_dict:
|
| 783 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 784 |
+
return ((loss,) + output) if loss is not None else output
|
| 785 |
+
|
| 786 |
+
return CausalLMOutput(
|
| 787 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
__all__ = ["MCTCTForCTC", "MCTCTModel", "MCTCTPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Speech processor class for M-CTC-T
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import warnings
|
| 20 |
+
from contextlib import contextmanager
|
| 21 |
+
|
| 22 |
+
from ....processing_utils import ProcessorMixin
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MCTCTProcessor(ProcessorMixin):
|
| 26 |
+
r"""
|
| 27 |
+
Constructs a MCTCT processor which wraps a MCTCT feature extractor and a MCTCT tokenizer into a single processor.
|
| 28 |
+
|
| 29 |
+
[`MCTCTProcessor`] offers all the functionalities of [`MCTCTFeatureExtractor`] and [`AutoTokenizer`]. See the
|
| 30 |
+
[`~MCTCTProcessor.__call__`] and [`~MCTCTProcessor.decode`] for more information.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
feature_extractor (`MCTCTFeatureExtractor`):
|
| 34 |
+
An instance of [`MCTCTFeatureExtractor`]. The feature extractor is a required input.
|
| 35 |
+
tokenizer (`AutoTokenizer`):
|
| 36 |
+
An instance of [`AutoTokenizer`]. The tokenizer is a required input.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
feature_extractor_class = "MCTCTFeatureExtractor"
|
| 40 |
+
tokenizer_class = "AutoTokenizer"
|
| 41 |
+
|
| 42 |
+
def __init__(self, feature_extractor, tokenizer):
|
| 43 |
+
super().__init__(feature_extractor, tokenizer)
|
| 44 |
+
self.current_processor = self.feature_extractor
|
| 45 |
+
self._in_target_context_manager = False
|
| 46 |
+
|
| 47 |
+
def __call__(self, *args, **kwargs):
|
| 48 |
+
"""
|
| 49 |
+
When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
|
| 50 |
+
[`~MCTCTFeatureExtractor.__call__`] and returns its output. If used in the context
|
| 51 |
+
[`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's
|
| 52 |
+
[`~AutoTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information.
|
| 53 |
+
"""
|
| 54 |
+
# For backward compatibility
|
| 55 |
+
if self._in_target_context_manager:
|
| 56 |
+
return self.current_processor(*args, **kwargs)
|
| 57 |
+
|
| 58 |
+
if "raw_speech" in kwargs:
|
| 59 |
+
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
|
| 60 |
+
audio = kwargs.pop("raw_speech")
|
| 61 |
+
else:
|
| 62 |
+
audio = kwargs.pop("audio", None)
|
| 63 |
+
sampling_rate = kwargs.pop("sampling_rate", None)
|
| 64 |
+
text = kwargs.pop("text", None)
|
| 65 |
+
if len(args) > 0:
|
| 66 |
+
audio = args[0]
|
| 67 |
+
args = args[1:]
|
| 68 |
+
|
| 69 |
+
if audio is None and text is None:
|
| 70 |
+
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
| 71 |
+
|
| 72 |
+
if audio is not None:
|
| 73 |
+
inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
|
| 74 |
+
if text is not None:
|
| 75 |
+
encodings = self.tokenizer(text, **kwargs)
|
| 76 |
+
|
| 77 |
+
if text is None:
|
| 78 |
+
return inputs
|
| 79 |
+
elif audio is None:
|
| 80 |
+
return encodings
|
| 81 |
+
else:
|
| 82 |
+
inputs["labels"] = encodings["input_ids"]
|
| 83 |
+
return inputs
|
| 84 |
+
|
| 85 |
+
def batch_decode(self, *args, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
|
| 88 |
+
to the docstring of this method for more information.
|
| 89 |
+
"""
|
| 90 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 91 |
+
|
| 92 |
+
def pad(self, *args, **kwargs):
|
| 93 |
+
"""
|
| 94 |
+
When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
|
| 95 |
+
[`~MCTCTFeatureExtractor.pad`] and returns its output. If used in the context
|
| 96 |
+
[`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
|
| 97 |
+
[`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
|
| 98 |
+
"""
|
| 99 |
+
# For backward compatibility
|
| 100 |
+
if self._in_target_context_manager:
|
| 101 |
+
return self.current_processor.pad(*args, **kwargs)
|
| 102 |
+
|
| 103 |
+
input_features = kwargs.pop("input_features", None)
|
| 104 |
+
labels = kwargs.pop("labels", None)
|
| 105 |
+
if len(args) > 0:
|
| 106 |
+
input_features = args[0]
|
| 107 |
+
args = args[1:]
|
| 108 |
+
|
| 109 |
+
if input_features is not None:
|
| 110 |
+
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
|
| 111 |
+
if labels is not None:
|
| 112 |
+
labels = self.tokenizer.pad(labels, **kwargs)
|
| 113 |
+
|
| 114 |
+
if labels is None:
|
| 115 |
+
return input_features
|
| 116 |
+
elif input_features is None:
|
| 117 |
+
return labels
|
| 118 |
+
else:
|
| 119 |
+
input_features["labels"] = labels["input_ids"]
|
| 120 |
+
return input_features
|
| 121 |
+
|
| 122 |
+
def decode(self, *args, **kwargs):
|
| 123 |
+
"""
|
| 124 |
+
This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
|
| 125 |
+
docstring of this method for more information.
|
| 126 |
+
"""
|
| 127 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 128 |
+
|
| 129 |
+
@contextmanager
|
| 130 |
+
def as_target_processor(self):
|
| 131 |
+
"""
|
| 132 |
+
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT.
|
| 133 |
+
"""
|
| 134 |
+
warnings.warn(
|
| 135 |
+
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
|
| 136 |
+
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
|
| 137 |
+
"your audio inputs, or in a separate call."
|
| 138 |
+
)
|
| 139 |
+
self._in_target_context_manager = True
|
| 140 |
+
self.current_processor = self.tokenizer
|
| 141 |
+
yield
|
| 142 |
+
self.current_processor = self.feature_extractor
|
| 143 |
+
self._in_target_context_manager = False
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
__all__ = ["MCTCTProcessor"]
|
docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_mega import *
|
| 22 |
+
from .modeling_mega import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Mega Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""MEGA configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import Mapping
|
| 19 |
+
|
| 20 |
+
from ....configuration_utils import PretrainedConfig
|
| 21 |
+
from ....onnx import OnnxConfig
|
| 22 |
+
from ....utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MegaConfig(PretrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
This is the configuration class to store the configuration of a [`MegaModel`]. It is used to instantiate a Mega
|
| 31 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 32 |
+
defaults will yield a similar configuration to that of the Mega
|
| 33 |
+
[mnaylor/mega-base-wikitext](https://huggingface.co/mnaylor/mega-base-wikitext) architecture.
|
| 34 |
+
|
| 35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 36 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 41 |
+
Vocabulary size of the Mega model. Defines the number of different tokens that can be represented by the
|
| 42 |
+
`inputs_ids` passed when calling [`MegaModel`].
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 128):
|
| 44 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 45 |
+
num_hidden_layers (`int`, *optional*, defaults to 4):
|
| 46 |
+
Number of hidden layers in the Mega encoder.
|
| 47 |
+
intermediate_size (`int`, *optional*, defaults to 256):
|
| 48 |
+
Dimensionality of the hidden size (self-attention value projection) within the Mega encoder
|
| 49 |
+
ema_projection_size (`int`, *optional*, defaults to 16):
|
| 50 |
+
Dimensionality of the MegaMultiDimensionDampedEma
|
| 51 |
+
bidirectional (`bool`, *optional*, defaults to `True`):
|
| 52 |
+
Whether the MegaMultiDimensionDampedEma used in Mega's self-attention should work bidirectionally (`True`)
|
| 53 |
+
or unidirectionally (`False`). Bidirectional EMA is incompatible with causal decoding, so this should be
|
| 54 |
+
False if you intend to use the model as a decoder.
|
| 55 |
+
shared_representation_size (`int`, *optional*, defaults to 64):
|
| 56 |
+
Dimensionality of the linear projection for shared representation of self-attention queries and keys
|
| 57 |
+
use_chunking (`bool`, *optional*, defaults to `False`):
|
| 58 |
+
Whether to chunk inputs for linear self-attention complexity (described as Mega-chunk in the paper)
|
| 59 |
+
chunk_size (`int`, *optional*, defaults to -1):
|
| 60 |
+
If `use_chunking` is set to `True`, determines the size of the chunks to apply to the input sequence. If
|
| 61 |
+
chunking is used, input sequences must be padded to a multiple of `chunk_size`
|
| 62 |
+
truncation (`int`, *optional*):
|
| 63 |
+
If specified, the sequence length for which to truncate MegaMultiDimensionDampedEma
|
| 64 |
+
normalize_before_mega (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether to normalize before (`True`) or after (`False`) passing through Mega encoder blocks
|
| 66 |
+
normalization_type (`str`, *optional*, defaults to `"scalenorm"`):
|
| 67 |
+
Type of normalization to use in Mega encoder blocks. Choose one of `"scalenorm"`, `"layernorm"`,
|
| 68 |
+
`"rmsnorm"`, `"batchnorm"`, or `"syncbatchnorm"` (GPU required for syncbatchnorm)
|
| 69 |
+
norm_affine (`bool`, *optional*, defaults to `True`):
|
| 70 |
+
If `True`, applies a parameterized affine transformation to inputs during normalization
|
| 71 |
+
activation (`str`, *optional*, defaults to `"silu"`):
|
| 72 |
+
Activation function to apply within Mega encoder blocks. Choose one of `"silu"`, `"relu"`, `"linear"`,
|
| 73 |
+
`"gelu"`, or `"gelu_accurate"`
|
| 74 |
+
attention_activation (`str`, *optional*, defaults to `"softmax"`):
|
| 75 |
+
Activation function to apply for single-headed self-attention (a la Transformer). Choose one of
|
| 76 |
+
`"softmax"`, `"laplace"`, or `"relu2"`
|
| 77 |
+
dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 78 |
+
The dropout probability for EMA self-attention
|
| 79 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 80 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 81 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 82 |
+
The dropout ratio for the attention probabilities.
|
| 83 |
+
use_feature_dropout (`bool`, *optional*, defaults to `False`):
|
| 84 |
+
Whether to use feature-based (`True`) or standard dropout (`False`)
|
| 85 |
+
use_normalized_ffn (`bool`, *optional*, defaults to `True`):
|
| 86 |
+
Whether to use the normalized feed-forward sub-layer in Mega blocks (`True`) or pass Mega encoder output
|
| 87 |
+
as-is (`False`)
|
| 88 |
+
nffn_hidden_size (`int`, *optional*, defaults to 256):
|
| 89 |
+
If using the normalized feed-forward network (NFFN) layer within Mega (`use_normalized_ffn = True`), this
|
| 90 |
+
is the hidden size of the NFFN
|
| 91 |
+
normalize_before_ffn (`bool`, *optional*, defaults to `True`):
|
| 92 |
+
Whether to normalize before (`True`) or after (`False`) the feed-forward portion of NFFN
|
| 93 |
+
nffn_activation_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 94 |
+
The dropout ratio for the NFFN component.
|
| 95 |
+
max_positions (`int`, *optional*, defaults to 2048):
|
| 96 |
+
The maximum sequence length to use for positional representations. For `"simple"` relative positional bias,
|
| 97 |
+
this is a hard limit on input length; `"rotary"` relative positional bias will extrapolate to longer
|
| 98 |
+
sequences
|
| 99 |
+
add_token_type_embeddings (`bool`, *optional*, defaults to `True`):
|
| 100 |
+
Whether to account for token types in embeddings. Left as optional to maintain compatibility with original
|
| 101 |
+
implementation while adding support for token types.
|
| 102 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 103 |
+
The vocabulary size of the `token_type_ids` passed when calling [`MegaModel`]. Only used if
|
| 104 |
+
`add_token_type_embeddings = True`
|
| 105 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 106 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 107 |
+
ema_delta_alpha_range (`float`, *optional*, defaults to 0.2):
|
| 108 |
+
The standard deviation for initializing the delta (damping factor) and alpha (decay factor) parameters in
|
| 109 |
+
MegaMultiDimensionDampedEma.
|
| 110 |
+
ema_beta_range (`float`, *optional*, defaults to 0.02):
|
| 111 |
+
The standard deviation for initializing the beta parameter (expansion matrix) in
|
| 112 |
+
MegaMultiDimensionDampedEma.
|
| 113 |
+
ema_gamma_omega_range (`float`, *optional*, defaults to 1.0):
|
| 114 |
+
The standard deviation for initializing the gamma (projection matrix) and omega (residual weight)
|
| 115 |
+
parameters in MultiDimensionEMA.
|
| 116 |
+
relative_positional_bias (`str`, *optional*, defaults to `"rotary"`):
|
| 117 |
+
Type of relative positional encoding. Choose one of `"rotary"` or `"simple"`. If `"simple"` is selected,
|
| 118 |
+
`max_positions` is used as a limit on input size, while `"rotary"` extrapolates beyond `max_positions`.
|
| 119 |
+
is_decoder (`bool`, *optional*, defaults to `False`):
|
| 120 |
+
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
|
| 121 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 122 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 123 |
+
relevant if `config.is_decoder=True`.
|
| 124 |
+
classifier_dropout (`float`, *optional*):
|
| 125 |
+
The dropout ratio for the classification head.
|
| 126 |
+
add_lm_hidden_dense_layer (`bool`, *optional*, defaults to `True`):
|
| 127 |
+
Whether to include a hidden layer for projection between encoder outputs and LM heads (`True`) or pass
|
| 128 |
+
hidden states directly to LM head (`False`). Remains optional for compatibility with original
|
| 129 |
+
implementation
|
| 130 |
+
|
| 131 |
+
Examples:
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
>>> from transformers import MegaConfig, MegaModel
|
| 135 |
+
|
| 136 |
+
>>> # Initializing a Mega configuration
|
| 137 |
+
>>> configuration = MegaConfig()
|
| 138 |
+
|
| 139 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 140 |
+
>>> model = MegaModel(configuration)
|
| 141 |
+
|
| 142 |
+
>>> # Accessing the model configuration
|
| 143 |
+
>>> configuration = model.config
|
| 144 |
+
```"""
|
| 145 |
+
|
| 146 |
+
model_type = "mega"
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
vocab_size=30522,
|
| 151 |
+
hidden_size=128,
|
| 152 |
+
num_hidden_layers=4,
|
| 153 |
+
intermediate_size=256,
|
| 154 |
+
ema_projection_size=16,
|
| 155 |
+
bidirectional=True,
|
| 156 |
+
shared_representation_size=64,
|
| 157 |
+
use_chunking=False,
|
| 158 |
+
chunk_size=-1,
|
| 159 |
+
truncation=None,
|
| 160 |
+
normalize_before_mega=True,
|
| 161 |
+
normalization_type="scalenorm",
|
| 162 |
+
norm_affine=True,
|
| 163 |
+
activation="silu",
|
| 164 |
+
attention_activation="softmax",
|
| 165 |
+
dropout_prob=0.1,
|
| 166 |
+
hidden_dropout_prob=0.1,
|
| 167 |
+
attention_probs_dropout_prob=0.1,
|
| 168 |
+
use_feature_dropout=False,
|
| 169 |
+
use_normalized_ffn=True,
|
| 170 |
+
nffn_hidden_size=256,
|
| 171 |
+
normalize_before_ffn=True,
|
| 172 |
+
nffn_activation_dropout_prob=0.1,
|
| 173 |
+
max_positions=2048,
|
| 174 |
+
add_token_type_embeddings=False,
|
| 175 |
+
type_vocab_size=2,
|
| 176 |
+
initializer_range=0.02,
|
| 177 |
+
ema_delta_alpha_range=0.2,
|
| 178 |
+
ema_beta_range=0.02,
|
| 179 |
+
ema_gamma_omega_range=1.0,
|
| 180 |
+
pad_token_id=1,
|
| 181 |
+
bos_token_id=0,
|
| 182 |
+
eos_token_id=2,
|
| 183 |
+
relative_positional_bias="rotary",
|
| 184 |
+
classifier_dropout=None,
|
| 185 |
+
use_cache=True,
|
| 186 |
+
add_lm_hidden_dense_layer=True,
|
| 187 |
+
**kwargs,
|
| 188 |
+
):
|
| 189 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 190 |
+
|
| 191 |
+
self.vocab_size = vocab_size
|
| 192 |
+
self.hidden_size = hidden_size
|
| 193 |
+
self.num_hidden_layers = num_hidden_layers
|
| 194 |
+
self.activation = activation
|
| 195 |
+
self.attention_activation = attention_activation
|
| 196 |
+
self.intermediate_size = intermediate_size
|
| 197 |
+
self.ema_projection_size = ema_projection_size
|
| 198 |
+
self.bidirectional = bidirectional
|
| 199 |
+
self.shared_representation_size = shared_representation_size
|
| 200 |
+
self.use_chunking = use_chunking
|
| 201 |
+
self.chunk_size = chunk_size
|
| 202 |
+
self.truncation = truncation
|
| 203 |
+
self.normalize_before_mega = normalize_before_mega
|
| 204 |
+
self.normalization_type = normalization_type
|
| 205 |
+
self.norm_affine = norm_affine
|
| 206 |
+
self.dropout_prob = dropout_prob
|
| 207 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 208 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 209 |
+
self.use_feature_dropout = use_feature_dropout
|
| 210 |
+
self.use_normalized_ffn = use_normalized_ffn
|
| 211 |
+
self.nffn_hidden_size = nffn_hidden_size
|
| 212 |
+
self.normalize_before_ffn = normalize_before_ffn
|
| 213 |
+
self.nffn_activation_dropout_prob = nffn_activation_dropout_prob
|
| 214 |
+
self.max_positions = max_positions
|
| 215 |
+
self.add_token_type_embeddings = add_token_type_embeddings
|
| 216 |
+
self.type_vocab_size = type_vocab_size
|
| 217 |
+
self.initializer_range = initializer_range
|
| 218 |
+
self.ema_delta_alpha_range = ema_delta_alpha_range
|
| 219 |
+
self.ema_beta_range = ema_beta_range
|
| 220 |
+
self.ema_gamma_omega_range = ema_gamma_omega_range
|
| 221 |
+
self.relative_positional_bias = relative_positional_bias
|
| 222 |
+
self.use_cache = use_cache
|
| 223 |
+
self.classifier_dropout = classifier_dropout
|
| 224 |
+
self.add_lm_hidden_dense_layer = add_lm_hidden_dense_layer
|
| 225 |
+
self.num_attention_heads = 1 # not used but required by Hugging Face
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class MegaOnnxConfig(OnnxConfig):
|
| 229 |
+
@property
|
| 230 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 231 |
+
if self.task == "multiple-choice":
|
| 232 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
| 233 |
+
else:
|
| 234 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
| 235 |
+
return OrderedDict(
|
| 236 |
+
[
|
| 237 |
+
("input_ids", dynamic_axis),
|
| 238 |
+
("attention_mask", dynamic_axis),
|
| 239 |
+
]
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
__all__ = ["MegaConfig", "MegaOnnxConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Convert Mega pretrained checkpoint. Built to convert the Masked LM checkpoint located at
|
| 18 |
+
https://huggingface.co/mnaylor/mega-wikitext-103
|
| 19 |
+
|
| 20 |
+
Requirements:
|
| 21 |
+
- clone the Mega repo and install fairseq from there
|
| 22 |
+
1. git clone https://github.com/facebookresearch/mega.git
|
| 23 |
+
2. cd mega && pip install -e
|
| 24 |
+
- clone the pretrained weights for the original implementation from the hugging face repo
|
| 25 |
+
* use this location as the path for pretrained weights
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
|
| 30 |
+
# utilities to import the model weights and config file
|
| 31 |
+
import os
|
| 32 |
+
import pickle as pkl
|
| 33 |
+
|
| 34 |
+
# PyTorch + new model classes
|
| 35 |
+
import torch
|
| 36 |
+
from torch import nn
|
| 37 |
+
|
| 38 |
+
from transformers import AutoTokenizer, MegaConfig, MegaForMaskedLM
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# import the EncoderLayer class used to pretrain
|
| 42 |
+
# !! NOTE !! this requires the version of fairseq that is built when you install the Mega source
|
| 43 |
+
try:
|
| 44 |
+
from fairseq.modules.mega_layer import MegaEncoderLayer
|
| 45 |
+
except ImportError:
|
| 46 |
+
raise ImportError("You need to install the version of fairseq from the Mega repo!")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# define the wrapper classes used to train the MLM (see colab notebook below)
|
| 50 |
+
# https://colab.research.google.com/drive/1qfUO6o5HRdxBblWlw058HVyvaEPhPpH8?usp=sharing
|
| 51 |
+
# MegaLM outputs hidden states
|
| 52 |
+
class MegaLM(nn.Module):
|
| 53 |
+
"The base class for our Mega encoder - given input IDs, embed text and return encoder output"
|
| 54 |
+
|
| 55 |
+
def __init__(self, mega_args, depth, vocab_size):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.mega_args = mega_args
|
| 58 |
+
self.embedding_layer = nn.Embedding(vocab_size, self.mega_args.encoder_embed_dim)
|
| 59 |
+
self.encoders = nn.ModuleList([MegaEncoderLayer(self.mega_args) for _ in range(depth)])
|
| 60 |
+
self.depth = depth
|
| 61 |
+
|
| 62 |
+
def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
|
| 63 |
+
"""
|
| 64 |
+
Code for a forward pass - expects input_ids and attention_mask to come from a Hugging Face tokenizer as PyTorch
|
| 65 |
+
tensors, and returns a tensor of size (batch, n_classes) containing classification logits
|
| 66 |
+
|
| 67 |
+
Other options:
|
| 68 |
+
- batch_first: boolean indicating whether the batch dimension is first in input_ids (default: True, which
|
| 69 |
+
aligns with the HF tokenizer behavior)
|
| 70 |
+
- ignore_mask_value: the value in attention_mask that identifies tokens that should be ignored (default: 0,
|
| 71 |
+
which aligns with HF tokenizer)
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
# Mega expects embeddings to be (time, batch, embedding size), but
|
| 75 |
+
# Hugging Face returns tokens as (batch, time)
|
| 76 |
+
if batch_first:
|
| 77 |
+
input_ids = input_ids.T
|
| 78 |
+
|
| 79 |
+
# to make things more confusing, Mega expects the attention mask to
|
| 80 |
+
# be (batch, time), but with values of 0 (normal token) and 1 (ignore token)
|
| 81 |
+
# which is the opposite of what HF returns
|
| 82 |
+
if ignore_mask_value == 0:
|
| 83 |
+
attention_mask = 1 - attention_mask
|
| 84 |
+
|
| 85 |
+
# get token embeddings from IDs
|
| 86 |
+
embeds = self.embedding_layer(input_ids)
|
| 87 |
+
|
| 88 |
+
# pass through the Mega layers
|
| 89 |
+
# input is (time, batch, encoder dim) and output is the same
|
| 90 |
+
for encoder in self.encoders:
|
| 91 |
+
embeds = encoder(embeds, attention_mask)
|
| 92 |
+
|
| 93 |
+
# return according to the shape specified
|
| 94 |
+
if batch_first:
|
| 95 |
+
# (T, B, H) --> (B, T, H)
|
| 96 |
+
return torch.transpose(embeds, 0, 1)
|
| 97 |
+
else:
|
| 98 |
+
return embeds
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# renamed from MegaForMaskedLM to avoid confusion with new module
|
| 102 |
+
class OriginalMegaForMaskedLM(nn.Module):
|
| 103 |
+
"A wrapper class for doing masked language modeling with Mega"
|
| 104 |
+
|
| 105 |
+
def __init__(self, mega_args, depth, vocab_size):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.mega = MegaLM(mega_args, depth, vocab_size)
|
| 108 |
+
self.mlm_head = nn.Linear(mega_args.encoder_embed_dim, vocab_size)
|
| 109 |
+
self.dropout = nn.Dropout(p=0.1)
|
| 110 |
+
|
| 111 |
+
def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
|
| 112 |
+
"""
|
| 113 |
+
Perform a forward pass through the Mega encoder and the masked LM head. Returns logits for each vocabulary
|
| 114 |
+
entry.
|
| 115 |
+
|
| 116 |
+
If `batch_first` (default to align with Hugging Face tokenizer behavior), output will have the shape (Batch
|
| 117 |
+
size, Sequence length, Vocab size); otherwise (S, B, V)
|
| 118 |
+
"""
|
| 119 |
+
encoder_output = self.mega(input_ids, attention_mask, batch_first, ignore_mask_value)
|
| 120 |
+
return self.mlm_head(self.dropout(encoder_output))
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# code to convert the checkpoint located in the user-specified location
|
| 124 |
+
def convert_checkpoint_to_huggingface(pretrained_checkpoint_path, output_path, includes_tokenizer):
|
| 125 |
+
with open(os.path.join(pretrained_checkpoint_path, "model_args.pkl"), "rb") as f:
|
| 126 |
+
mega_original_args = pkl.load(f)
|
| 127 |
+
|
| 128 |
+
# load the original encoder
|
| 129 |
+
original_mlm = OriginalMegaForMaskedLM(**mega_original_args).eval()
|
| 130 |
+
|
| 131 |
+
# load its weights
|
| 132 |
+
print(
|
| 133 |
+
"Original Mega encoder:",
|
| 134 |
+
original_mlm.mega.load_state_dict(
|
| 135 |
+
torch.load(
|
| 136 |
+
os.path.join(pretrained_checkpoint_path, "encoder_weights.pt"), map_location="cpu", weights_only=True
|
| 137 |
+
)
|
| 138 |
+
),
|
| 139 |
+
)
|
| 140 |
+
print(
|
| 141 |
+
"Original Mega MLM layer:",
|
| 142 |
+
original_mlm.mlm_head.load_state_dict(
|
| 143 |
+
torch.load(
|
| 144 |
+
os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
|
| 145 |
+
)
|
| 146 |
+
),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# create a new config from the old one
|
| 150 |
+
hf_config = MegaConfig(
|
| 151 |
+
num_hidden_layers=mega_original_args["depth"],
|
| 152 |
+
vocab_size=mega_original_args["vocab_size"],
|
| 153 |
+
hidden_size=mega_original_args["mega_args"].encoder_embed_dim,
|
| 154 |
+
shared_representation_size=mega_original_args["mega_args"].encoder_z_dim,
|
| 155 |
+
intermediate_size=mega_original_args["mega_args"].encoder_hidden_dim,
|
| 156 |
+
ema_projection_size=mega_original_args["mega_args"].encoder_n_dim,
|
| 157 |
+
dropout_prob=mega_original_args["mega_args"].dropout,
|
| 158 |
+
attention_probs_dropout_prob=mega_original_args["mega_args"].attention_dropout,
|
| 159 |
+
hidden_dropout_prob=mega_original_args["mega_args"].hidden_dropout,
|
| 160 |
+
activation=mega_original_args["mega_args"].activation_fn,
|
| 161 |
+
attention_activation=mega_original_args["mega_args"].attention_activation_fn,
|
| 162 |
+
bidirectional=mega_original_args["mega_args"].bidirectional,
|
| 163 |
+
use_chunking=mega_original_args["mega_args"].encoder_chunk_size > 0,
|
| 164 |
+
chunk_size=mega_original_args["mega_args"].encoder_chunk_size,
|
| 165 |
+
truncation=mega_original_args["mega_args"].truncation_length,
|
| 166 |
+
normalization_type=mega_original_args["mega_args"].normalization_type,
|
| 167 |
+
normalize_before_mega=True,
|
| 168 |
+
norm_affine=True,
|
| 169 |
+
use_feature_dropout=mega_original_args["mega_args"].feature_dropout,
|
| 170 |
+
relative_positional_bias=mega_original_args["mega_args"].rel_pos_bias,
|
| 171 |
+
max_positions=mega_original_args["mega_args"].max_source_positions,
|
| 172 |
+
nffn_hidden_size=mega_original_args["mega_args"].encoder_ffn_embed_dim,
|
| 173 |
+
normalize_before_ffn=mega_original_args["mega_args"].normalize_before,
|
| 174 |
+
# new arguments added for HF implementation
|
| 175 |
+
nffn_activation_dropout_prob=0.0,
|
| 176 |
+
add_token_type_embeddings=False,
|
| 177 |
+
add_lm_hidden_dense_layer=False,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
hf_mlm = MegaForMaskedLM(hf_config).eval()
|
| 181 |
+
|
| 182 |
+
# the originl checkpoint just uses nn.Embedding for the word embeddings
|
| 183 |
+
# we use a wrapper module for embeddings to add support for positional embeddings
|
| 184 |
+
hf_mlm.mega.embedding_layer.word_embeddings.weight = original_mlm.mega.embedding_layer.weight
|
| 185 |
+
|
| 186 |
+
# modify the state dictionary of the original checkpoint to account for naming issues in the Hugging Face
|
| 187 |
+
# ecosystem -- any names containing "beta" or "gamma" aren't safe to use and are renamed upon _load_pretrained,
|
| 188 |
+
# also renaming previously confusing parameter names
|
| 189 |
+
original_state_dict = original_mlm.mega.encoders.state_dict()
|
| 190 |
+
updated_keys = {}
|
| 191 |
+
for module_name in original_state_dict.keys():
|
| 192 |
+
new_module_name = None
|
| 193 |
+
# have to handle gamma, beta, and alpha differently due to their use
|
| 194 |
+
# in multiple modules within the original repository;
|
| 195 |
+
# beta is used in EMA, MovingAverageGatedAttention, and RotaryRelativePositionalBias, and must be renamed due to flax/tf weights
|
| 196 |
+
# the EMA sublayer was renamed from "move" to "ema_gate" for readability, so that is also done here
|
| 197 |
+
if "beta" in module_name:
|
| 198 |
+
# EMA sub-layers were always called "move" in the original repo
|
| 199 |
+
if "move.beta" in module_name:
|
| 200 |
+
new_module_name = module_name.replace("move.beta", "ema_gate.ema_expansion_matrix")
|
| 201 |
+
elif "mega_layer.beta" in module_name:
|
| 202 |
+
new_module_name = module_name.replace("beta", "qk_bias")
|
| 203 |
+
else:
|
| 204 |
+
new_module_name = module_name.replace("beta", "b_param")
|
| 205 |
+
# beta is used in EMA and MovingAverageGatedAttention, and must be renamed due to flax/tf weights
|
| 206 |
+
elif "gamma" in module_name:
|
| 207 |
+
if "move.gamma" in module_name:
|
| 208 |
+
new_module_name = module_name.replace("move.gamma", "ema_gate.kernel_projection_matrix")
|
| 209 |
+
elif "mega_layer.gamma" in module_name:
|
| 210 |
+
new_module_name = module_name.replace("gamma", "qk_weight")
|
| 211 |
+
else:
|
| 212 |
+
new_module_name = module_name.replace("gamma", "g_param")
|
| 213 |
+
# alpha is used in EMA and positional bias; renaming to improve readability
|
| 214 |
+
elif "move.alpha" in module_name:
|
| 215 |
+
new_module_name = module_name.replace("move.alpha", "ema_gate.decay_factor")
|
| 216 |
+
# delta is only used in EMA; renaming to improve readability
|
| 217 |
+
elif "move.delta" in module_name:
|
| 218 |
+
new_module_name = module_name.replace("move.delta", "ema_gate.damping_factor")
|
| 219 |
+
# omega is only used in EMA; renaming to improve readability
|
| 220 |
+
elif "omega" in module_name:
|
| 221 |
+
new_module_name = module_name.replace("move.omega", "ema_gate.residual_weight")
|
| 222 |
+
|
| 223 |
+
if new_module_name:
|
| 224 |
+
updated_keys[module_name] = new_module_name
|
| 225 |
+
|
| 226 |
+
if len(updated_keys) != 0:
|
| 227 |
+
print(f"Renaming these keys: {updated_keys.keys()}")
|
| 228 |
+
else:
|
| 229 |
+
print("No need to rename state dict entries")
|
| 230 |
+
for old, new in updated_keys.items():
|
| 231 |
+
original_state_dict[new] = original_state_dict.pop(old)
|
| 232 |
+
|
| 233 |
+
# now attempt to load the state dictionary with updated names
|
| 234 |
+
# note that we now call it `mega.layers` instead of `mega.encoders` due to hugging face style
|
| 235 |
+
print("HF Mega encoder:", hf_mlm.mega.layers.load_state_dict(original_state_dict))
|
| 236 |
+
|
| 237 |
+
# load the MLM head weights directly
|
| 238 |
+
print(
|
| 239 |
+
"HF Mega MLM layer:",
|
| 240 |
+
hf_mlm.mlm_head.load_state_dict(
|
| 241 |
+
torch.load(
|
| 242 |
+
os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
|
| 243 |
+
)
|
| 244 |
+
),
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# test on a randomly generated input sequence
|
| 248 |
+
input_ids = torch.randint(0, hf_config.vocab_size, size=(4, 256))
|
| 249 |
+
input_mask = torch.ones_like(input_ids)
|
| 250 |
+
# mask a few tokens to make sure masking is applied appropriately :)
|
| 251 |
+
input_mask[:, -10:] = 0
|
| 252 |
+
|
| 253 |
+
# run forward passes
|
| 254 |
+
original_output = original_mlm(input_ids, input_mask, batch_first=True, ignore_mask_value=0)
|
| 255 |
+
hf_output = hf_mlm(input_ids, input_mask)[0]
|
| 256 |
+
|
| 257 |
+
# print shapes and diff
|
| 258 |
+
print(f"original output {original_output.shape}")
|
| 259 |
+
print(f"hf output {hf_output.shape}")
|
| 260 |
+
print(f"max diff: {(original_output - hf_output).max()}") # 0.0
|
| 261 |
+
success = torch.allclose(original_output, hf_output, atol=1e-3)
|
| 262 |
+
|
| 263 |
+
if success:
|
| 264 |
+
print("Yay!")
|
| 265 |
+
hf_mlm.save_pretrained(output_path)
|
| 266 |
+
else:
|
| 267 |
+
raise RuntimeError(f"Something's broken :(\nOriginal:\n{original_output}\n\nHF\n{hf_output}\n{hf_mlm}")
|
| 268 |
+
|
| 269 |
+
if includes_tokenizer:
|
| 270 |
+
print("Transferring tokenizer")
|
| 271 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_checkpoint_path)
|
| 272 |
+
tokenizer.save_pretrained(output_path)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
if __name__ == "__main__":
|
| 276 |
+
parser = argparse.ArgumentParser()
|
| 277 |
+
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--pretrained_checkpoint_path",
|
| 280 |
+
default=None,
|
| 281 |
+
type=str,
|
| 282 |
+
required=True,
|
| 283 |
+
help="Point to the directory containing your model weights using the official Mega repo",
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--output_path", default=None, type=str, required=True, help="Location to save the Hugging Face version"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
"--includes_tokenizer",
|
| 292 |
+
action="store_true",
|
| 293 |
+
help="Use this flag if there is a Hugging Face tokenizer in the original checkpoint repo",
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
args = parser.parse_args()
|
| 297 |
+
|
| 298 |
+
convert_checkpoint_to_huggingface(args.pretrained_checkpoint_path, args.output_path, args.includes_tokenizer)
|
docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_mmbt import *
|
| 22 |
+
from .modeling_mmbt import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# Copyright (c) HuggingFace Inc. team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""MMBT configuration"""
|
| 17 |
+
|
| 18 |
+
from ....utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MMBTConfig:
|
| 25 |
+
"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`MMBTModel`]. It is used to instantiate a MMBT
|
| 27 |
+
model according to the specified arguments, defining the model architecture.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
config ([`PreTrainedConfig`]):
|
| 31 |
+
Config of the underlying Transformer models. Its values are copied over to use a single config.
|
| 32 |
+
num_labels (`int`, *optional*):
|
| 33 |
+
Size of final Linear layer for classification.
|
| 34 |
+
modal_hidden_size (`int`, *optional*, defaults to 2048):
|
| 35 |
+
Embedding dimension of the non-text modality encoder.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, config, num_labels=None, modal_hidden_size=2048):
|
| 39 |
+
self.__dict__ = config.__dict__
|
| 40 |
+
self.modal_hidden_size = modal_hidden_size
|
| 41 |
+
if num_labels:
|
| 42 |
+
self.num_labels = num_labels
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
__all__ = ["MMBTConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# Copyright (c) HuggingFace Inc. team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch MMBT model."""
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
| 21 |
+
|
| 22 |
+
from ....modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
|
| 23 |
+
from ....modeling_utils import ModuleUtilsMixin
|
| 24 |
+
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
_CONFIG_FOR_DOC = "MMBTConfig"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ModalEmbeddings(nn.Module):
|
| 33 |
+
"""Generic Modal Embeddings which takes in an encoder, and a transformer embedding."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, config, encoder, embeddings):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.config = config
|
| 38 |
+
self.encoder = encoder
|
| 39 |
+
self.proj_embeddings = nn.Linear(config.modal_hidden_size, config.hidden_size)
|
| 40 |
+
self.position_embeddings = embeddings.position_embeddings
|
| 41 |
+
self.token_type_embeddings = embeddings.token_type_embeddings
|
| 42 |
+
self.word_embeddings = embeddings.word_embeddings
|
| 43 |
+
self.LayerNorm = embeddings.LayerNorm
|
| 44 |
+
self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
|
| 45 |
+
|
| 46 |
+
def forward(self, input_modal, start_token=None, end_token=None, position_ids=None, token_type_ids=None):
|
| 47 |
+
token_embeddings = self.proj_embeddings(self.encoder(input_modal))
|
| 48 |
+
seq_length = token_embeddings.size(1)
|
| 49 |
+
|
| 50 |
+
if start_token is not None:
|
| 51 |
+
start_token_embeds = self.word_embeddings(start_token)
|
| 52 |
+
seq_length += 1
|
| 53 |
+
token_embeddings = torch.cat([start_token_embeds.unsqueeze(1), token_embeddings], dim=1)
|
| 54 |
+
|
| 55 |
+
if end_token is not None:
|
| 56 |
+
end_token_embeds = self.word_embeddings(end_token)
|
| 57 |
+
seq_length += 1
|
| 58 |
+
token_embeddings = torch.cat([token_embeddings, end_token_embeds.unsqueeze(1)], dim=1)
|
| 59 |
+
|
| 60 |
+
if position_ids is None:
|
| 61 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_modal.device)
|
| 62 |
+
position_ids = position_ids.unsqueeze(0).expand(input_modal.size(0), seq_length)
|
| 63 |
+
|
| 64 |
+
if token_type_ids is None:
|
| 65 |
+
token_type_ids = torch.zeros(
|
| 66 |
+
(input_modal.size(0), seq_length), dtype=torch.long, device=input_modal.device
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 70 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 71 |
+
embeddings = token_embeddings + position_embeddings + token_type_embeddings
|
| 72 |
+
embeddings = self.LayerNorm(embeddings)
|
| 73 |
+
embeddings = self.dropout(embeddings)
|
| 74 |
+
return embeddings
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
MMBT_START_DOCSTRING = r"""
|
| 78 |
+
MMBT model was proposed in [Supervised Multimodal Bitransformers for Classifying Images and
|
| 79 |
+
Text](https://github.com/facebookresearch/mmbt) by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
|
| 80 |
+
It's a supervised multimodal bitransformer model that fuses information from text and other image encoders, and
|
| 81 |
+
obtain state-of-the-art performance on various multimodal classification benchmark tasks.
|
| 82 |
+
|
| 83 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 84 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 85 |
+
etc.)
|
| 86 |
+
|
| 87 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 88 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 89 |
+
and behavior.
|
| 90 |
+
|
| 91 |
+
Parameters:
|
| 92 |
+
config ([`MMBTConfig`]): Model configuration class with all the parameters of the model.
|
| 93 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 94 |
+
configuration.
|
| 95 |
+
transformer (`nn.Module`): A text transformer that is used by MMBT.
|
| 96 |
+
It should have embeddings, encoder, and pooler attributes.
|
| 97 |
+
encoder (`nn.Module`): Encoder for the second modality.
|
| 98 |
+
It should take in a batch of modal inputs and return k, n dimension embeddings.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
MMBT_INPUTS_DOCSTRING = r"""
|
| 102 |
+
Args:
|
| 103 |
+
input_modal (`torch.FloatTensor` of shape `(batch_size, ***)`):
|
| 104 |
+
The other modality data. It will be the shape that the encoder for that type expects. e.g. With an Image
|
| 105 |
+
Encoder, the shape would be (batch_size, channels, height, width)
|
| 106 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 107 |
+
Indices of input sequence tokens in the vocabulary. It does not expect [CLS] token to be added as it's
|
| 108 |
+
appended to the end of other modality embeddings. Indices can be obtained using [`AutoTokenizer`]. See
|
| 109 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
| 110 |
+
|
| 111 |
+
[What are input IDs?](../glossary#input-ids)
|
| 112 |
+
modal_start_tokens (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 113 |
+
Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for classification
|
| 114 |
+
tasks.
|
| 115 |
+
modal_end_tokens (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 116 |
+
Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used.
|
| 117 |
+
attention_mask (*optional*) `torch.FloatTensor` of shape `(batch_size, sequence_length)`:
|
| 118 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 119 |
+
|
| 120 |
+
- 1 for tokens that are **not masked**,
|
| 121 |
+
- 0 for tokens that are **masked**.
|
| 122 |
+
|
| 123 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 124 |
+
token_type_ids (*optional*) `torch.LongTensor` of shape `(batch_size, sequence_length)`:
|
| 125 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 126 |
+
1]`:
|
| 127 |
+
|
| 128 |
+
- 0 corresponds to a *sentence A* token,
|
| 129 |
+
- 1 corresponds to a *sentence B* token.
|
| 130 |
+
|
| 131 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 132 |
+
modal_token_type_ids (*optional*) `torch.LongTensor` of shape `(batch_size, modal_sequence_length)`:
|
| 133 |
+
Segment token indices to indicate different portions of the non-text modality. The embeddings from these
|
| 134 |
+
tokens will be summed with the respective token embeddings for the non-text modality.
|
| 135 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 136 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 137 |
+
config.max_position_embeddings - 1]`.
|
| 138 |
+
|
| 139 |
+
[What are position IDs?](../glossary#position-ids)
|
| 140 |
+
modal_position_ids (`torch.LongTensor` of shape `(batch_size, modal_sequence_length)`, *optional*):
|
| 141 |
+
Indices of positions of each input sequence tokens in the position embeddings for the non-text modality.
|
| 142 |
+
Selected in the range `[0, config.max_position_embeddings - 1]`.
|
| 143 |
+
|
| 144 |
+
[What are position IDs?](../glossary#position-ids)
|
| 145 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 146 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 147 |
+
|
| 148 |
+
- 1 indicates the head is **not masked**,
|
| 149 |
+
- 0 indicates the head is **masked**.
|
| 150 |
+
|
| 151 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, embedding_dim)`, *optional*):
|
| 152 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 153 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 154 |
+
model's internal embedding lookup matrix.
|
| 155 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 156 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 157 |
+
the model is configured as a decoder.
|
| 158 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 159 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 160 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 161 |
+
|
| 162 |
+
- 1 for tokens that are **not masked**,
|
| 163 |
+
- 0 for tokens that are **masked**.
|
| 164 |
+
|
| 165 |
+
output_attentions (`bool`, *optional*):
|
| 166 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 167 |
+
tensors for more detail.
|
| 168 |
+
output_hidden_states (`bool`, *optional*):
|
| 169 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 170 |
+
more detail.
|
| 171 |
+
return_dict (`bool`, *optional*):
|
| 172 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@add_start_docstrings(
|
| 177 |
+
"The bare MMBT Model outputting raw hidden-states without any specific head on top.",
|
| 178 |
+
MMBT_START_DOCSTRING,
|
| 179 |
+
)
|
| 180 |
+
class MMBTModel(nn.Module, ModuleUtilsMixin):
|
| 181 |
+
def __init__(self, config, transformer, encoder):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.config = config
|
| 184 |
+
self.transformer = transformer
|
| 185 |
+
self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings)
|
| 186 |
+
|
| 187 |
+
@add_start_docstrings_to_model_forward(MMBT_INPUTS_DOCSTRING)
|
| 188 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
| 189 |
+
def forward(
|
| 190 |
+
self,
|
| 191 |
+
input_modal,
|
| 192 |
+
input_ids=None,
|
| 193 |
+
modal_start_tokens=None,
|
| 194 |
+
modal_end_tokens=None,
|
| 195 |
+
attention_mask=None,
|
| 196 |
+
token_type_ids=None,
|
| 197 |
+
modal_token_type_ids=None,
|
| 198 |
+
position_ids=None,
|
| 199 |
+
modal_position_ids=None,
|
| 200 |
+
head_mask=None,
|
| 201 |
+
inputs_embeds=None,
|
| 202 |
+
encoder_hidden_states=None,
|
| 203 |
+
encoder_attention_mask=None,
|
| 204 |
+
output_attentions=None,
|
| 205 |
+
output_hidden_states=None,
|
| 206 |
+
return_dict=None,
|
| 207 |
+
):
|
| 208 |
+
r"""
|
| 209 |
+
Returns:
|
| 210 |
+
|
| 211 |
+
Examples:
|
| 212 |
+
|
| 213 |
+
```python
|
| 214 |
+
# For example purposes. Not runnable.
|
| 215 |
+
transformer = BertModel.from_pretrained("google-bert/bert-base-uncased")
|
| 216 |
+
encoder = ImageEncoder(args)
|
| 217 |
+
mmbt = MMBTModel(config, transformer, encoder)
|
| 218 |
+
```"""
|
| 219 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 220 |
+
output_hidden_states = (
|
| 221 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 222 |
+
)
|
| 223 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 224 |
+
|
| 225 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 226 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 227 |
+
elif input_ids is not None:
|
| 228 |
+
input_txt_shape = input_ids.size()
|
| 229 |
+
elif inputs_embeds is not None:
|
| 230 |
+
input_txt_shape = inputs_embeds.size()[:-1]
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 233 |
+
|
| 234 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 235 |
+
|
| 236 |
+
modal_embeddings = self.modal_encoder(
|
| 237 |
+
input_modal,
|
| 238 |
+
start_token=modal_start_tokens,
|
| 239 |
+
end_token=modal_end_tokens,
|
| 240 |
+
position_ids=modal_position_ids,
|
| 241 |
+
token_type_ids=modal_token_type_ids,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
input_modal_shape = modal_embeddings.size()[:-1]
|
| 245 |
+
|
| 246 |
+
if token_type_ids is None:
|
| 247 |
+
token_type_ids = torch.ones(input_txt_shape, dtype=torch.long, device=device)
|
| 248 |
+
|
| 249 |
+
txt_embeddings = self.transformer.embeddings(
|
| 250 |
+
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1)
|
| 254 |
+
|
| 255 |
+
input_shape = embedding_output.size()[:-1]
|
| 256 |
+
|
| 257 |
+
if attention_mask is None:
|
| 258 |
+
attention_mask = torch.ones(input_shape, device=device)
|
| 259 |
+
else:
|
| 260 |
+
attention_mask = torch.cat(
|
| 261 |
+
[torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1
|
| 262 |
+
)
|
| 263 |
+
if encoder_attention_mask is None:
|
| 264 |
+
encoder_attention_mask = torch.ones(input_shape, device=device)
|
| 265 |
+
else:
|
| 266 |
+
encoder_attention_mask = torch.cat(
|
| 267 |
+
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 271 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 272 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 273 |
+
|
| 274 |
+
encoder_outputs = self.transformer.encoder(
|
| 275 |
+
embedding_output,
|
| 276 |
+
attention_mask=extended_attention_mask,
|
| 277 |
+
head_mask=head_mask,
|
| 278 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 279 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 280 |
+
output_attentions=output_attentions,
|
| 281 |
+
output_hidden_states=output_hidden_states,
|
| 282 |
+
return_dict=return_dict,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
sequence_output = encoder_outputs[0]
|
| 286 |
+
pooled_output = self.transformer.pooler(sequence_output)
|
| 287 |
+
|
| 288 |
+
if not return_dict:
|
| 289 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 290 |
+
|
| 291 |
+
return BaseModelOutputWithPooling(
|
| 292 |
+
last_hidden_state=sequence_output,
|
| 293 |
+
pooler_output=pooled_output,
|
| 294 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 295 |
+
attentions=encoder_outputs.attentions,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def get_input_embeddings(self):
|
| 299 |
+
return self.embeddings.word_embeddings
|
| 300 |
+
|
| 301 |
+
def set_input_embeddings(self, value):
|
| 302 |
+
self.embeddings.word_embeddings = value
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
@add_start_docstrings(
|
| 306 |
+
"""
|
| 307 |
+
MMBT Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
|
| 308 |
+
""",
|
| 309 |
+
MMBT_START_DOCSTRING,
|
| 310 |
+
MMBT_INPUTS_DOCSTRING,
|
| 311 |
+
)
|
| 312 |
+
class MMBTForClassification(nn.Module):
|
| 313 |
+
r"""
|
| 314 |
+
**labels**: (*optional*) `torch.LongTensor` of shape `(batch_size,)`:
|
| 315 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 316 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 317 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 318 |
+
|
| 319 |
+
Returns: *Tuple* comprising various elements depending on the configuration (config) and inputs: **loss**:
|
| 320 |
+
(*optional*, returned when `labels` is provided) `torch.FloatTensor` of shape `(1,)`: Classification (or
|
| 321 |
+
regression if config.num_labels==1) loss. **logits**:
|
| 322 |
+
`torch.FloatTensor` of shape `(batch_size, config.num_labels)` Classification (or regression if
|
| 323 |
+
config.num_labels==1) scores (before SoftMax).
|
| 324 |
+
**hidden_states**: (*optional*, returned when `output_hidden_states=True`) list of `torch.FloatTensor` (one for
|
| 325 |
+
the output of each layer + the output of the embeddings) of shape `(batch_size, sequence_length, hidden_size)`:
|
| 326 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**:
|
| 327 |
+
(*optional*, returned when `output_attentions=True`) list of `torch.FloatTensor` (one for each layer) of shape
|
| 328 |
+
`(batch_size, num_heads, sequence_length, sequence_length)`: Attentions weights after the attention softmax, used
|
| 329 |
+
to compute the weighted average in the self-attention heads.
|
| 330 |
+
|
| 331 |
+
Examples:
|
| 332 |
+
|
| 333 |
+
```python
|
| 334 |
+
# For example purposes. Not runnable.
|
| 335 |
+
transformer = BertModel.from_pretrained("google-bert/bert-base-uncased")
|
| 336 |
+
encoder = ImageEncoder(args)
|
| 337 |
+
model = MMBTForClassification(config, transformer, encoder)
|
| 338 |
+
outputs = model(input_modal, input_ids, labels=labels)
|
| 339 |
+
loss, logits = outputs[:2]
|
| 340 |
+
```"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, config, transformer, encoder):
|
| 343 |
+
super().__init__()
|
| 344 |
+
self.num_labels = config.num_labels
|
| 345 |
+
|
| 346 |
+
self.mmbt = MMBTModel(config, transformer, encoder)
|
| 347 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 348 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 349 |
+
|
| 350 |
+
def forward(
|
| 351 |
+
self,
|
| 352 |
+
input_modal,
|
| 353 |
+
input_ids=None,
|
| 354 |
+
modal_start_tokens=None,
|
| 355 |
+
modal_end_tokens=None,
|
| 356 |
+
attention_mask=None,
|
| 357 |
+
token_type_ids=None,
|
| 358 |
+
modal_token_type_ids=None,
|
| 359 |
+
position_ids=None,
|
| 360 |
+
modal_position_ids=None,
|
| 361 |
+
head_mask=None,
|
| 362 |
+
inputs_embeds=None,
|
| 363 |
+
labels=None,
|
| 364 |
+
return_dict=None,
|
| 365 |
+
):
|
| 366 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 367 |
+
|
| 368 |
+
outputs = self.mmbt(
|
| 369 |
+
input_modal=input_modal,
|
| 370 |
+
input_ids=input_ids,
|
| 371 |
+
modal_start_tokens=modal_start_tokens,
|
| 372 |
+
modal_end_tokens=modal_end_tokens,
|
| 373 |
+
attention_mask=attention_mask,
|
| 374 |
+
token_type_ids=token_type_ids,
|
| 375 |
+
modal_token_type_ids=modal_token_type_ids,
|
| 376 |
+
position_ids=position_ids,
|
| 377 |
+
modal_position_ids=modal_position_ids,
|
| 378 |
+
head_mask=head_mask,
|
| 379 |
+
inputs_embeds=inputs_embeds,
|
| 380 |
+
return_dict=return_dict,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
pooled_output = outputs[1]
|
| 384 |
+
|
| 385 |
+
pooled_output = self.dropout(pooled_output)
|
| 386 |
+
logits = self.classifier(pooled_output)
|
| 387 |
+
|
| 388 |
+
loss = None
|
| 389 |
+
if labels is not None:
|
| 390 |
+
if self.num_labels == 1:
|
| 391 |
+
# We are doing regression
|
| 392 |
+
loss_fct = MSELoss()
|
| 393 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
| 394 |
+
else:
|
| 395 |
+
loss_fct = CrossEntropyLoss()
|
| 396 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 397 |
+
|
| 398 |
+
if not return_dict:
|
| 399 |
+
output = (logits,) + outputs[2:]
|
| 400 |
+
return ((loss,) + output) if loss is not None else output
|
| 401 |
+
|
| 402 |
+
return SequenceClassifierOutput(
|
| 403 |
+
loss=loss,
|
| 404 |
+
logits=logits,
|
| 405 |
+
hidden_states=outputs.hidden_states,
|
| 406 |
+
attentions=outputs.attentions,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
__all__ = ["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]
|
docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_nat import *
|
| 22 |
+
from .modeling_nat import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Neighborhood Attention Transformer model configuration"""
|
| 16 |
+
|
| 17 |
+
from ....configuration_utils import PretrainedConfig
|
| 18 |
+
from ....utils import logging
|
| 19 |
+
from ....utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class NatConfig(BackboneConfigMixin, PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`NatModel`]. It is used to instantiate a Nat model
|
| 28 |
+
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 29 |
+
defaults will yield a similar configuration to that of the Nat
|
| 30 |
+
[shi-labs/nat-mini-in1k-224](https://huggingface.co/shi-labs/nat-mini-in1k-224) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
patch_size (`int`, *optional*, defaults to 4):
|
| 37 |
+
The size (resolution) of each patch. NOTE: Only patch size of 4 is supported at the moment.
|
| 38 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 39 |
+
The number of input channels.
|
| 40 |
+
embed_dim (`int`, *optional*, defaults to 64):
|
| 41 |
+
Dimensionality of patch embedding.
|
| 42 |
+
depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 5]`):
|
| 43 |
+
Number of layers in each level of the encoder.
|
| 44 |
+
num_heads (`List[int]`, *optional*, defaults to `[2, 4, 8, 16]`):
|
| 45 |
+
Number of attention heads in each layer of the Transformer encoder.
|
| 46 |
+
kernel_size (`int`, *optional*, defaults to 7):
|
| 47 |
+
Neighborhood Attention kernel size.
|
| 48 |
+
mlp_ratio (`float`, *optional*, defaults to 3.0):
|
| 49 |
+
Ratio of MLP hidden dimensionality to embedding dimensionality.
|
| 50 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 51 |
+
Whether or not a learnable bias should be added to the queries, keys and values.
|
| 52 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 53 |
+
The dropout probability for all fully connected layers in the embeddings and encoder.
|
| 54 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 55 |
+
The dropout ratio for the attention probabilities.
|
| 56 |
+
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
| 57 |
+
Stochastic depth rate.
|
| 58 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 59 |
+
The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
|
| 60 |
+
`"selu"` and `"gelu_new"` are supported.
|
| 61 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 62 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 63 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 64 |
+
The epsilon used by the layer normalization layers.
|
| 65 |
+
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
|
| 66 |
+
The initial value for the layer scale. Disabled if <=0.
|
| 67 |
+
out_features (`List[str]`, *optional*):
|
| 68 |
+
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
| 69 |
+
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
| 70 |
+
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
| 71 |
+
same order as defined in the `stage_names` attribute.
|
| 72 |
+
out_indices (`List[int]`, *optional*):
|
| 73 |
+
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
| 74 |
+
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
| 75 |
+
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
| 76 |
+
same order as defined in the `stage_names` attribute.
|
| 77 |
+
|
| 78 |
+
Example:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
>>> from transformers import NatConfig, NatModel
|
| 82 |
+
|
| 83 |
+
>>> # Initializing a Nat shi-labs/nat-mini-in1k-224 style configuration
|
| 84 |
+
>>> configuration = NatConfig()
|
| 85 |
+
|
| 86 |
+
>>> # Initializing a model (with random weights) from the shi-labs/nat-mini-in1k-224 style configuration
|
| 87 |
+
>>> model = NatModel(configuration)
|
| 88 |
+
|
| 89 |
+
>>> # Accessing the model configuration
|
| 90 |
+
>>> configuration = model.config
|
| 91 |
+
```"""
|
| 92 |
+
|
| 93 |
+
model_type = "nat"
|
| 94 |
+
|
| 95 |
+
attribute_map = {
|
| 96 |
+
"num_attention_heads": "num_heads",
|
| 97 |
+
"num_hidden_layers": "num_layers",
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
patch_size=4,
|
| 103 |
+
num_channels=3,
|
| 104 |
+
embed_dim=64,
|
| 105 |
+
depths=[3, 4, 6, 5],
|
| 106 |
+
num_heads=[2, 4, 8, 16],
|
| 107 |
+
kernel_size=7,
|
| 108 |
+
mlp_ratio=3.0,
|
| 109 |
+
qkv_bias=True,
|
| 110 |
+
hidden_dropout_prob=0.0,
|
| 111 |
+
attention_probs_dropout_prob=0.0,
|
| 112 |
+
drop_path_rate=0.1,
|
| 113 |
+
hidden_act="gelu",
|
| 114 |
+
initializer_range=0.02,
|
| 115 |
+
layer_norm_eps=1e-5,
|
| 116 |
+
layer_scale_init_value=0.0,
|
| 117 |
+
out_features=None,
|
| 118 |
+
out_indices=None,
|
| 119 |
+
**kwargs,
|
| 120 |
+
):
|
| 121 |
+
super().__init__(**kwargs)
|
| 122 |
+
|
| 123 |
+
self.patch_size = patch_size
|
| 124 |
+
self.num_channels = num_channels
|
| 125 |
+
self.embed_dim = embed_dim
|
| 126 |
+
self.depths = depths
|
| 127 |
+
self.num_layers = len(depths)
|
| 128 |
+
self.num_heads = num_heads
|
| 129 |
+
self.kernel_size = kernel_size
|
| 130 |
+
self.mlp_ratio = mlp_ratio
|
| 131 |
+
self.qkv_bias = qkv_bias
|
| 132 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 133 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 134 |
+
self.drop_path_rate = drop_path_rate
|
| 135 |
+
self.hidden_act = hidden_act
|
| 136 |
+
self.layer_norm_eps = layer_norm_eps
|
| 137 |
+
self.initializer_range = initializer_range
|
| 138 |
+
# we set the hidden_size attribute in order to make Nat work with VisionEncoderDecoderModel
|
| 139 |
+
# this indicates the channel dimension after the last stage of the model
|
| 140 |
+
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
| 141 |
+
self.layer_scale_init_value = layer_scale_init_value
|
| 142 |
+
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
| 143 |
+
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
| 144 |
+
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
__all__ = ["NatConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py
ADDED
|
@@ -0,0 +1,953 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Neighborhood Attention Transformer model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from ....activations import ACT2FN
|
| 27 |
+
from ....modeling_outputs import BackboneOutput
|
| 28 |
+
from ....modeling_utils import PreTrainedModel
|
| 29 |
+
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 30 |
+
from ....utils import (
|
| 31 |
+
ModelOutput,
|
| 32 |
+
OptionalDependencyNotAvailable,
|
| 33 |
+
add_code_sample_docstrings,
|
| 34 |
+
add_start_docstrings,
|
| 35 |
+
add_start_docstrings_to_model_forward,
|
| 36 |
+
is_natten_available,
|
| 37 |
+
logging,
|
| 38 |
+
replace_return_docstrings,
|
| 39 |
+
requires_backends,
|
| 40 |
+
)
|
| 41 |
+
from ....utils.backbone_utils import BackboneMixin
|
| 42 |
+
from .configuration_nat import NatConfig
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if is_natten_available():
|
| 46 |
+
from natten.functional import natten2dav, natten2dqkrpb
|
| 47 |
+
else:
|
| 48 |
+
|
| 49 |
+
def natten2dqkrpb(*args, **kwargs):
|
| 50 |
+
raise OptionalDependencyNotAvailable()
|
| 51 |
+
|
| 52 |
+
def natten2dav(*args, **kwargs):
|
| 53 |
+
raise OptionalDependencyNotAvailable()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
logger = logging.get_logger(__name__)
|
| 57 |
+
|
| 58 |
+
# General docstring
|
| 59 |
+
_CONFIG_FOR_DOC = "NatConfig"
|
| 60 |
+
|
| 61 |
+
# Base docstring
|
| 62 |
+
_CHECKPOINT_FOR_DOC = "shi-labs/nat-mini-in1k-224"
|
| 63 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 7, 7, 512]
|
| 64 |
+
|
| 65 |
+
# Image classification docstring
|
| 66 |
+
_IMAGE_CLASS_CHECKPOINT = "shi-labs/nat-mini-in1k-224"
|
| 67 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# drop_path and NatDropPath are from the timm library.
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class NatEncoderOutput(ModelOutput):
|
| 75 |
+
"""
|
| 76 |
+
Nat encoder's outputs, with potential hidden states and attentions.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 80 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 81 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 82 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
| 83 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 84 |
+
|
| 85 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 86 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 87 |
+
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
| 88 |
+
sequence_length)`.
|
| 89 |
+
|
| 90 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 91 |
+
heads.
|
| 92 |
+
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 93 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
| 94 |
+
shape `(batch_size, hidden_size, height, width)`.
|
| 95 |
+
|
| 96 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
| 97 |
+
include the spatial dimensions.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 101 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 102 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 103 |
+
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class NatModelOutput(ModelOutput):
|
| 108 |
+
"""
|
| 109 |
+
Nat model's outputs that also contains a pooling of the last hidden states.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 113 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 114 |
+
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
|
| 115 |
+
Average pooling of the last layer hidden-state.
|
| 116 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 117 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
| 118 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 119 |
+
|
| 120 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 121 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 122 |
+
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
| 123 |
+
sequence_length)`.
|
| 124 |
+
|
| 125 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 126 |
+
heads.
|
| 127 |
+
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 128 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
| 129 |
+
shape `(batch_size, hidden_size, height, width)`.
|
| 130 |
+
|
| 131 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
| 132 |
+
include the spatial dimensions.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 136 |
+
pooler_output: Optional[torch.FloatTensor] = None
|
| 137 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 138 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 139 |
+
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass
|
| 143 |
+
class NatImageClassifierOutput(ModelOutput):
|
| 144 |
+
"""
|
| 145 |
+
Nat outputs for image classification.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 149 |
+
Classification (or regression if config.num_labels==1) loss.
|
| 150 |
+
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
| 151 |
+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
| 152 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 153 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
| 154 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 155 |
+
|
| 156 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 157 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 158 |
+
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
| 159 |
+
sequence_length)`.
|
| 160 |
+
|
| 161 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 162 |
+
heads.
|
| 163 |
+
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 164 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
| 165 |
+
shape `(batch_size, hidden_size, height, width)`.
|
| 166 |
+
|
| 167 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
| 168 |
+
include the spatial dimensions.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
loss: Optional[torch.FloatTensor] = None
|
| 172 |
+
logits: Optional[torch.FloatTensor] = None
|
| 173 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 174 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 175 |
+
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class NatEmbeddings(nn.Module):
|
| 179 |
+
"""
|
| 180 |
+
Construct the patch and position embeddings.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, config):
|
| 184 |
+
super().__init__()
|
| 185 |
+
|
| 186 |
+
self.patch_embeddings = NatPatchEmbeddings(config)
|
| 187 |
+
|
| 188 |
+
self.norm = nn.LayerNorm(config.embed_dim)
|
| 189 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 190 |
+
|
| 191 |
+
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor]:
|
| 192 |
+
embeddings = self.patch_embeddings(pixel_values)
|
| 193 |
+
embeddings = self.norm(embeddings)
|
| 194 |
+
|
| 195 |
+
embeddings = self.dropout(embeddings)
|
| 196 |
+
|
| 197 |
+
return embeddings
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class NatPatchEmbeddings(nn.Module):
|
| 201 |
+
"""
|
| 202 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 203 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
|
| 204 |
+
Transformer.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def __init__(self, config):
|
| 208 |
+
super().__init__()
|
| 209 |
+
patch_size = config.patch_size
|
| 210 |
+
num_channels, hidden_size = config.num_channels, config.embed_dim
|
| 211 |
+
self.num_channels = num_channels
|
| 212 |
+
|
| 213 |
+
if patch_size == 4:
|
| 214 |
+
pass
|
| 215 |
+
else:
|
| 216 |
+
# TODO: Support arbitrary patch sizes.
|
| 217 |
+
raise ValueError("Dinat only supports patch size of 4 at the moment.")
|
| 218 |
+
|
| 219 |
+
self.projection = nn.Sequential(
|
| 220 |
+
nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
|
| 221 |
+
nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor:
|
| 225 |
+
_, num_channels, height, width = pixel_values.shape
|
| 226 |
+
if num_channels != self.num_channels:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 229 |
+
)
|
| 230 |
+
embeddings = self.projection(pixel_values)
|
| 231 |
+
embeddings = embeddings.permute(0, 2, 3, 1)
|
| 232 |
+
|
| 233 |
+
return embeddings
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class NatDownsampler(nn.Module):
|
| 237 |
+
"""
|
| 238 |
+
Convolutional Downsampling Layer.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
dim (`int`):
|
| 242 |
+
Number of input channels.
|
| 243 |
+
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
|
| 244 |
+
Normalization layer class.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.dim = dim
|
| 250 |
+
self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
| 251 |
+
self.norm = norm_layer(2 * dim)
|
| 252 |
+
|
| 253 |
+
def forward(self, input_feature: torch.Tensor) -> torch.Tensor:
|
| 254 |
+
input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
| 255 |
+
input_feature = self.norm(input_feature)
|
| 256 |
+
return input_feature
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
| 260 |
+
"""
|
| 261 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 262 |
+
|
| 263 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 264 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 265 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 266 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 267 |
+
argument.
|
| 268 |
+
"""
|
| 269 |
+
if drop_prob == 0.0 or not training:
|
| 270 |
+
return input
|
| 271 |
+
keep_prob = 1 - drop_prob
|
| 272 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 273 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
| 274 |
+
random_tensor.floor_() # binarize
|
| 275 |
+
output = input.div(keep_prob) * random_tensor
|
| 276 |
+
return output
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class NatDropPath(nn.Module):
|
| 280 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 281 |
+
|
| 282 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.drop_prob = drop_prob
|
| 285 |
+
|
| 286 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 287 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
| 288 |
+
|
| 289 |
+
def extra_repr(self) -> str:
|
| 290 |
+
return "p={}".format(self.drop_prob)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class NeighborhoodAttention(nn.Module):
|
| 294 |
+
def __init__(self, config, dim, num_heads, kernel_size):
|
| 295 |
+
super().__init__()
|
| 296 |
+
if dim % num_heads != 0:
|
| 297 |
+
raise ValueError(
|
| 298 |
+
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
self.num_attention_heads = num_heads
|
| 302 |
+
self.attention_head_size = int(dim / num_heads)
|
| 303 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 304 |
+
self.kernel_size = kernel_size
|
| 305 |
+
|
| 306 |
+
# rpb is learnable relative positional biases; same concept is used Swin.
|
| 307 |
+
self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1)))
|
| 308 |
+
|
| 309 |
+
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
| 310 |
+
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
| 311 |
+
self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
| 312 |
+
|
| 313 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 314 |
+
|
| 315 |
+
def transpose_for_scores(self, x):
|
| 316 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 317 |
+
x = x.view(new_x_shape)
|
| 318 |
+
return x.permute(0, 3, 1, 2, 4)
|
| 319 |
+
|
| 320 |
+
def forward(
|
| 321 |
+
self,
|
| 322 |
+
hidden_states: torch.Tensor,
|
| 323 |
+
output_attentions: Optional[bool] = False,
|
| 324 |
+
) -> Tuple[torch.Tensor]:
|
| 325 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 326 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 327 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 328 |
+
|
| 329 |
+
# Apply the scale factor before computing attention weights. It's usually more efficient because
|
| 330 |
+
# attention weights are typically a bigger tensor compared to query.
|
| 331 |
+
# It gives identical results because scalars are commutable in matrix multiplication.
|
| 332 |
+
query_layer = query_layer / math.sqrt(self.attention_head_size)
|
| 333 |
+
|
| 334 |
+
# Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases.
|
| 335 |
+
attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, 1)
|
| 336 |
+
|
| 337 |
+
# Normalize the attention scores to probabilities.
|
| 338 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 339 |
+
|
| 340 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 341 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 342 |
+
attention_probs = self.dropout(attention_probs)
|
| 343 |
+
|
| 344 |
+
context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, 1)
|
| 345 |
+
context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
|
| 346 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 347 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 348 |
+
|
| 349 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 350 |
+
|
| 351 |
+
return outputs
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class NeighborhoodAttentionOutput(nn.Module):
|
| 355 |
+
def __init__(self, config, dim):
|
| 356 |
+
super().__init__()
|
| 357 |
+
self.dense = nn.Linear(dim, dim)
|
| 358 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 359 |
+
|
| 360 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 361 |
+
hidden_states = self.dense(hidden_states)
|
| 362 |
+
hidden_states = self.dropout(hidden_states)
|
| 363 |
+
|
| 364 |
+
return hidden_states
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class NeighborhoodAttentionModule(nn.Module):
|
| 368 |
+
def __init__(self, config, dim, num_heads, kernel_size):
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size)
|
| 371 |
+
self.output = NeighborhoodAttentionOutput(config, dim)
|
| 372 |
+
self.pruned_heads = set()
|
| 373 |
+
|
| 374 |
+
def prune_heads(self, heads):
|
| 375 |
+
if len(heads) == 0:
|
| 376 |
+
return
|
| 377 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 378 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# Prune linear layers
|
| 382 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 383 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 384 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 385 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 386 |
+
|
| 387 |
+
# Update hyper params and store pruned heads
|
| 388 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 389 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 390 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 391 |
+
|
| 392 |
+
def forward(
|
| 393 |
+
self,
|
| 394 |
+
hidden_states: torch.Tensor,
|
| 395 |
+
output_attentions: Optional[bool] = False,
|
| 396 |
+
) -> Tuple[torch.Tensor]:
|
| 397 |
+
self_outputs = self.self(hidden_states, output_attentions)
|
| 398 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 399 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 400 |
+
return outputs
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class NatIntermediate(nn.Module):
|
| 404 |
+
def __init__(self, config, dim):
|
| 405 |
+
super().__init__()
|
| 406 |
+
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
|
| 407 |
+
if isinstance(config.hidden_act, str):
|
| 408 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 409 |
+
else:
|
| 410 |
+
self.intermediate_act_fn = config.hidden_act
|
| 411 |
+
|
| 412 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 413 |
+
hidden_states = self.dense(hidden_states)
|
| 414 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 415 |
+
return hidden_states
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class NatOutput(nn.Module):
|
| 419 |
+
def __init__(self, config, dim):
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
|
| 422 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 423 |
+
|
| 424 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 425 |
+
hidden_states = self.dense(hidden_states)
|
| 426 |
+
hidden_states = self.dropout(hidden_states)
|
| 427 |
+
return hidden_states
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class NatLayer(nn.Module):
|
| 431 |
+
def __init__(self, config, dim, num_heads, drop_path_rate=0.0):
|
| 432 |
+
super().__init__()
|
| 433 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 434 |
+
self.kernel_size = config.kernel_size
|
| 435 |
+
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
| 436 |
+
self.attention = NeighborhoodAttentionModule(config, dim, num_heads, kernel_size=self.kernel_size)
|
| 437 |
+
self.drop_path = NatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 438 |
+
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
| 439 |
+
self.intermediate = NatIntermediate(config, dim)
|
| 440 |
+
self.output = NatOutput(config, dim)
|
| 441 |
+
self.layer_scale_parameters = (
|
| 442 |
+
nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
|
| 443 |
+
if config.layer_scale_init_value > 0
|
| 444 |
+
else None
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
def maybe_pad(self, hidden_states, height, width):
|
| 448 |
+
window_size = self.kernel_size
|
| 449 |
+
pad_values = (0, 0, 0, 0, 0, 0)
|
| 450 |
+
if height < window_size or width < window_size:
|
| 451 |
+
pad_l = pad_t = 0
|
| 452 |
+
pad_r = max(0, window_size - width)
|
| 453 |
+
pad_b = max(0, window_size - height)
|
| 454 |
+
pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b)
|
| 455 |
+
hidden_states = nn.functional.pad(hidden_states, pad_values)
|
| 456 |
+
return hidden_states, pad_values
|
| 457 |
+
|
| 458 |
+
def forward(
|
| 459 |
+
self,
|
| 460 |
+
hidden_states: torch.Tensor,
|
| 461 |
+
output_attentions: Optional[bool] = False,
|
| 462 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 463 |
+
batch_size, height, width, channels = hidden_states.size()
|
| 464 |
+
shortcut = hidden_states
|
| 465 |
+
|
| 466 |
+
hidden_states = self.layernorm_before(hidden_states)
|
| 467 |
+
# pad hidden_states if they are smaller than kernel size
|
| 468 |
+
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
| 469 |
+
|
| 470 |
+
_, height_pad, width_pad, _ = hidden_states.shape
|
| 471 |
+
|
| 472 |
+
attention_outputs = self.attention(hidden_states, output_attentions=output_attentions)
|
| 473 |
+
|
| 474 |
+
attention_output = attention_outputs[0]
|
| 475 |
+
|
| 476 |
+
was_padded = pad_values[3] > 0 or pad_values[5] > 0
|
| 477 |
+
if was_padded:
|
| 478 |
+
attention_output = attention_output[:, :height, :width, :].contiguous()
|
| 479 |
+
|
| 480 |
+
if self.layer_scale_parameters is not None:
|
| 481 |
+
attention_output = self.layer_scale_parameters[0] * attention_output
|
| 482 |
+
|
| 483 |
+
hidden_states = shortcut + self.drop_path(attention_output)
|
| 484 |
+
|
| 485 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 486 |
+
layer_output = self.output(self.intermediate(layer_output))
|
| 487 |
+
|
| 488 |
+
if self.layer_scale_parameters is not None:
|
| 489 |
+
layer_output = self.layer_scale_parameters[1] * layer_output
|
| 490 |
+
|
| 491 |
+
layer_output = hidden_states + self.drop_path(layer_output)
|
| 492 |
+
|
| 493 |
+
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
|
| 494 |
+
return layer_outputs
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class NatStage(nn.Module):
|
| 498 |
+
def __init__(self, config, dim, depth, num_heads, drop_path_rate, downsample):
|
| 499 |
+
super().__init__()
|
| 500 |
+
self.config = config
|
| 501 |
+
self.dim = dim
|
| 502 |
+
self.layers = nn.ModuleList(
|
| 503 |
+
[
|
| 504 |
+
NatLayer(
|
| 505 |
+
config=config,
|
| 506 |
+
dim=dim,
|
| 507 |
+
num_heads=num_heads,
|
| 508 |
+
drop_path_rate=drop_path_rate[i],
|
| 509 |
+
)
|
| 510 |
+
for i in range(depth)
|
| 511 |
+
]
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# patch merging layer
|
| 515 |
+
if downsample is not None:
|
| 516 |
+
self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
|
| 517 |
+
else:
|
| 518 |
+
self.downsample = None
|
| 519 |
+
|
| 520 |
+
self.pointing = False
|
| 521 |
+
|
| 522 |
+
def forward(
|
| 523 |
+
self,
|
| 524 |
+
hidden_states: torch.Tensor,
|
| 525 |
+
output_attentions: Optional[bool] = False,
|
| 526 |
+
) -> Tuple[torch.Tensor]:
|
| 527 |
+
_, height, width, _ = hidden_states.size()
|
| 528 |
+
for i, layer_module in enumerate(self.layers):
|
| 529 |
+
layer_outputs = layer_module(hidden_states, output_attentions)
|
| 530 |
+
hidden_states = layer_outputs[0]
|
| 531 |
+
|
| 532 |
+
hidden_states_before_downsampling = hidden_states
|
| 533 |
+
if self.downsample is not None:
|
| 534 |
+
hidden_states = self.downsample(hidden_states_before_downsampling)
|
| 535 |
+
|
| 536 |
+
stage_outputs = (hidden_states, hidden_states_before_downsampling)
|
| 537 |
+
|
| 538 |
+
if output_attentions:
|
| 539 |
+
stage_outputs += layer_outputs[1:]
|
| 540 |
+
return stage_outputs
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class NatEncoder(nn.Module):
|
| 544 |
+
def __init__(self, config):
|
| 545 |
+
super().__init__()
|
| 546 |
+
self.num_levels = len(config.depths)
|
| 547 |
+
self.config = config
|
| 548 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
| 549 |
+
self.levels = nn.ModuleList(
|
| 550 |
+
[
|
| 551 |
+
NatStage(
|
| 552 |
+
config=config,
|
| 553 |
+
dim=int(config.embed_dim * 2**i_layer),
|
| 554 |
+
depth=config.depths[i_layer],
|
| 555 |
+
num_heads=config.num_heads[i_layer],
|
| 556 |
+
drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
|
| 557 |
+
downsample=NatDownsampler if (i_layer < self.num_levels - 1) else None,
|
| 558 |
+
)
|
| 559 |
+
for i_layer in range(self.num_levels)
|
| 560 |
+
]
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
def forward(
|
| 564 |
+
self,
|
| 565 |
+
hidden_states: torch.Tensor,
|
| 566 |
+
output_attentions: Optional[bool] = False,
|
| 567 |
+
output_hidden_states: Optional[bool] = False,
|
| 568 |
+
output_hidden_states_before_downsampling: Optional[bool] = False,
|
| 569 |
+
return_dict: Optional[bool] = True,
|
| 570 |
+
) -> Union[Tuple, NatEncoderOutput]:
|
| 571 |
+
all_hidden_states = () if output_hidden_states else None
|
| 572 |
+
all_reshaped_hidden_states = () if output_hidden_states else None
|
| 573 |
+
all_self_attentions = () if output_attentions else None
|
| 574 |
+
|
| 575 |
+
if output_hidden_states:
|
| 576 |
+
# rearrange b h w c -> b c h w
|
| 577 |
+
reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
|
| 578 |
+
all_hidden_states += (hidden_states,)
|
| 579 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
| 580 |
+
|
| 581 |
+
for i, layer_module in enumerate(self.levels):
|
| 582 |
+
layer_outputs = layer_module(hidden_states, output_attentions)
|
| 583 |
+
|
| 584 |
+
hidden_states = layer_outputs[0]
|
| 585 |
+
hidden_states_before_downsampling = layer_outputs[1]
|
| 586 |
+
|
| 587 |
+
if output_hidden_states and output_hidden_states_before_downsampling:
|
| 588 |
+
# rearrange b h w c -> b c h w
|
| 589 |
+
reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
|
| 590 |
+
all_hidden_states += (hidden_states_before_downsampling,)
|
| 591 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
| 592 |
+
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
| 593 |
+
# rearrange b h w c -> b c h w
|
| 594 |
+
reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
|
| 595 |
+
all_hidden_states += (hidden_states,)
|
| 596 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
| 597 |
+
|
| 598 |
+
if output_attentions:
|
| 599 |
+
all_self_attentions += layer_outputs[2:]
|
| 600 |
+
|
| 601 |
+
if not return_dict:
|
| 602 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 603 |
+
|
| 604 |
+
return NatEncoderOutput(
|
| 605 |
+
last_hidden_state=hidden_states,
|
| 606 |
+
hidden_states=all_hidden_states,
|
| 607 |
+
attentions=all_self_attentions,
|
| 608 |
+
reshaped_hidden_states=all_reshaped_hidden_states,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class NatPreTrainedModel(PreTrainedModel):
|
| 613 |
+
"""
|
| 614 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 615 |
+
models.
|
| 616 |
+
"""
|
| 617 |
+
|
| 618 |
+
config_class = NatConfig
|
| 619 |
+
base_model_prefix = "nat"
|
| 620 |
+
main_input_name = "pixel_values"
|
| 621 |
+
|
| 622 |
+
def _init_weights(self, module):
|
| 623 |
+
"""Initialize the weights"""
|
| 624 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 625 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 626 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 627 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 628 |
+
if module.bias is not None:
|
| 629 |
+
module.bias.data.zero_()
|
| 630 |
+
elif isinstance(module, nn.LayerNorm):
|
| 631 |
+
module.bias.data.zero_()
|
| 632 |
+
module.weight.data.fill_(1.0)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
NAT_START_DOCSTRING = r"""
|
| 636 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 637 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 638 |
+
behavior.
|
| 639 |
+
|
| 640 |
+
Parameters:
|
| 641 |
+
config ([`NatConfig`]): Model configuration class with all the parameters of the model.
|
| 642 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 643 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 644 |
+
"""
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
NAT_INPUTS_DOCSTRING = r"""
|
| 648 |
+
Args:
|
| 649 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 650 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
|
| 651 |
+
for details.
|
| 652 |
+
|
| 653 |
+
output_attentions (`bool`, *optional*):
|
| 654 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 655 |
+
tensors for more detail.
|
| 656 |
+
output_hidden_states (`bool`, *optional*):
|
| 657 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 658 |
+
more detail.
|
| 659 |
+
return_dict (`bool`, *optional*):
|
| 660 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 661 |
+
"""
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
@add_start_docstrings(
|
| 665 |
+
"The bare Nat Model transformer outputting raw hidden-states without any specific head on top.",
|
| 666 |
+
NAT_START_DOCSTRING,
|
| 667 |
+
)
|
| 668 |
+
class NatModel(NatPreTrainedModel):
|
| 669 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 670 |
+
super().__init__(config)
|
| 671 |
+
|
| 672 |
+
requires_backends(self, ["natten"])
|
| 673 |
+
|
| 674 |
+
self.config = config
|
| 675 |
+
self.num_levels = len(config.depths)
|
| 676 |
+
self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1))
|
| 677 |
+
|
| 678 |
+
self.embeddings = NatEmbeddings(config)
|
| 679 |
+
self.encoder = NatEncoder(config)
|
| 680 |
+
|
| 681 |
+
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
|
| 682 |
+
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
|
| 683 |
+
|
| 684 |
+
# Initialize weights and apply final processing
|
| 685 |
+
self.post_init()
|
| 686 |
+
|
| 687 |
+
def get_input_embeddings(self):
|
| 688 |
+
return self.embeddings.patch_embeddings
|
| 689 |
+
|
| 690 |
+
def _prune_heads(self, heads_to_prune):
|
| 691 |
+
"""
|
| 692 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 693 |
+
class PreTrainedModel
|
| 694 |
+
"""
|
| 695 |
+
for layer, heads in heads_to_prune.items():
|
| 696 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 697 |
+
|
| 698 |
+
@add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
|
| 699 |
+
@add_code_sample_docstrings(
|
| 700 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 701 |
+
output_type=NatModelOutput,
|
| 702 |
+
config_class=_CONFIG_FOR_DOC,
|
| 703 |
+
modality="vision",
|
| 704 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 705 |
+
)
|
| 706 |
+
def forward(
|
| 707 |
+
self,
|
| 708 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 709 |
+
output_attentions: Optional[bool] = None,
|
| 710 |
+
output_hidden_states: Optional[bool] = None,
|
| 711 |
+
return_dict: Optional[bool] = None,
|
| 712 |
+
) -> Union[Tuple, NatModelOutput]:
|
| 713 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 714 |
+
output_hidden_states = (
|
| 715 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 716 |
+
)
|
| 717 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 718 |
+
|
| 719 |
+
if pixel_values is None:
|
| 720 |
+
raise ValueError("You have to specify pixel_values")
|
| 721 |
+
|
| 722 |
+
embedding_output = self.embeddings(pixel_values)
|
| 723 |
+
|
| 724 |
+
encoder_outputs = self.encoder(
|
| 725 |
+
embedding_output,
|
| 726 |
+
output_attentions=output_attentions,
|
| 727 |
+
output_hidden_states=output_hidden_states,
|
| 728 |
+
return_dict=return_dict,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
sequence_output = encoder_outputs[0]
|
| 732 |
+
sequence_output = self.layernorm(sequence_output)
|
| 733 |
+
|
| 734 |
+
pooled_output = None
|
| 735 |
+
if self.pooler is not None:
|
| 736 |
+
pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2))
|
| 737 |
+
pooled_output = torch.flatten(pooled_output, 1)
|
| 738 |
+
|
| 739 |
+
if not return_dict:
|
| 740 |
+
output = (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 741 |
+
|
| 742 |
+
return output
|
| 743 |
+
|
| 744 |
+
return NatModelOutput(
|
| 745 |
+
last_hidden_state=sequence_output,
|
| 746 |
+
pooler_output=pooled_output,
|
| 747 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 748 |
+
attentions=encoder_outputs.attentions,
|
| 749 |
+
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
@add_start_docstrings(
|
| 754 |
+
"""
|
| 755 |
+
Nat Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
| 756 |
+
the [CLS] token) e.g. for ImageNet.
|
| 757 |
+
""",
|
| 758 |
+
NAT_START_DOCSTRING,
|
| 759 |
+
)
|
| 760 |
+
class NatForImageClassification(NatPreTrainedModel):
|
| 761 |
+
def __init__(self, config):
|
| 762 |
+
super().__init__(config)
|
| 763 |
+
|
| 764 |
+
requires_backends(self, ["natten"])
|
| 765 |
+
|
| 766 |
+
self.num_labels = config.num_labels
|
| 767 |
+
self.nat = NatModel(config)
|
| 768 |
+
|
| 769 |
+
# Classifier head
|
| 770 |
+
self.classifier = (
|
| 771 |
+
nn.Linear(self.nat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# Initialize weights and apply final processing
|
| 775 |
+
self.post_init()
|
| 776 |
+
|
| 777 |
+
@add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
|
| 778 |
+
@add_code_sample_docstrings(
|
| 779 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 780 |
+
output_type=NatImageClassifierOutput,
|
| 781 |
+
config_class=_CONFIG_FOR_DOC,
|
| 782 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 783 |
+
)
|
| 784 |
+
def forward(
|
| 785 |
+
self,
|
| 786 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 787 |
+
labels: Optional[torch.LongTensor] = None,
|
| 788 |
+
output_attentions: Optional[bool] = None,
|
| 789 |
+
output_hidden_states: Optional[bool] = None,
|
| 790 |
+
return_dict: Optional[bool] = None,
|
| 791 |
+
) -> Union[Tuple, NatImageClassifierOutput]:
|
| 792 |
+
r"""
|
| 793 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 794 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 795 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 796 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 797 |
+
"""
|
| 798 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 799 |
+
|
| 800 |
+
outputs = self.nat(
|
| 801 |
+
pixel_values,
|
| 802 |
+
output_attentions=output_attentions,
|
| 803 |
+
output_hidden_states=output_hidden_states,
|
| 804 |
+
return_dict=return_dict,
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
pooled_output = outputs[1]
|
| 808 |
+
|
| 809 |
+
logits = self.classifier(pooled_output)
|
| 810 |
+
|
| 811 |
+
loss = None
|
| 812 |
+
if labels is not None:
|
| 813 |
+
if self.config.problem_type is None:
|
| 814 |
+
if self.num_labels == 1:
|
| 815 |
+
self.config.problem_type = "regression"
|
| 816 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 817 |
+
self.config.problem_type = "single_label_classification"
|
| 818 |
+
else:
|
| 819 |
+
self.config.problem_type = "multi_label_classification"
|
| 820 |
+
|
| 821 |
+
if self.config.problem_type == "regression":
|
| 822 |
+
loss_fct = MSELoss()
|
| 823 |
+
if self.num_labels == 1:
|
| 824 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 825 |
+
else:
|
| 826 |
+
loss = loss_fct(logits, labels)
|
| 827 |
+
elif self.config.problem_type == "single_label_classification":
|
| 828 |
+
loss_fct = CrossEntropyLoss()
|
| 829 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 830 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 831 |
+
loss_fct = BCEWithLogitsLoss()
|
| 832 |
+
loss = loss_fct(logits, labels)
|
| 833 |
+
|
| 834 |
+
if not return_dict:
|
| 835 |
+
output = (logits,) + outputs[2:]
|
| 836 |
+
return ((loss,) + output) if loss is not None else output
|
| 837 |
+
|
| 838 |
+
return NatImageClassifierOutput(
|
| 839 |
+
loss=loss,
|
| 840 |
+
logits=logits,
|
| 841 |
+
hidden_states=outputs.hidden_states,
|
| 842 |
+
attentions=outputs.attentions,
|
| 843 |
+
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
@add_start_docstrings(
|
| 848 |
+
"NAT backbone, to be used with frameworks like DETR and MaskFormer.",
|
| 849 |
+
NAT_START_DOCSTRING,
|
| 850 |
+
)
|
| 851 |
+
class NatBackbone(NatPreTrainedModel, BackboneMixin):
|
| 852 |
+
def __init__(self, config):
|
| 853 |
+
super().__init__(config)
|
| 854 |
+
super()._init_backbone(config)
|
| 855 |
+
|
| 856 |
+
requires_backends(self, ["natten"])
|
| 857 |
+
|
| 858 |
+
self.embeddings = NatEmbeddings(config)
|
| 859 |
+
self.encoder = NatEncoder(config)
|
| 860 |
+
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
| 861 |
+
|
| 862 |
+
# Add layer norms to hidden states of out_features
|
| 863 |
+
hidden_states_norms = {}
|
| 864 |
+
for stage, num_channels in zip(self.out_features, self.channels):
|
| 865 |
+
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
|
| 866 |
+
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
| 867 |
+
|
| 868 |
+
# Initialize weights and apply final processing
|
| 869 |
+
self.post_init()
|
| 870 |
+
|
| 871 |
+
def get_input_embeddings(self):
|
| 872 |
+
return self.embeddings.patch_embeddings
|
| 873 |
+
|
| 874 |
+
@add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
|
| 875 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
| 876 |
+
def forward(
|
| 877 |
+
self,
|
| 878 |
+
pixel_values: torch.Tensor,
|
| 879 |
+
output_hidden_states: Optional[bool] = None,
|
| 880 |
+
output_attentions: Optional[bool] = None,
|
| 881 |
+
return_dict: Optional[bool] = None,
|
| 882 |
+
) -> BackboneOutput:
|
| 883 |
+
"""
|
| 884 |
+
Returns:
|
| 885 |
+
|
| 886 |
+
Examples:
|
| 887 |
+
|
| 888 |
+
```python
|
| 889 |
+
>>> from transformers import AutoImageProcessor, AutoBackbone
|
| 890 |
+
>>> import torch
|
| 891 |
+
>>> from PIL import Image
|
| 892 |
+
>>> import requests
|
| 893 |
+
|
| 894 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 895 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 896 |
+
|
| 897 |
+
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
|
| 898 |
+
>>> model = AutoBackbone.from_pretrained(
|
| 899 |
+
... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
|
| 900 |
+
... )
|
| 901 |
+
|
| 902 |
+
>>> inputs = processor(image, return_tensors="pt")
|
| 903 |
+
|
| 904 |
+
>>> outputs = model(**inputs)
|
| 905 |
+
|
| 906 |
+
>>> feature_maps = outputs.feature_maps
|
| 907 |
+
>>> list(feature_maps[-1].shape)
|
| 908 |
+
[1, 512, 7, 7]
|
| 909 |
+
```"""
|
| 910 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 911 |
+
output_hidden_states = (
|
| 912 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 913 |
+
)
|
| 914 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 915 |
+
|
| 916 |
+
embedding_output = self.embeddings(pixel_values)
|
| 917 |
+
|
| 918 |
+
outputs = self.encoder(
|
| 919 |
+
embedding_output,
|
| 920 |
+
output_attentions=output_attentions,
|
| 921 |
+
output_hidden_states=True,
|
| 922 |
+
output_hidden_states_before_downsampling=True,
|
| 923 |
+
return_dict=True,
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
hidden_states = outputs.reshaped_hidden_states
|
| 927 |
+
|
| 928 |
+
feature_maps = ()
|
| 929 |
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
| 930 |
+
if stage in self.out_features:
|
| 931 |
+
# TODO can we simplify this?
|
| 932 |
+
batch_size, num_channels, height, width = hidden_state.shape
|
| 933 |
+
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
| 934 |
+
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
|
| 935 |
+
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
| 936 |
+
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
|
| 937 |
+
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
| 938 |
+
feature_maps += (hidden_state,)
|
| 939 |
+
|
| 940 |
+
if not return_dict:
|
| 941 |
+
output = (feature_maps,)
|
| 942 |
+
if output_hidden_states:
|
| 943 |
+
output += (outputs.hidden_states,)
|
| 944 |
+
return output
|
| 945 |
+
|
| 946 |
+
return BackboneOutput(
|
| 947 |
+
feature_maps=feature_maps,
|
| 948 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 949 |
+
attentions=outputs.attentions,
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
__all__ = ["NatForImageClassification", "NatModel", "NatPreTrainedModel", "NatBackbone"]
|
docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_nezha import *
|
| 22 |
+
from .modeling_nezha import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .... import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class NezhaConfig(PretrainedConfig):
|
| 5 |
+
r"""
|
| 6 |
+
This is the configuration class to store the configuration of an [`NezhaModel`]. It is used to instantiate an Nezha
|
| 7 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 8 |
+
defaults will yield a similar configuration to that of the Nezha
|
| 9 |
+
[sijunhe/nezha-cn-base](https://huggingface.co/sijunhe/nezha-cn-base) architecture.
|
| 10 |
+
|
| 11 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 12 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
vocab_size (`int`, optional, defaults to 21128):
|
| 17 |
+
Vocabulary size of the NEZHA model. Defines the different tokens that can be represented by the
|
| 18 |
+
*inputs_ids* passed to the forward method of [`NezhaModel`].
|
| 19 |
+
hidden_size (`int`, optional, defaults to 768):
|
| 20 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 21 |
+
num_hidden_layers (`int`, optional, defaults to 12):
|
| 22 |
+
Number of hidden layers in the Transformer encoder.
|
| 23 |
+
num_attention_heads (`int`, optional, defaults to 12):
|
| 24 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 25 |
+
intermediate_size (`int`, optional, defaults to 3072):
|
| 26 |
+
The dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 27 |
+
hidden_act (`str` or `function`, optional, defaults to "gelu"):
|
| 28 |
+
The non-linear activation function (function or string) in the encoder and pooler.
|
| 29 |
+
hidden_dropout_prob (`float`, optional, defaults to 0.1):
|
| 30 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 31 |
+
attention_probs_dropout_prob (`float`, optional, defaults to 0.1):
|
| 32 |
+
The dropout ratio for the attention probabilities.
|
| 33 |
+
max_position_embeddings (`int`, optional, defaults to 512):
|
| 34 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 35 |
+
(e.g., 512 or 1024 or 2048).
|
| 36 |
+
type_vocab_size (`int`, optional, defaults to 2):
|
| 37 |
+
The vocabulary size of the *token_type_ids* passed into [`NezhaModel`].
|
| 38 |
+
initializer_range (`float`, optional, defaults to 0.02):
|
| 39 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 40 |
+
layer_norm_eps (`float`, optional, defaults to 1e-12):
|
| 41 |
+
The epsilon used by the layer normalization layers.
|
| 42 |
+
classifier_dropout (`float`, optional, defaults to 0.1):
|
| 43 |
+
The dropout ratio for attached classifiers.
|
| 44 |
+
is_decoder (`bool`, *optional*, defaults to `False`):
|
| 45 |
+
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
|
| 46 |
+
|
| 47 |
+
Example:
|
| 48 |
+
|
| 49 |
+
```python
|
| 50 |
+
>>> from transformers import NezhaConfig, NezhaModel
|
| 51 |
+
|
| 52 |
+
>>> # Initializing an Nezha configuration
|
| 53 |
+
>>> configuration = NezhaConfig()
|
| 54 |
+
|
| 55 |
+
>>> # Initializing a model (with random weights) from the Nezha-base style configuration model
|
| 56 |
+
>>> model = NezhaModel(configuration)
|
| 57 |
+
|
| 58 |
+
>>> # Accessing the model configuration
|
| 59 |
+
>>> configuration = model.config
|
| 60 |
+
```"""
|
| 61 |
+
|
| 62 |
+
model_type = "nezha"
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
vocab_size=21128,
|
| 67 |
+
hidden_size=768,
|
| 68 |
+
num_hidden_layers=12,
|
| 69 |
+
num_attention_heads=12,
|
| 70 |
+
intermediate_size=3072,
|
| 71 |
+
hidden_act="gelu",
|
| 72 |
+
hidden_dropout_prob=0.1,
|
| 73 |
+
attention_probs_dropout_prob=0.1,
|
| 74 |
+
max_position_embeddings=512,
|
| 75 |
+
max_relative_position=64,
|
| 76 |
+
type_vocab_size=2,
|
| 77 |
+
initializer_range=0.02,
|
| 78 |
+
layer_norm_eps=1e-12,
|
| 79 |
+
classifier_dropout=0.1,
|
| 80 |
+
pad_token_id=0,
|
| 81 |
+
bos_token_id=2,
|
| 82 |
+
eos_token_id=3,
|
| 83 |
+
use_cache=True,
|
| 84 |
+
**kwargs,
|
| 85 |
+
):
|
| 86 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 87 |
+
|
| 88 |
+
self.vocab_size = vocab_size
|
| 89 |
+
self.hidden_size = hidden_size
|
| 90 |
+
self.num_hidden_layers = num_hidden_layers
|
| 91 |
+
self.num_attention_heads = num_attention_heads
|
| 92 |
+
self.hidden_act = hidden_act
|
| 93 |
+
self.intermediate_size = intermediate_size
|
| 94 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 95 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 96 |
+
self.max_position_embeddings = max_position_embeddings
|
| 97 |
+
self.max_relative_position = max_relative_position
|
| 98 |
+
self.type_vocab_size = type_vocab_size
|
| 99 |
+
self.initializer_range = initializer_range
|
| 100 |
+
self.layer_norm_eps = layer_norm_eps
|
| 101 |
+
self.classifier_dropout = classifier_dropout
|
| 102 |
+
self.use_cache = use_cache
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
__all__ = ["NezhaConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py
ADDED
|
@@ -0,0 +1,1697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Nezha model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
import warnings
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
+
|
| 28 |
+
from ....activations import ACT2FN
|
| 29 |
+
from ....modeling_outputs import (
|
| 30 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 31 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 32 |
+
MaskedLMOutput,
|
| 33 |
+
MultipleChoiceModelOutput,
|
| 34 |
+
NextSentencePredictorOutput,
|
| 35 |
+
QuestionAnsweringModelOutput,
|
| 36 |
+
SequenceClassifierOutput,
|
| 37 |
+
TokenClassifierOutput,
|
| 38 |
+
)
|
| 39 |
+
from ....modeling_utils import PreTrainedModel
|
| 40 |
+
from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 41 |
+
from ....utils import (
|
| 42 |
+
ModelOutput,
|
| 43 |
+
add_code_sample_docstrings,
|
| 44 |
+
add_start_docstrings,
|
| 45 |
+
add_start_docstrings_to_model_forward,
|
| 46 |
+
logging,
|
| 47 |
+
replace_return_docstrings,
|
| 48 |
+
)
|
| 49 |
+
from .configuration_nezha import NezhaConfig
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
logger = logging.get_logger(__name__)
|
| 53 |
+
|
| 54 |
+
_CHECKPOINT_FOR_DOC = "sijunhe/nezha-cn-base"
|
| 55 |
+
_CONFIG_FOR_DOC = "NezhaConfig"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_tf_weights_in_nezha(model, config, tf_checkpoint_path):
|
| 59 |
+
"""Load tf checkpoints in a pytorch model."""
|
| 60 |
+
try:
|
| 61 |
+
import re
|
| 62 |
+
|
| 63 |
+
import numpy as np
|
| 64 |
+
import tensorflow as tf
|
| 65 |
+
except ImportError:
|
| 66 |
+
logger.error(
|
| 67 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 68 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 69 |
+
)
|
| 70 |
+
raise
|
| 71 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
| 72 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
| 73 |
+
# Load weights from TF model
|
| 74 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 75 |
+
names = []
|
| 76 |
+
arrays = []
|
| 77 |
+
for name, shape in init_vars:
|
| 78 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
| 79 |
+
array = tf.train.load_variable(tf_path, name)
|
| 80 |
+
names.append(name)
|
| 81 |
+
arrays.append(array)
|
| 82 |
+
|
| 83 |
+
for name, array in zip(names, arrays):
|
| 84 |
+
name = name.split("/")
|
| 85 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
| 86 |
+
# which are not required for using pretrained model
|
| 87 |
+
if any(
|
| 88 |
+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
| 89 |
+
for n in name
|
| 90 |
+
):
|
| 91 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 92 |
+
continue
|
| 93 |
+
pointer = model
|
| 94 |
+
for m_name in name:
|
| 95 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
| 96 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
| 97 |
+
else:
|
| 98 |
+
scope_names = [m_name]
|
| 99 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
| 100 |
+
pointer = getattr(pointer, "weight")
|
| 101 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
| 102 |
+
pointer = getattr(pointer, "bias")
|
| 103 |
+
elif scope_names[0] == "output_weights":
|
| 104 |
+
pointer = getattr(pointer, "weight")
|
| 105 |
+
elif scope_names[0] == "squad":
|
| 106 |
+
pointer = getattr(pointer, "classifier")
|
| 107 |
+
else:
|
| 108 |
+
try:
|
| 109 |
+
pointer = getattr(pointer, scope_names[0])
|
| 110 |
+
except AttributeError:
|
| 111 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 112 |
+
continue
|
| 113 |
+
if len(scope_names) >= 2:
|
| 114 |
+
num = int(scope_names[1])
|
| 115 |
+
pointer = pointer[num]
|
| 116 |
+
if m_name[-11:] == "_embeddings":
|
| 117 |
+
pointer = getattr(pointer, "weight")
|
| 118 |
+
elif m_name == "kernel":
|
| 119 |
+
array = np.transpose(array)
|
| 120 |
+
try:
|
| 121 |
+
if pointer.shape != array.shape:
|
| 122 |
+
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
| 123 |
+
except AssertionError as e:
|
| 124 |
+
e.args += (pointer.shape, array.shape)
|
| 125 |
+
raise
|
| 126 |
+
logger.info(f"Initialize PyTorch weight {name}")
|
| 127 |
+
pointer.data = torch.from_numpy(array)
|
| 128 |
+
return model
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class NezhaRelativePositionsEncoding(nn.Module):
|
| 132 |
+
"""Implement the Functional Relative Position Encoding"""
|
| 133 |
+
|
| 134 |
+
def __init__(self, length, depth, max_relative_position=127):
|
| 135 |
+
super().__init__()
|
| 136 |
+
vocab_size = max_relative_position * 2 + 1
|
| 137 |
+
range_vec = torch.arange(length)
|
| 138 |
+
range_mat = range_vec.repeat(length).view(length, length)
|
| 139 |
+
distance_mat = range_mat - torch.t(range_mat)
|
| 140 |
+
distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position)
|
| 141 |
+
final_mat = distance_mat_clipped + max_relative_position
|
| 142 |
+
|
| 143 |
+
embeddings_table = torch.zeros(vocab_size, depth)
|
| 144 |
+
position = torch.arange(0, vocab_size, dtype=torch.int64).float().unsqueeze(1)
|
| 145 |
+
div_term = torch.exp(torch.arange(0, depth, 2).float() * (-math.log(10000.0) / depth))
|
| 146 |
+
embeddings_table[:, 0::2] = torch.sin(position * div_term)
|
| 147 |
+
embeddings_table[:, 1::2] = torch.cos(position * div_term)
|
| 148 |
+
|
| 149 |
+
flat_relative_positions_matrix = final_mat.view(-1)
|
| 150 |
+
one_hot_relative_positions_matrix = torch.nn.functional.one_hot(
|
| 151 |
+
flat_relative_positions_matrix, num_classes=vocab_size
|
| 152 |
+
).float()
|
| 153 |
+
positions_encoding = torch.matmul(one_hot_relative_positions_matrix, embeddings_table)
|
| 154 |
+
my_shape = list(final_mat.size())
|
| 155 |
+
my_shape.append(depth)
|
| 156 |
+
positions_encoding = positions_encoding.view(my_shape)
|
| 157 |
+
self.register_buffer("positions_encoding", positions_encoding, persistent=False)
|
| 158 |
+
|
| 159 |
+
def forward(self, length):
|
| 160 |
+
return self.positions_encoding[:length, :length, :]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class NezhaEmbeddings(nn.Module):
|
| 164 |
+
"""Construct the embeddings from word and token_type embeddings."""
|
| 165 |
+
|
| 166 |
+
def __init__(self, config):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 169 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 170 |
+
|
| 171 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 172 |
+
# any TensorFlow checkpoint file
|
| 173 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 174 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 175 |
+
self.register_buffer(
|
| 176 |
+
"token_type_ids", torch.zeros((1, config.max_position_embeddings), dtype=torch.long), persistent=False
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 182 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 183 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 184 |
+
) -> torch.Tensor:
|
| 185 |
+
if input_ids is not None:
|
| 186 |
+
input_shape = input_ids.size()
|
| 187 |
+
else:
|
| 188 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 189 |
+
|
| 190 |
+
seq_length = input_shape[1]
|
| 191 |
+
|
| 192 |
+
if inputs_embeds is None:
|
| 193 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 194 |
+
|
| 195 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 196 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 197 |
+
# issue #5664
|
| 198 |
+
if token_type_ids is None:
|
| 199 |
+
if hasattr(self, "token_type_ids"):
|
| 200 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 201 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 202 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 203 |
+
else:
|
| 204 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
|
| 205 |
+
|
| 206 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 207 |
+
|
| 208 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 209 |
+
embeddings = self.LayerNorm(embeddings)
|
| 210 |
+
embeddings = self.dropout(embeddings)
|
| 211 |
+
return embeddings
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class NezhaSelfAttention(nn.Module):
|
| 215 |
+
def __init__(self, config):
|
| 216 |
+
super().__init__()
|
| 217 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 218 |
+
raise ValueError(
|
| 219 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 220 |
+
f"heads ({config.num_attention_heads})"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.num_attention_heads = config.num_attention_heads
|
| 224 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 225 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 226 |
+
|
| 227 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 228 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 229 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 230 |
+
|
| 231 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 232 |
+
self.relative_positions_encoding = NezhaRelativePositionsEncoding(
|
| 233 |
+
length=config.max_position_embeddings,
|
| 234 |
+
depth=self.attention_head_size,
|
| 235 |
+
max_relative_position=config.max_relative_position,
|
| 236 |
+
)
|
| 237 |
+
self.is_decoder = config.is_decoder
|
| 238 |
+
|
| 239 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 240 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 241 |
+
x = x.view(new_x_shape)
|
| 242 |
+
return x.permute(0, 2, 1, 3)
|
| 243 |
+
|
| 244 |
+
def forward(
|
| 245 |
+
self,
|
| 246 |
+
hidden_states: torch.Tensor,
|
| 247 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 248 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 249 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 250 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 251 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 252 |
+
output_attentions: Optional[bool] = False,
|
| 253 |
+
) -> Tuple[torch.Tensor]:
|
| 254 |
+
mixed_query_layer = self.query(hidden_states)
|
| 255 |
+
|
| 256 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 257 |
+
# and values come from an encoder; the attention mask needs to be
|
| 258 |
+
# such that the encoder's padding tokens are not attended to.
|
| 259 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 260 |
+
|
| 261 |
+
if is_cross_attention and past_key_value is not None:
|
| 262 |
+
# reuse k,v, cross_attentions
|
| 263 |
+
key_layer = past_key_value[0]
|
| 264 |
+
value_layer = past_key_value[1]
|
| 265 |
+
attention_mask = encoder_attention_mask
|
| 266 |
+
elif is_cross_attention:
|
| 267 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 268 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 269 |
+
attention_mask = encoder_attention_mask
|
| 270 |
+
elif past_key_value is not None:
|
| 271 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 272 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 273 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 274 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 275 |
+
else:
|
| 276 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 277 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 278 |
+
|
| 279 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 280 |
+
|
| 281 |
+
if self.is_decoder:
|
| 282 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 283 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 284 |
+
# key/value_states (first "if" case)
|
| 285 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 286 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 287 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 288 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 289 |
+
past_key_value = (key_layer, value_layer)
|
| 290 |
+
|
| 291 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 292 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 293 |
+
|
| 294 |
+
batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.size()
|
| 295 |
+
relations_keys = self.relative_positions_encoding(to_seq_length)
|
| 296 |
+
query_layer_t = query_layer.permute(2, 0, 1, 3)
|
| 297 |
+
|
| 298 |
+
query_layer_r = query_layer_t.contiguous().view(
|
| 299 |
+
from_seq_length, batch_size * num_attention_heads, self.attention_head_size
|
| 300 |
+
)
|
| 301 |
+
key_position_scores = torch.matmul(query_layer_r, relations_keys.permute(0, 2, 1))
|
| 302 |
+
key_position_scores_r = key_position_scores.view(
|
| 303 |
+
from_seq_length, batch_size, num_attention_heads, from_seq_length
|
| 304 |
+
)
|
| 305 |
+
key_position_scores_r_t = key_position_scores_r.permute(1, 2, 0, 3)
|
| 306 |
+
attention_scores = attention_scores + key_position_scores_r_t
|
| 307 |
+
|
| 308 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 309 |
+
|
| 310 |
+
if attention_mask is not None:
|
| 311 |
+
# Apply the attention mask is (precomputed for all layers in NezhaModel forward() function)
|
| 312 |
+
attention_scores = attention_scores + attention_mask
|
| 313 |
+
|
| 314 |
+
# Normalize the attention scores to probabilities.
|
| 315 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 316 |
+
|
| 317 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 318 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 319 |
+
attention_probs = self.dropout(attention_probs)
|
| 320 |
+
|
| 321 |
+
# Mask heads if we want to
|
| 322 |
+
if head_mask is not None:
|
| 323 |
+
attention_probs = attention_probs * head_mask
|
| 324 |
+
|
| 325 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 326 |
+
relations_values = self.relative_positions_encoding(to_seq_length)
|
| 327 |
+
attention_probs_t = attention_probs.permute(2, 0, 1, 3)
|
| 328 |
+
attentions_probs_r = attention_probs_t.contiguous().view(
|
| 329 |
+
from_seq_length, batch_size * num_attention_heads, to_seq_length
|
| 330 |
+
)
|
| 331 |
+
value_position_scores = torch.matmul(attentions_probs_r, relations_values)
|
| 332 |
+
value_position_scores_r = value_position_scores.view(
|
| 333 |
+
from_seq_length, batch_size, num_attention_heads, self.attention_head_size
|
| 334 |
+
)
|
| 335 |
+
value_position_scores_r_t = value_position_scores_r.permute(1, 2, 0, 3)
|
| 336 |
+
context_layer = context_layer + value_position_scores_r_t
|
| 337 |
+
|
| 338 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 339 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 340 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 341 |
+
|
| 342 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 343 |
+
|
| 344 |
+
if self.is_decoder:
|
| 345 |
+
outputs = outputs + (past_key_value,)
|
| 346 |
+
return outputs
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class NezhaSelfOutput(nn.Module):
|
| 350 |
+
def __init__(self, config):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 353 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 354 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 355 |
+
|
| 356 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
hidden_states = self.dense(hidden_states)
|
| 358 |
+
hidden_states = self.dropout(hidden_states)
|
| 359 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 360 |
+
return hidden_states
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class NezhaAttention(nn.Module):
|
| 364 |
+
def __init__(self, config):
|
| 365 |
+
super().__init__()
|
| 366 |
+
self.self = NezhaSelfAttention(config)
|
| 367 |
+
self.output = NezhaSelfOutput(config)
|
| 368 |
+
self.pruned_heads = set()
|
| 369 |
+
|
| 370 |
+
def prune_heads(self, heads):
|
| 371 |
+
if len(heads) == 0:
|
| 372 |
+
return
|
| 373 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 374 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Prune linear layers
|
| 378 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 379 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 380 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 381 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 382 |
+
|
| 383 |
+
# Update hyper params and store pruned heads
|
| 384 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 385 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 386 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 387 |
+
|
| 388 |
+
def forward(
|
| 389 |
+
self,
|
| 390 |
+
hidden_states: torch.Tensor,
|
| 391 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 392 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 393 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 394 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 395 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 396 |
+
output_attentions: Optional[bool] = False,
|
| 397 |
+
) -> Tuple[torch.Tensor]:
|
| 398 |
+
self_outputs = self.self(
|
| 399 |
+
hidden_states,
|
| 400 |
+
attention_mask,
|
| 401 |
+
head_mask,
|
| 402 |
+
encoder_hidden_states,
|
| 403 |
+
encoder_attention_mask,
|
| 404 |
+
past_key_value,
|
| 405 |
+
output_attentions,
|
| 406 |
+
)
|
| 407 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 408 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 409 |
+
return outputs
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class NezhaIntermediate(nn.Module):
|
| 413 |
+
def __init__(self, config):
|
| 414 |
+
super().__init__()
|
| 415 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 416 |
+
if isinstance(config.hidden_act, str):
|
| 417 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 418 |
+
else:
|
| 419 |
+
self.intermediate_act_fn = config.hidden_act
|
| 420 |
+
|
| 421 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 422 |
+
hidden_states = self.dense(hidden_states)
|
| 423 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 424 |
+
return hidden_states
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class NezhaOutput(nn.Module):
|
| 428 |
+
def __init__(self, config):
|
| 429 |
+
super().__init__()
|
| 430 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 431 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 432 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 433 |
+
|
| 434 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 435 |
+
hidden_states = self.dense(hidden_states)
|
| 436 |
+
hidden_states = self.dropout(hidden_states)
|
| 437 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 438 |
+
return hidden_states
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class NezhaLayer(nn.Module):
|
| 442 |
+
def __init__(self, config):
|
| 443 |
+
super().__init__()
|
| 444 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 445 |
+
self.seq_len_dim = 1
|
| 446 |
+
self.attention = NezhaAttention(config)
|
| 447 |
+
self.is_decoder = config.is_decoder
|
| 448 |
+
self.add_cross_attention = config.add_cross_attention
|
| 449 |
+
if self.add_cross_attention:
|
| 450 |
+
if not self.is_decoder:
|
| 451 |
+
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
| 452 |
+
self.crossattention = NezhaAttention(config)
|
| 453 |
+
self.intermediate = NezhaIntermediate(config)
|
| 454 |
+
self.output = NezhaOutput(config)
|
| 455 |
+
|
| 456 |
+
def forward(
|
| 457 |
+
self,
|
| 458 |
+
hidden_states: torch.Tensor,
|
| 459 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 460 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 461 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 462 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 463 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 464 |
+
output_attentions: Optional[bool] = False,
|
| 465 |
+
) -> Tuple[torch.Tensor]:
|
| 466 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 467 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 468 |
+
self_attention_outputs = self.attention(
|
| 469 |
+
hidden_states,
|
| 470 |
+
attention_mask,
|
| 471 |
+
head_mask,
|
| 472 |
+
output_attentions=output_attentions,
|
| 473 |
+
past_key_value=self_attn_past_key_value,
|
| 474 |
+
)
|
| 475 |
+
attention_output = self_attention_outputs[0]
|
| 476 |
+
|
| 477 |
+
# if decoder, the last output is tuple of self-attn cache
|
| 478 |
+
if self.is_decoder:
|
| 479 |
+
outputs = self_attention_outputs[1:-1]
|
| 480 |
+
present_key_value = self_attention_outputs[-1]
|
| 481 |
+
else:
|
| 482 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 483 |
+
|
| 484 |
+
cross_attn_present_key_value = None
|
| 485 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
| 486 |
+
if not hasattr(self, "crossattention"):
|
| 487 |
+
raise ValueError(
|
| 488 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
| 489 |
+
" by setting `config.add_cross_attention=True`"
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
| 493 |
+
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
| 494 |
+
cross_attention_outputs = self.crossattention(
|
| 495 |
+
attention_output,
|
| 496 |
+
attention_mask,
|
| 497 |
+
head_mask,
|
| 498 |
+
encoder_hidden_states,
|
| 499 |
+
encoder_attention_mask,
|
| 500 |
+
cross_attn_past_key_value,
|
| 501 |
+
output_attentions,
|
| 502 |
+
)
|
| 503 |
+
attention_output = cross_attention_outputs[0]
|
| 504 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 505 |
+
|
| 506 |
+
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
| 507 |
+
cross_attn_present_key_value = cross_attention_outputs[-1]
|
| 508 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
| 509 |
+
|
| 510 |
+
layer_output = apply_chunking_to_forward(
|
| 511 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 512 |
+
)
|
| 513 |
+
outputs = (layer_output,) + outputs
|
| 514 |
+
|
| 515 |
+
# if decoder, return the attn key/values as the last output
|
| 516 |
+
if self.is_decoder:
|
| 517 |
+
outputs = outputs + (present_key_value,)
|
| 518 |
+
|
| 519 |
+
return outputs
|
| 520 |
+
|
| 521 |
+
def feed_forward_chunk(self, attention_output):
|
| 522 |
+
intermediate_output = self.intermediate(attention_output)
|
| 523 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 524 |
+
return layer_output
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class NezhaEncoder(nn.Module):
|
| 528 |
+
def __init__(self, config):
|
| 529 |
+
super().__init__()
|
| 530 |
+
self.config = config
|
| 531 |
+
self.layer = nn.ModuleList([NezhaLayer(config) for _ in range(config.num_hidden_layers)])
|
| 532 |
+
self.gradient_checkpointing = False
|
| 533 |
+
|
| 534 |
+
def forward(
|
| 535 |
+
self,
|
| 536 |
+
hidden_states: torch.Tensor,
|
| 537 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 538 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 539 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 540 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 541 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 542 |
+
use_cache: Optional[bool] = None,
|
| 543 |
+
output_attentions: Optional[bool] = False,
|
| 544 |
+
output_hidden_states: Optional[bool] = False,
|
| 545 |
+
return_dict: Optional[bool] = True,
|
| 546 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 547 |
+
all_hidden_states = () if output_hidden_states else None
|
| 548 |
+
all_self_attentions = () if output_attentions else None
|
| 549 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 550 |
+
|
| 551 |
+
if self.gradient_checkpointing and self.training:
|
| 552 |
+
if use_cache:
|
| 553 |
+
logger.warning_once(
|
| 554 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 555 |
+
)
|
| 556 |
+
use_cache = False
|
| 557 |
+
|
| 558 |
+
next_decoder_cache = () if use_cache else None
|
| 559 |
+
for i, layer_module in enumerate(self.layer):
|
| 560 |
+
if output_hidden_states:
|
| 561 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 562 |
+
|
| 563 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 564 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 565 |
+
|
| 566 |
+
if self.gradient_checkpointing and self.training:
|
| 567 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 568 |
+
layer_module.__call__,
|
| 569 |
+
hidden_states,
|
| 570 |
+
attention_mask,
|
| 571 |
+
layer_head_mask,
|
| 572 |
+
encoder_hidden_states,
|
| 573 |
+
encoder_attention_mask,
|
| 574 |
+
past_key_value,
|
| 575 |
+
output_attentions,
|
| 576 |
+
)
|
| 577 |
+
else:
|
| 578 |
+
layer_outputs = layer_module(
|
| 579 |
+
hidden_states,
|
| 580 |
+
attention_mask,
|
| 581 |
+
layer_head_mask,
|
| 582 |
+
encoder_hidden_states,
|
| 583 |
+
encoder_attention_mask,
|
| 584 |
+
past_key_value,
|
| 585 |
+
output_attentions,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
hidden_states = layer_outputs[0]
|
| 589 |
+
if use_cache:
|
| 590 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 591 |
+
if output_attentions:
|
| 592 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 593 |
+
if self.config.add_cross_attention:
|
| 594 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 595 |
+
|
| 596 |
+
if output_hidden_states:
|
| 597 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 598 |
+
|
| 599 |
+
if not return_dict:
|
| 600 |
+
return tuple(
|
| 601 |
+
v
|
| 602 |
+
for v in [
|
| 603 |
+
hidden_states,
|
| 604 |
+
next_decoder_cache,
|
| 605 |
+
all_hidden_states,
|
| 606 |
+
all_self_attentions,
|
| 607 |
+
all_cross_attentions,
|
| 608 |
+
]
|
| 609 |
+
if v is not None
|
| 610 |
+
)
|
| 611 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 612 |
+
last_hidden_state=hidden_states,
|
| 613 |
+
past_key_values=next_decoder_cache,
|
| 614 |
+
hidden_states=all_hidden_states,
|
| 615 |
+
attentions=all_self_attentions,
|
| 616 |
+
cross_attentions=all_cross_attentions,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class NezhaPooler(nn.Module):
|
| 621 |
+
def __init__(self, config):
|
| 622 |
+
super().__init__()
|
| 623 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 624 |
+
self.activation = nn.Tanh()
|
| 625 |
+
|
| 626 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 627 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 628 |
+
# to the first token.
|
| 629 |
+
first_token_tensor = hidden_states[:, 0]
|
| 630 |
+
pooled_output = self.dense(first_token_tensor)
|
| 631 |
+
pooled_output = self.activation(pooled_output)
|
| 632 |
+
return pooled_output
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class NezhaPredictionHeadTransform(nn.Module):
|
| 636 |
+
def __init__(self, config):
|
| 637 |
+
super().__init__()
|
| 638 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 639 |
+
if isinstance(config.hidden_act, str):
|
| 640 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 641 |
+
else:
|
| 642 |
+
self.transform_act_fn = config.hidden_act
|
| 643 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 644 |
+
|
| 645 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 646 |
+
hidden_states = self.dense(hidden_states)
|
| 647 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 648 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 649 |
+
return hidden_states
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class NezhaLMPredictionHead(nn.Module):
|
| 653 |
+
def __init__(self, config):
|
| 654 |
+
super().__init__()
|
| 655 |
+
self.transform = NezhaPredictionHeadTransform(config)
|
| 656 |
+
|
| 657 |
+
# The output weights are the same as the input embeddings, but there is
|
| 658 |
+
# an output-only bias for each token.
|
| 659 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 660 |
+
|
| 661 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 662 |
+
|
| 663 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 664 |
+
self.decoder.bias = self.bias
|
| 665 |
+
|
| 666 |
+
def _tie_weights(self):
|
| 667 |
+
self.decoder.bias = self.bias
|
| 668 |
+
|
| 669 |
+
def forward(self, hidden_states):
|
| 670 |
+
hidden_states = self.transform(hidden_states)
|
| 671 |
+
hidden_states = self.decoder(hidden_states)
|
| 672 |
+
return hidden_states
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
class NezhaOnlyMLMHead(nn.Module):
|
| 676 |
+
def __init__(self, config):
|
| 677 |
+
super().__init__()
|
| 678 |
+
self.predictions = NezhaLMPredictionHead(config)
|
| 679 |
+
|
| 680 |
+
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
| 681 |
+
prediction_scores = self.predictions(sequence_output)
|
| 682 |
+
return prediction_scores
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
class NezhaOnlyNSPHead(nn.Module):
|
| 686 |
+
def __init__(self, config):
|
| 687 |
+
super().__init__()
|
| 688 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 689 |
+
|
| 690 |
+
def forward(self, pooled_output):
|
| 691 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 692 |
+
return seq_relationship_score
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
class NezhaPreTrainingHeads(nn.Module):
|
| 696 |
+
def __init__(self, config):
|
| 697 |
+
super().__init__()
|
| 698 |
+
self.predictions = NezhaLMPredictionHead(config)
|
| 699 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 700 |
+
|
| 701 |
+
def forward(self, sequence_output, pooled_output):
|
| 702 |
+
prediction_scores = self.predictions(sequence_output)
|
| 703 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 704 |
+
return prediction_scores, seq_relationship_score
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class NezhaPreTrainedModel(PreTrainedModel):
|
| 708 |
+
"""
|
| 709 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 710 |
+
models.
|
| 711 |
+
"""
|
| 712 |
+
|
| 713 |
+
config_class = NezhaConfig
|
| 714 |
+
load_tf_weights = load_tf_weights_in_nezha
|
| 715 |
+
base_model_prefix = "nezha"
|
| 716 |
+
supports_gradient_checkpointing = True
|
| 717 |
+
|
| 718 |
+
def _init_weights(self, module):
|
| 719 |
+
"""Initialize the weights"""
|
| 720 |
+
if isinstance(module, nn.Linear):
|
| 721 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 722 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 723 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 724 |
+
if module.bias is not None:
|
| 725 |
+
module.bias.data.zero_()
|
| 726 |
+
elif isinstance(module, nn.Embedding):
|
| 727 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 728 |
+
if module.padding_idx is not None:
|
| 729 |
+
module.weight.data[module.padding_idx].zero_()
|
| 730 |
+
elif isinstance(module, nn.LayerNorm):
|
| 731 |
+
module.bias.data.zero_()
|
| 732 |
+
module.weight.data.fill_(1.0)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
@dataclass
|
| 736 |
+
class NezhaForPreTrainingOutput(ModelOutput):
|
| 737 |
+
"""
|
| 738 |
+
Output type of [`NezhaForPreTraining`].
|
| 739 |
+
|
| 740 |
+
Args:
|
| 741 |
+
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
| 742 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
| 743 |
+
(classification) loss.
|
| 744 |
+
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 745 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 746 |
+
seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
|
| 747 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
| 748 |
+
before SoftMax).
|
| 749 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 750 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 751 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 752 |
+
|
| 753 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 754 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 755 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 756 |
+
sequence_length)`.
|
| 757 |
+
|
| 758 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 759 |
+
heads.
|
| 760 |
+
"""
|
| 761 |
+
|
| 762 |
+
loss: Optional[torch.FloatTensor] = None
|
| 763 |
+
prediction_logits: Optional[torch.FloatTensor] = None
|
| 764 |
+
seq_relationship_logits: Optional[torch.FloatTensor] = None
|
| 765 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 766 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
NEZHA_START_DOCSTRING = r"""
|
| 770 |
+
|
| 771 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 772 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 773 |
+
etc.)
|
| 774 |
+
|
| 775 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 776 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 777 |
+
and behavior.
|
| 778 |
+
|
| 779 |
+
Parameters:
|
| 780 |
+
config ([`NezhaConfig`]): Model configuration class with all the parameters of the model.
|
| 781 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 782 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 783 |
+
"""
|
| 784 |
+
|
| 785 |
+
NEZHA_INPUTS_DOCSTRING = r"""
|
| 786 |
+
Args:
|
| 787 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 788 |
+
Indices of input sequence tokens in the vocabulary.
|
| 789 |
+
|
| 790 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 791 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 792 |
+
|
| 793 |
+
[What are input IDs?](../glossary#input-ids)
|
| 794 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 795 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 796 |
+
|
| 797 |
+
- 1 for tokens that are **not masked**,
|
| 798 |
+
- 0 for tokens that are **masked**.
|
| 799 |
+
|
| 800 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 801 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 802 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 803 |
+
1]`:
|
| 804 |
+
|
| 805 |
+
- 0 corresponds to a *sentence A* token,
|
| 806 |
+
- 1 corresponds to a *sentence B* token.
|
| 807 |
+
|
| 808 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 809 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 810 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 811 |
+
|
| 812 |
+
- 1 indicates the head is **not masked**,
|
| 813 |
+
- 0 indicates the head is **masked**.
|
| 814 |
+
|
| 815 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 816 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 817 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 818 |
+
model's internal embedding lookup matrix.
|
| 819 |
+
output_attentions (`bool`, *optional*):
|
| 820 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 821 |
+
tensors for more detail.
|
| 822 |
+
output_hidden_states (`bool`, *optional*):
|
| 823 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 824 |
+
more detail.
|
| 825 |
+
return_dict (`bool`, *optional*):
|
| 826 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 827 |
+
"""
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
@add_start_docstrings(
|
| 831 |
+
"The bare Nezha Model transformer outputting raw hidden-states without any specific head on top.",
|
| 832 |
+
NEZHA_START_DOCSTRING,
|
| 833 |
+
)
|
| 834 |
+
class NezhaModel(NezhaPreTrainedModel):
|
| 835 |
+
"""
|
| 836 |
+
|
| 837 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 838 |
+
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
| 839 |
+
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 840 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 841 |
+
|
| 842 |
+
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
| 843 |
+
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
| 844 |
+
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
| 845 |
+
"""
|
| 846 |
+
|
| 847 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 848 |
+
super().__init__(config)
|
| 849 |
+
self.config = config
|
| 850 |
+
|
| 851 |
+
self.embeddings = NezhaEmbeddings(config)
|
| 852 |
+
self.encoder = NezhaEncoder(config)
|
| 853 |
+
|
| 854 |
+
self.pooler = NezhaPooler(config) if add_pooling_layer else None
|
| 855 |
+
|
| 856 |
+
# Initialize weights and apply final processing
|
| 857 |
+
self.post_init()
|
| 858 |
+
|
| 859 |
+
def get_input_embeddings(self):
|
| 860 |
+
return self.embeddings.word_embeddings
|
| 861 |
+
|
| 862 |
+
def set_input_embeddings(self, value):
|
| 863 |
+
self.embeddings.word_embeddings = value
|
| 864 |
+
|
| 865 |
+
def _prune_heads(self, heads_to_prune):
|
| 866 |
+
"""
|
| 867 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 868 |
+
class PreTrainedModel
|
| 869 |
+
"""
|
| 870 |
+
for layer, heads in heads_to_prune.items():
|
| 871 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 872 |
+
|
| 873 |
+
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 874 |
+
@add_code_sample_docstrings(
|
| 875 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 876 |
+
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
| 877 |
+
config_class=_CONFIG_FOR_DOC,
|
| 878 |
+
)
|
| 879 |
+
def forward(
|
| 880 |
+
self,
|
| 881 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 882 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 883 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 884 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 885 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 886 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 887 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 888 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 889 |
+
use_cache: Optional[bool] = None,
|
| 890 |
+
output_attentions: Optional[bool] = None,
|
| 891 |
+
output_hidden_states: Optional[bool] = None,
|
| 892 |
+
return_dict: Optional[bool] = None,
|
| 893 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 894 |
+
r"""
|
| 895 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 896 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 897 |
+
the model is configured as a decoder.
|
| 898 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 899 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 900 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 901 |
+
|
| 902 |
+
- 1 for tokens that are **not masked**,
|
| 903 |
+
- 0 for tokens that are **masked**.
|
| 904 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 905 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 906 |
+
|
| 907 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 908 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 909 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 910 |
+
use_cache (`bool`, *optional*):
|
| 911 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 912 |
+
`past_key_values`).
|
| 913 |
+
"""
|
| 914 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 915 |
+
output_hidden_states = (
|
| 916 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 917 |
+
)
|
| 918 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 919 |
+
|
| 920 |
+
if self.config.is_decoder:
|
| 921 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 922 |
+
else:
|
| 923 |
+
use_cache = False
|
| 924 |
+
|
| 925 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 926 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 927 |
+
elif input_ids is not None:
|
| 928 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 929 |
+
input_shape = input_ids.size()
|
| 930 |
+
elif inputs_embeds is not None:
|
| 931 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 932 |
+
else:
|
| 933 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 934 |
+
|
| 935 |
+
batch_size, seq_length = input_shape
|
| 936 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 937 |
+
|
| 938 |
+
# past_key_values_length
|
| 939 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 940 |
+
|
| 941 |
+
if attention_mask is None:
|
| 942 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 943 |
+
|
| 944 |
+
if token_type_ids is None:
|
| 945 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
| 946 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
| 947 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
| 948 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 949 |
+
else:
|
| 950 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 951 |
+
|
| 952 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 953 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 954 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 955 |
+
|
| 956 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 957 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 958 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
| 959 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 960 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 961 |
+
if encoder_attention_mask is None:
|
| 962 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 963 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 964 |
+
else:
|
| 965 |
+
encoder_extended_attention_mask = None
|
| 966 |
+
|
| 967 |
+
# Prepare head mask if needed
|
| 968 |
+
# 1.0 in head_mask indicate we keep the head
|
| 969 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 970 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 971 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 972 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 973 |
+
|
| 974 |
+
embedding_output = self.embeddings(
|
| 975 |
+
input_ids=input_ids,
|
| 976 |
+
token_type_ids=token_type_ids,
|
| 977 |
+
inputs_embeds=inputs_embeds,
|
| 978 |
+
)
|
| 979 |
+
encoder_outputs = self.encoder(
|
| 980 |
+
embedding_output,
|
| 981 |
+
attention_mask=extended_attention_mask,
|
| 982 |
+
head_mask=head_mask,
|
| 983 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 984 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 985 |
+
past_key_values=past_key_values,
|
| 986 |
+
use_cache=use_cache,
|
| 987 |
+
output_attentions=output_attentions,
|
| 988 |
+
output_hidden_states=output_hidden_states,
|
| 989 |
+
return_dict=return_dict,
|
| 990 |
+
)
|
| 991 |
+
sequence_output = encoder_outputs[0]
|
| 992 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 993 |
+
|
| 994 |
+
if not return_dict:
|
| 995 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 996 |
+
|
| 997 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 998 |
+
last_hidden_state=sequence_output,
|
| 999 |
+
pooler_output=pooled_output,
|
| 1000 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 1001 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1002 |
+
attentions=encoder_outputs.attentions,
|
| 1003 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
@add_start_docstrings(
|
| 1008 |
+
"""
|
| 1009 |
+
Nezha Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
|
| 1010 |
+
sentence prediction (classification)` head.
|
| 1011 |
+
""",
|
| 1012 |
+
NEZHA_START_DOCSTRING,
|
| 1013 |
+
)
|
| 1014 |
+
class NezhaForPreTraining(NezhaPreTrainedModel):
|
| 1015 |
+
_tied_weights_keys = ["cls.predictions.decoder"]
|
| 1016 |
+
|
| 1017 |
+
def __init__(self, config):
|
| 1018 |
+
super().__init__(config)
|
| 1019 |
+
|
| 1020 |
+
self.nezha = NezhaModel(config)
|
| 1021 |
+
self.cls = NezhaPreTrainingHeads(config)
|
| 1022 |
+
|
| 1023 |
+
# Initialize weights and apply final processing
|
| 1024 |
+
self.post_init()
|
| 1025 |
+
|
| 1026 |
+
def get_output_embeddings(self):
|
| 1027 |
+
return self.cls.predictions.decoder
|
| 1028 |
+
|
| 1029 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1030 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1031 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 1032 |
+
|
| 1033 |
+
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1034 |
+
@replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 1035 |
+
def forward(
|
| 1036 |
+
self,
|
| 1037 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1038 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1039 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1040 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1041 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1042 |
+
labels: Optional[torch.Tensor] = None,
|
| 1043 |
+
next_sentence_label: Optional[torch.Tensor] = None,
|
| 1044 |
+
output_attentions: Optional[bool] = None,
|
| 1045 |
+
output_hidden_states: Optional[bool] = None,
|
| 1046 |
+
return_dict: Optional[bool] = None,
|
| 1047 |
+
) -> Union[Tuple[torch.Tensor], NezhaForPreTrainingOutput]:
|
| 1048 |
+
r"""
|
| 1049 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1050 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1051 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
|
| 1052 |
+
the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1053 |
+
next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1054 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
|
| 1055 |
+
pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
|
| 1056 |
+
|
| 1057 |
+
- 0 indicates sequence B is a continuation of sequence A,
|
| 1058 |
+
- 1 indicates sequence B is a random sequence.
|
| 1059 |
+
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
| 1060 |
+
Used to hide legacy arguments that have been deprecated.
|
| 1061 |
+
|
| 1062 |
+
Returns:
|
| 1063 |
+
|
| 1064 |
+
Example:
|
| 1065 |
+
|
| 1066 |
+
```python
|
| 1067 |
+
>>> from transformers import AutoTokenizer, NezhaForPreTraining
|
| 1068 |
+
>>> import torch
|
| 1069 |
+
|
| 1070 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("sijunhe/nezha-cn-base")
|
| 1071 |
+
>>> model = NezhaForPreTraining.from_pretrained("sijunhe/nezha-cn-base")
|
| 1072 |
+
|
| 1073 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 1074 |
+
>>> outputs = model(**inputs)
|
| 1075 |
+
|
| 1076 |
+
>>> prediction_logits = outputs.prediction_logits
|
| 1077 |
+
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
| 1078 |
+
```
|
| 1079 |
+
"""
|
| 1080 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1081 |
+
|
| 1082 |
+
outputs = self.nezha(
|
| 1083 |
+
input_ids,
|
| 1084 |
+
attention_mask=attention_mask,
|
| 1085 |
+
token_type_ids=token_type_ids,
|
| 1086 |
+
head_mask=head_mask,
|
| 1087 |
+
inputs_embeds=inputs_embeds,
|
| 1088 |
+
output_attentions=output_attentions,
|
| 1089 |
+
output_hidden_states=output_hidden_states,
|
| 1090 |
+
return_dict=return_dict,
|
| 1091 |
+
)
|
| 1092 |
+
|
| 1093 |
+
sequence_output, pooled_output = outputs[:2]
|
| 1094 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 1095 |
+
|
| 1096 |
+
total_loss = None
|
| 1097 |
+
if labels is not None and next_sentence_label is not None:
|
| 1098 |
+
loss_fct = CrossEntropyLoss()
|
| 1099 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1100 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
| 1101 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
| 1102 |
+
|
| 1103 |
+
if not return_dict:
|
| 1104 |
+
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
| 1105 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1106 |
+
|
| 1107 |
+
return NezhaForPreTrainingOutput(
|
| 1108 |
+
loss=total_loss,
|
| 1109 |
+
prediction_logits=prediction_scores,
|
| 1110 |
+
seq_relationship_logits=seq_relationship_score,
|
| 1111 |
+
hidden_states=outputs.hidden_states,
|
| 1112 |
+
attentions=outputs.attentions,
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
@add_start_docstrings("""Nezha Model with a `language modeling` head on top.""", NEZHA_START_DOCSTRING)
|
| 1117 |
+
class NezhaForMaskedLM(NezhaPreTrainedModel):
|
| 1118 |
+
_tied_weights_keys = ["cls.predictions.decoder"]
|
| 1119 |
+
|
| 1120 |
+
def __init__(self, config):
|
| 1121 |
+
super().__init__(config)
|
| 1122 |
+
|
| 1123 |
+
if config.is_decoder:
|
| 1124 |
+
logger.warning(
|
| 1125 |
+
"If you want to use `NezhaForMaskedLM` make sure `config.is_decoder=False` for "
|
| 1126 |
+
"bi-directional self-attention."
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
self.nezha = NezhaModel(config, add_pooling_layer=False)
|
| 1130 |
+
self.cls = NezhaOnlyMLMHead(config)
|
| 1131 |
+
|
| 1132 |
+
# Initialize weights and apply final processing
|
| 1133 |
+
self.post_init()
|
| 1134 |
+
|
| 1135 |
+
def get_output_embeddings(self):
|
| 1136 |
+
return self.cls.predictions.decoder
|
| 1137 |
+
|
| 1138 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1139 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1140 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 1141 |
+
|
| 1142 |
+
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1143 |
+
@add_code_sample_docstrings(
|
| 1144 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1145 |
+
output_type=MaskedLMOutput,
|
| 1146 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1147 |
+
)
|
| 1148 |
+
def forward(
|
| 1149 |
+
self,
|
| 1150 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1151 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1152 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1153 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1154 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1155 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1156 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1157 |
+
labels: Optional[torch.Tensor] = None,
|
| 1158 |
+
output_attentions: Optional[bool] = None,
|
| 1159 |
+
output_hidden_states: Optional[bool] = None,
|
| 1160 |
+
return_dict: Optional[bool] = None,
|
| 1161 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 1162 |
+
r"""
|
| 1163 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1164 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1165 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1166 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1167 |
+
"""
|
| 1168 |
+
|
| 1169 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1170 |
+
|
| 1171 |
+
outputs = self.nezha(
|
| 1172 |
+
input_ids,
|
| 1173 |
+
attention_mask=attention_mask,
|
| 1174 |
+
token_type_ids=token_type_ids,
|
| 1175 |
+
head_mask=head_mask,
|
| 1176 |
+
inputs_embeds=inputs_embeds,
|
| 1177 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1178 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1179 |
+
output_attentions=output_attentions,
|
| 1180 |
+
output_hidden_states=output_hidden_states,
|
| 1181 |
+
return_dict=return_dict,
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
sequence_output = outputs[0]
|
| 1185 |
+
prediction_scores = self.cls(sequence_output)
|
| 1186 |
+
|
| 1187 |
+
masked_lm_loss = None
|
| 1188 |
+
if labels is not None:
|
| 1189 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1190 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1191 |
+
|
| 1192 |
+
if not return_dict:
|
| 1193 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1194 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1195 |
+
|
| 1196 |
+
return MaskedLMOutput(
|
| 1197 |
+
loss=masked_lm_loss,
|
| 1198 |
+
logits=prediction_scores,
|
| 1199 |
+
hidden_states=outputs.hidden_states,
|
| 1200 |
+
attentions=outputs.attentions,
|
| 1201 |
+
)
|
| 1202 |
+
|
| 1203 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
| 1204 |
+
input_shape = input_ids.shape
|
| 1205 |
+
effective_batch_size = input_shape[0]
|
| 1206 |
+
|
| 1207 |
+
# add a dummy token
|
| 1208 |
+
if self.config.pad_token_id is None:
|
| 1209 |
+
raise ValueError("The PAD token should be defined for generation")
|
| 1210 |
+
|
| 1211 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
| 1212 |
+
dummy_token = torch.full(
|
| 1213 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
| 1214 |
+
)
|
| 1215 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 1216 |
+
|
| 1217 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 1218 |
+
|
| 1219 |
+
|
| 1220 |
+
@add_start_docstrings(
|
| 1221 |
+
"""Nezha Model with a `next sentence prediction (classification)` head on top.""",
|
| 1222 |
+
NEZHA_START_DOCSTRING,
|
| 1223 |
+
)
|
| 1224 |
+
class NezhaForNextSentencePrediction(NezhaPreTrainedModel):
|
| 1225 |
+
def __init__(self, config):
|
| 1226 |
+
super().__init__(config)
|
| 1227 |
+
|
| 1228 |
+
self.nezha = NezhaModel(config)
|
| 1229 |
+
self.cls = NezhaOnlyNSPHead(config)
|
| 1230 |
+
|
| 1231 |
+
# Initialize weights and apply final processing
|
| 1232 |
+
self.post_init()
|
| 1233 |
+
|
| 1234 |
+
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1235 |
+
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
| 1236 |
+
def forward(
|
| 1237 |
+
self,
|
| 1238 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1239 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1240 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1241 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1242 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1243 |
+
labels: Optional[torch.Tensor] = None,
|
| 1244 |
+
output_attentions: Optional[bool] = None,
|
| 1245 |
+
output_hidden_states: Optional[bool] = None,
|
| 1246 |
+
return_dict: Optional[bool] = None,
|
| 1247 |
+
**kwargs,
|
| 1248 |
+
) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
| 1249 |
+
r"""
|
| 1250 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1251 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
| 1252 |
+
(see `input_ids` docstring). Indices should be in `[0, 1]`:
|
| 1253 |
+
|
| 1254 |
+
- 0 indicates sequence B is a continuation of sequence A,
|
| 1255 |
+
- 1 indicates sequence B is a random sequence.
|
| 1256 |
+
|
| 1257 |
+
Returns:
|
| 1258 |
+
|
| 1259 |
+
Example:
|
| 1260 |
+
|
| 1261 |
+
```python
|
| 1262 |
+
>>> from transformers import AutoTokenizer, NezhaForNextSentencePrediction
|
| 1263 |
+
>>> import torch
|
| 1264 |
+
|
| 1265 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("sijunhe/nezha-cn-base")
|
| 1266 |
+
>>> model = NezhaForNextSentencePrediction.from_pretrained("sijunhe/nezha-cn-base")
|
| 1267 |
+
|
| 1268 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1269 |
+
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
| 1270 |
+
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
| 1271 |
+
|
| 1272 |
+
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
| 1273 |
+
>>> logits = outputs.logits
|
| 1274 |
+
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
| 1275 |
+
```
|
| 1276 |
+
"""
|
| 1277 |
+
|
| 1278 |
+
if "next_sentence_label" in kwargs:
|
| 1279 |
+
warnings.warn(
|
| 1280 |
+
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
|
| 1281 |
+
" `labels` instead.",
|
| 1282 |
+
FutureWarning,
|
| 1283 |
+
)
|
| 1284 |
+
labels = kwargs.pop("next_sentence_label")
|
| 1285 |
+
|
| 1286 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1287 |
+
|
| 1288 |
+
outputs = self.nezha(
|
| 1289 |
+
input_ids,
|
| 1290 |
+
attention_mask=attention_mask,
|
| 1291 |
+
token_type_ids=token_type_ids,
|
| 1292 |
+
head_mask=head_mask,
|
| 1293 |
+
inputs_embeds=inputs_embeds,
|
| 1294 |
+
output_attentions=output_attentions,
|
| 1295 |
+
output_hidden_states=output_hidden_states,
|
| 1296 |
+
return_dict=return_dict,
|
| 1297 |
+
)
|
| 1298 |
+
|
| 1299 |
+
pooled_output = outputs[1]
|
| 1300 |
+
|
| 1301 |
+
seq_relationship_scores = self.cls(pooled_output)
|
| 1302 |
+
|
| 1303 |
+
next_sentence_loss = None
|
| 1304 |
+
if labels is not None:
|
| 1305 |
+
loss_fct = CrossEntropyLoss()
|
| 1306 |
+
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
|
| 1307 |
+
|
| 1308 |
+
if not return_dict:
|
| 1309 |
+
output = (seq_relationship_scores,) + outputs[2:]
|
| 1310 |
+
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
| 1311 |
+
|
| 1312 |
+
return NextSentencePredictorOutput(
|
| 1313 |
+
loss=next_sentence_loss,
|
| 1314 |
+
logits=seq_relationship_scores,
|
| 1315 |
+
hidden_states=outputs.hidden_states,
|
| 1316 |
+
attentions=outputs.attentions,
|
| 1317 |
+
)
|
| 1318 |
+
|
| 1319 |
+
|
| 1320 |
+
@add_start_docstrings(
|
| 1321 |
+
"""
|
| 1322 |
+
Nezha Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 1323 |
+
output) e.g. for GLUE tasks.
|
| 1324 |
+
""",
|
| 1325 |
+
NEZHA_START_DOCSTRING,
|
| 1326 |
+
)
|
| 1327 |
+
class NezhaForSequenceClassification(NezhaPreTrainedModel):
|
| 1328 |
+
def __init__(self, config):
|
| 1329 |
+
super().__init__(config)
|
| 1330 |
+
self.num_labels = config.num_labels
|
| 1331 |
+
self.config = config
|
| 1332 |
+
|
| 1333 |
+
self.nezha = NezhaModel(config)
|
| 1334 |
+
classifier_dropout = (
|
| 1335 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1336 |
+
)
|
| 1337 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1338 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1339 |
+
|
| 1340 |
+
# Initialize weights and apply final processing
|
| 1341 |
+
self.post_init()
|
| 1342 |
+
|
| 1343 |
+
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1344 |
+
@add_code_sample_docstrings(
|
| 1345 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1346 |
+
output_type=SequenceClassifierOutput,
|
| 1347 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1348 |
+
)
|
| 1349 |
+
def forward(
|
| 1350 |
+
self,
|
| 1351 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1352 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1353 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1354 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1355 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1356 |
+
labels: Optional[torch.Tensor] = None,
|
| 1357 |
+
output_attentions: Optional[bool] = None,
|
| 1358 |
+
output_hidden_states: Optional[bool] = None,
|
| 1359 |
+
return_dict: Optional[bool] = None,
|
| 1360 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 1361 |
+
r"""
|
| 1362 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1363 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1364 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1365 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1366 |
+
"""
|
| 1367 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1368 |
+
|
| 1369 |
+
outputs = self.nezha(
|
| 1370 |
+
input_ids,
|
| 1371 |
+
attention_mask=attention_mask,
|
| 1372 |
+
token_type_ids=token_type_ids,
|
| 1373 |
+
head_mask=head_mask,
|
| 1374 |
+
inputs_embeds=inputs_embeds,
|
| 1375 |
+
output_attentions=output_attentions,
|
| 1376 |
+
output_hidden_states=output_hidden_states,
|
| 1377 |
+
return_dict=return_dict,
|
| 1378 |
+
)
|
| 1379 |
+
|
| 1380 |
+
pooled_output = outputs[1]
|
| 1381 |
+
|
| 1382 |
+
pooled_output = self.dropout(pooled_output)
|
| 1383 |
+
logits = self.classifier(pooled_output)
|
| 1384 |
+
|
| 1385 |
+
loss = None
|
| 1386 |
+
if labels is not None:
|
| 1387 |
+
if self.config.problem_type is None:
|
| 1388 |
+
if self.num_labels == 1:
|
| 1389 |
+
self.config.problem_type = "regression"
|
| 1390 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1391 |
+
self.config.problem_type = "single_label_classification"
|
| 1392 |
+
else:
|
| 1393 |
+
self.config.problem_type = "multi_label_classification"
|
| 1394 |
+
|
| 1395 |
+
if self.config.problem_type == "regression":
|
| 1396 |
+
loss_fct = MSELoss()
|
| 1397 |
+
if self.num_labels == 1:
|
| 1398 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1399 |
+
else:
|
| 1400 |
+
loss = loss_fct(logits, labels)
|
| 1401 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1402 |
+
loss_fct = CrossEntropyLoss()
|
| 1403 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1404 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1405 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1406 |
+
loss = loss_fct(logits, labels)
|
| 1407 |
+
if not return_dict:
|
| 1408 |
+
output = (logits,) + outputs[2:]
|
| 1409 |
+
return ((loss,) + output) if loss is not None else output
|
| 1410 |
+
|
| 1411 |
+
return SequenceClassifierOutput(
|
| 1412 |
+
loss=loss,
|
| 1413 |
+
logits=logits,
|
| 1414 |
+
hidden_states=outputs.hidden_states,
|
| 1415 |
+
attentions=outputs.attentions,
|
| 1416 |
+
)
|
| 1417 |
+
|
| 1418 |
+
|
| 1419 |
+
@add_start_docstrings(
|
| 1420 |
+
"""
|
| 1421 |
+
Nezha Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1422 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1423 |
+
""",
|
| 1424 |
+
NEZHA_START_DOCSTRING,
|
| 1425 |
+
)
|
| 1426 |
+
class NezhaForMultipleChoice(NezhaPreTrainedModel):
|
| 1427 |
+
def __init__(self, config):
|
| 1428 |
+
super().__init__(config)
|
| 1429 |
+
|
| 1430 |
+
self.nezha = NezhaModel(config)
|
| 1431 |
+
classifier_dropout = (
|
| 1432 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1433 |
+
)
|
| 1434 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1435 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 1436 |
+
|
| 1437 |
+
# Initialize weights and apply final processing
|
| 1438 |
+
self.post_init()
|
| 1439 |
+
|
| 1440 |
+
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
| 1441 |
+
@add_code_sample_docstrings(
|
| 1442 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1443 |
+
output_type=MultipleChoiceModelOutput,
|
| 1444 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1445 |
+
)
|
| 1446 |
+
def forward(
|
| 1447 |
+
self,
|
| 1448 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1449 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1450 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1451 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1452 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1453 |
+
labels: Optional[torch.Tensor] = None,
|
| 1454 |
+
output_attentions: Optional[bool] = None,
|
| 1455 |
+
output_hidden_states: Optional[bool] = None,
|
| 1456 |
+
return_dict: Optional[bool] = None,
|
| 1457 |
+
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 1458 |
+
r"""
|
| 1459 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1460 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 1461 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
| 1462 |
+
`input_ids` above)
|
| 1463 |
+
"""
|
| 1464 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1465 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1466 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1467 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1468 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1469 |
+
inputs_embeds = (
|
| 1470 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1471 |
+
if inputs_embeds is not None
|
| 1472 |
+
else None
|
| 1473 |
+
)
|
| 1474 |
+
|
| 1475 |
+
outputs = self.nezha(
|
| 1476 |
+
input_ids,
|
| 1477 |
+
attention_mask=attention_mask,
|
| 1478 |
+
token_type_ids=token_type_ids,
|
| 1479 |
+
head_mask=head_mask,
|
| 1480 |
+
inputs_embeds=inputs_embeds,
|
| 1481 |
+
output_attentions=output_attentions,
|
| 1482 |
+
output_hidden_states=output_hidden_states,
|
| 1483 |
+
return_dict=return_dict,
|
| 1484 |
+
)
|
| 1485 |
+
|
| 1486 |
+
pooled_output = outputs[1]
|
| 1487 |
+
print(pooled_output.shape)
|
| 1488 |
+
pooled_output = self.dropout(pooled_output)
|
| 1489 |
+
logits = self.classifier(pooled_output)
|
| 1490 |
+
print(logits.shape)
|
| 1491 |
+
print(num_choices)
|
| 1492 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1493 |
+
|
| 1494 |
+
loss = None
|
| 1495 |
+
if labels is not None:
|
| 1496 |
+
loss_fct = CrossEntropyLoss()
|
| 1497 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1498 |
+
|
| 1499 |
+
if not return_dict:
|
| 1500 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1501 |
+
return ((loss,) + output) if loss is not None else output
|
| 1502 |
+
|
| 1503 |
+
return MultipleChoiceModelOutput(
|
| 1504 |
+
loss=loss,
|
| 1505 |
+
logits=reshaped_logits,
|
| 1506 |
+
hidden_states=outputs.hidden_states,
|
| 1507 |
+
attentions=outputs.attentions,
|
| 1508 |
+
)
|
| 1509 |
+
|
| 1510 |
+
|
| 1511 |
+
@add_start_docstrings(
|
| 1512 |
+
"""
|
| 1513 |
+
Nezha Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1514 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1515 |
+
""",
|
| 1516 |
+
NEZHA_START_DOCSTRING,
|
| 1517 |
+
)
|
| 1518 |
+
class NezhaForTokenClassification(NezhaPreTrainedModel):
|
| 1519 |
+
def __init__(self, config):
|
| 1520 |
+
super().__init__(config)
|
| 1521 |
+
self.num_labels = config.num_labels
|
| 1522 |
+
|
| 1523 |
+
self.nezha = NezhaModel(config, add_pooling_layer=False)
|
| 1524 |
+
classifier_dropout = (
|
| 1525 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1526 |
+
)
|
| 1527 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1528 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1529 |
+
|
| 1530 |
+
# Initialize weights and apply final processing
|
| 1531 |
+
self.post_init()
|
| 1532 |
+
|
| 1533 |
+
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1534 |
+
@add_code_sample_docstrings(
|
| 1535 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1536 |
+
output_type=TokenClassifierOutput,
|
| 1537 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1538 |
+
)
|
| 1539 |
+
def forward(
|
| 1540 |
+
self,
|
| 1541 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1542 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1543 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1544 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1545 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1546 |
+
labels: Optional[torch.Tensor] = None,
|
| 1547 |
+
output_attentions: Optional[bool] = None,
|
| 1548 |
+
output_hidden_states: Optional[bool] = None,
|
| 1549 |
+
return_dict: Optional[bool] = None,
|
| 1550 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 1551 |
+
r"""
|
| 1552 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1553 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1554 |
+
"""
|
| 1555 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1556 |
+
|
| 1557 |
+
outputs = self.nezha(
|
| 1558 |
+
input_ids,
|
| 1559 |
+
attention_mask=attention_mask,
|
| 1560 |
+
token_type_ids=token_type_ids,
|
| 1561 |
+
head_mask=head_mask,
|
| 1562 |
+
inputs_embeds=inputs_embeds,
|
| 1563 |
+
output_attentions=output_attentions,
|
| 1564 |
+
output_hidden_states=output_hidden_states,
|
| 1565 |
+
return_dict=return_dict,
|
| 1566 |
+
)
|
| 1567 |
+
|
| 1568 |
+
sequence_output = outputs[0]
|
| 1569 |
+
|
| 1570 |
+
sequence_output = self.dropout(sequence_output)
|
| 1571 |
+
logits = self.classifier(sequence_output)
|
| 1572 |
+
|
| 1573 |
+
loss = None
|
| 1574 |
+
if labels is not None:
|
| 1575 |
+
loss_fct = CrossEntropyLoss()
|
| 1576 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1577 |
+
|
| 1578 |
+
if not return_dict:
|
| 1579 |
+
output = (logits,) + outputs[2:]
|
| 1580 |
+
return ((loss,) + output) if loss is not None else output
|
| 1581 |
+
|
| 1582 |
+
return TokenClassifierOutput(
|
| 1583 |
+
loss=loss,
|
| 1584 |
+
logits=logits,
|
| 1585 |
+
hidden_states=outputs.hidden_states,
|
| 1586 |
+
attentions=outputs.attentions,
|
| 1587 |
+
)
|
| 1588 |
+
|
| 1589 |
+
|
| 1590 |
+
@add_start_docstrings(
|
| 1591 |
+
"""
|
| 1592 |
+
Nezha Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1593 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1594 |
+
""",
|
| 1595 |
+
NEZHA_START_DOCSTRING,
|
| 1596 |
+
)
|
| 1597 |
+
class NezhaForQuestionAnswering(NezhaPreTrainedModel):
|
| 1598 |
+
def __init__(self, config):
|
| 1599 |
+
super().__init__(config)
|
| 1600 |
+
self.num_labels = config.num_labels
|
| 1601 |
+
|
| 1602 |
+
self.nezha = NezhaModel(config, add_pooling_layer=False)
|
| 1603 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1604 |
+
|
| 1605 |
+
# Initialize weights and apply final processing
|
| 1606 |
+
self.post_init()
|
| 1607 |
+
|
| 1608 |
+
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1609 |
+
@add_code_sample_docstrings(
|
| 1610 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1611 |
+
output_type=QuestionAnsweringModelOutput,
|
| 1612 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1613 |
+
)
|
| 1614 |
+
def forward(
|
| 1615 |
+
self,
|
| 1616 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1617 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1618 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1619 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1620 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1621 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 1622 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 1623 |
+
output_attentions: Optional[bool] = None,
|
| 1624 |
+
output_hidden_states: Optional[bool] = None,
|
| 1625 |
+
return_dict: Optional[bool] = None,
|
| 1626 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 1627 |
+
r"""
|
| 1628 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1629 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1630 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1631 |
+
are not taken into account for computing the loss.
|
| 1632 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1633 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1634 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1635 |
+
are not taken into account for computing the loss.
|
| 1636 |
+
"""
|
| 1637 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1638 |
+
|
| 1639 |
+
outputs = self.nezha(
|
| 1640 |
+
input_ids,
|
| 1641 |
+
attention_mask=attention_mask,
|
| 1642 |
+
token_type_ids=token_type_ids,
|
| 1643 |
+
head_mask=head_mask,
|
| 1644 |
+
inputs_embeds=inputs_embeds,
|
| 1645 |
+
output_attentions=output_attentions,
|
| 1646 |
+
output_hidden_states=output_hidden_states,
|
| 1647 |
+
return_dict=return_dict,
|
| 1648 |
+
)
|
| 1649 |
+
|
| 1650 |
+
sequence_output = outputs[0]
|
| 1651 |
+
|
| 1652 |
+
logits = self.qa_outputs(sequence_output)
|
| 1653 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1654 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1655 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1656 |
+
|
| 1657 |
+
total_loss = None
|
| 1658 |
+
if start_positions is not None and end_positions is not None:
|
| 1659 |
+
# If we are on multi-GPU, split add a dimension
|
| 1660 |
+
if len(start_positions.size()) > 1:
|
| 1661 |
+
start_positions = start_positions.squeeze(-1)
|
| 1662 |
+
if len(end_positions.size()) > 1:
|
| 1663 |
+
end_positions = end_positions.squeeze(-1)
|
| 1664 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1665 |
+
ignored_index = start_logits.size(1)
|
| 1666 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1667 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1668 |
+
|
| 1669 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1670 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1671 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1672 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1673 |
+
|
| 1674 |
+
if not return_dict:
|
| 1675 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1676 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1677 |
+
|
| 1678 |
+
return QuestionAnsweringModelOutput(
|
| 1679 |
+
loss=total_loss,
|
| 1680 |
+
start_logits=start_logits,
|
| 1681 |
+
end_logits=end_logits,
|
| 1682 |
+
hidden_states=outputs.hidden_states,
|
| 1683 |
+
attentions=outputs.attentions,
|
| 1684 |
+
)
|
| 1685 |
+
|
| 1686 |
+
|
| 1687 |
+
__all__ = [
|
| 1688 |
+
"NezhaForNextSentencePrediction",
|
| 1689 |
+
"NezhaForMaskedLM",
|
| 1690 |
+
"NezhaForPreTraining",
|
| 1691 |
+
"NezhaForMultipleChoice",
|
| 1692 |
+
"NezhaForQuestionAnswering",
|
| 1693 |
+
"NezhaForSequenceClassification",
|
| 1694 |
+
"NezhaForTokenClassification",
|
| 1695 |
+
"NezhaModel",
|
| 1696 |
+
"NezhaPreTrainedModel",
|
| 1697 |
+
]
|
docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_open_llama import *
|
| 22 |
+
from .modeling_open_llama import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""Open-Llama model configuration"""
|
| 21 |
+
|
| 22 |
+
from ....configuration_utils import PretrainedConfig
|
| 23 |
+
from ....utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class OpenLlamaConfig(PretrainedConfig):
|
| 30 |
+
r"""
|
| 31 |
+
This is the configuration class to store the configuration of a [`OpenLlamaModel`]. It is used to instantiate an
|
| 32 |
+
Open-Llama model according to the specified arguments, defining the model architecture. Instantiating a
|
| 33 |
+
configuration with the defaults will yield a similar configuration to that of the
|
| 34 |
+
[s-JoL/Open-Llama-V1](https://huggingface.co/s-JoL/Open-Llama-V1).
|
| 35 |
+
|
| 36 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 37 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 42 |
+
Vocabulary size of the Open-Llama model. Defines the number of different tokens that can be represented by
|
| 43 |
+
the `inputs_ids` passed when calling [`OpenLlamaModel`]
|
| 44 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 45 |
+
Dimension of the hidden representations.
|
| 46 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 47 |
+
Dimension of the MLP representations.
|
| 48 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 49 |
+
Number of hidden layers in the Transformer encoder.
|
| 50 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 51 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 52 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 53 |
+
The non-linear activation function (function or string) in the decoder.
|
| 54 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 55 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 56 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 57 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 58 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 59 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 60 |
+
The epsilon used by the rms normalization layers.
|
| 61 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 63 |
+
relevant if `config.is_decoder=True`.
|
| 64 |
+
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
| 65 |
+
Whether to tie weight embeddings
|
| 66 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 67 |
+
The base period of the RoPE embeddings.
|
| 68 |
+
rope_scaling (`Dict`, *optional*):
|
| 69 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 70 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 71 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 72 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
| 73 |
+
these scaling strategies behave:
|
| 74 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
| 75 |
+
experimental feature, subject to breaking API changes in future versions.
|
| 76 |
+
|
| 77 |
+
Example:
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
>>> from transformers import OpenLlamaModel, OpenLlamaConfig
|
| 81 |
+
|
| 82 |
+
>>> # Initializing a Open-Llama open_llama-7b style configuration
|
| 83 |
+
>>> configuration = OpenLlamaConfig()
|
| 84 |
+
|
| 85 |
+
>>> # Initializing a model from the open_llama-7b style configuration
|
| 86 |
+
>>> model = OpenLlamaModel(configuration)
|
| 87 |
+
|
| 88 |
+
>>> # Accessing the model configuration
|
| 89 |
+
>>> configuration = model.config
|
| 90 |
+
```"""
|
| 91 |
+
|
| 92 |
+
model_type = "open-llama"
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
vocab_size=100000,
|
| 97 |
+
hidden_size=4096,
|
| 98 |
+
intermediate_size=11008,
|
| 99 |
+
num_hidden_layers=32,
|
| 100 |
+
num_attention_heads=32,
|
| 101 |
+
hidden_act="silu",
|
| 102 |
+
max_position_embeddings=2048,
|
| 103 |
+
initializer_range=0.02,
|
| 104 |
+
rms_norm_eps=1e-6,
|
| 105 |
+
use_cache=True,
|
| 106 |
+
pad_token_id=0,
|
| 107 |
+
bos_token_id=1,
|
| 108 |
+
eos_token_id=2,
|
| 109 |
+
tie_word_embeddings=False,
|
| 110 |
+
use_memory_efficient_attention=True,
|
| 111 |
+
hidden_dropout_prob=0.1,
|
| 112 |
+
attention_dropout_prob=0.1,
|
| 113 |
+
use_stable_embedding=True,
|
| 114 |
+
shared_input_output_embedding=True,
|
| 115 |
+
rope_theta=10000.0,
|
| 116 |
+
rope_scaling=None,
|
| 117 |
+
**kwargs,
|
| 118 |
+
):
|
| 119 |
+
self.vocab_size = vocab_size
|
| 120 |
+
self.max_position_embeddings = max_position_embeddings
|
| 121 |
+
self.hidden_size = hidden_size
|
| 122 |
+
self.intermediate_size = intermediate_size
|
| 123 |
+
self.num_hidden_layers = num_hidden_layers
|
| 124 |
+
self.num_attention_heads = num_attention_heads
|
| 125 |
+
self.hidden_act = hidden_act
|
| 126 |
+
self.initializer_range = initializer_range
|
| 127 |
+
self.rms_norm_eps = rms_norm_eps
|
| 128 |
+
self.use_cache = use_cache
|
| 129 |
+
self.use_memory_efficient_attention = kwargs.pop(
|
| 130 |
+
"use_memorry_efficient_attention", use_memory_efficient_attention
|
| 131 |
+
)
|
| 132 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 133 |
+
self.attention_dropout_prob = attention_dropout_prob
|
| 134 |
+
self.use_stable_embedding = use_stable_embedding
|
| 135 |
+
self.shared_input_output_embedding = shared_input_output_embedding
|
| 136 |
+
self.rope_theta = rope_theta
|
| 137 |
+
self.rope_scaling = rope_scaling
|
| 138 |
+
self._rope_scaling_validation()
|
| 139 |
+
|
| 140 |
+
super().__init__(
|
| 141 |
+
pad_token_id=pad_token_id,
|
| 142 |
+
bos_token_id=bos_token_id,
|
| 143 |
+
eos_token_id=eos_token_id,
|
| 144 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 145 |
+
**kwargs,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def _rope_scaling_validation(self):
|
| 149 |
+
"""
|
| 150 |
+
Validate the `rope_scaling` configuration.
|
| 151 |
+
"""
|
| 152 |
+
if self.rope_scaling is None:
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
|
| 158 |
+
)
|
| 159 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
| 160 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
| 161 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
| 162 |
+
raise ValueError(
|
| 163 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
| 164 |
+
)
|
| 165 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
| 166 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
__all__ = ["OpenLlamaConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py
ADDED
|
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""PyTorch Open-Llama model."""
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
from typing import List, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.utils.checkpoint
|
| 27 |
+
from torch import nn
|
| 28 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 29 |
+
|
| 30 |
+
from ....activations import ACT2FN
|
| 31 |
+
from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 32 |
+
from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 33 |
+
from ....modeling_utils import PreTrainedModel
|
| 34 |
+
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 35 |
+
from .configuration_open_llama import OpenLlamaConfig
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
from xformers import ops as xops
|
| 42 |
+
except ImportError:
|
| 43 |
+
xops = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
_CONFIG_FOR_DOC = "OpenLlamaConfig"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class OpenLlamaRMSNorm(nn.Module):
|
| 50 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 51 |
+
"""
|
| 52 |
+
OpenLlamaRMSNorm is equivalent to T5LayerNorm
|
| 53 |
+
"""
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 56 |
+
self.variance_epsilon = eps
|
| 57 |
+
|
| 58 |
+
def forward(self, hidden_states):
|
| 59 |
+
input_dtype = hidden_states.dtype
|
| 60 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 61 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 62 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 63 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 64 |
+
|
| 65 |
+
def extra_repr(self):
|
| 66 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class OpenLlamaRotaryEmbedding(nn.Module):
|
| 70 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
self.dim = dim
|
| 74 |
+
self.max_position_embeddings = max_position_embeddings
|
| 75 |
+
self.base = base
|
| 76 |
+
inv_freq = 1.0 / (
|
| 77 |
+
self.base
|
| 78 |
+
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
|
| 79 |
+
)
|
| 80 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 81 |
+
|
| 82 |
+
# Build here to make `torch.jit.trace` work.
|
| 83 |
+
self._set_cos_sin_cache(
|
| 84 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 88 |
+
self.max_seq_len_cached = seq_len
|
| 89 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 90 |
+
|
| 91 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 92 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 93 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 94 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 95 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 96 |
+
|
| 97 |
+
def forward(self, x, seq_len=None):
|
| 98 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 99 |
+
if seq_len > self.max_seq_len_cached:
|
| 100 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 101 |
+
|
| 102 |
+
return (
|
| 103 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 104 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
| 109 |
+
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 110 |
+
|
| 111 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 112 |
+
self.scaling_factor = scaling_factor
|
| 113 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 114 |
+
|
| 115 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 116 |
+
self.max_seq_len_cached = seq_len
|
| 117 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 118 |
+
t = t / self.scaling_factor
|
| 119 |
+
|
| 120 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 121 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 122 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 123 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 124 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
| 128 |
+
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 131 |
+
self.scaling_factor = scaling_factor
|
| 132 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 133 |
+
|
| 134 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 135 |
+
self.max_seq_len_cached = seq_len
|
| 136 |
+
|
| 137 |
+
if seq_len > self.max_position_embeddings:
|
| 138 |
+
base = self.base * (
|
| 139 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
| 140 |
+
) ** (self.dim / (self.dim - 2))
|
| 141 |
+
inv_freq = 1.0 / (
|
| 142 |
+
base
|
| 143 |
+
** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
|
| 144 |
+
)
|
| 145 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 146 |
+
|
| 147 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 148 |
+
|
| 149 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 150 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 151 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 152 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 153 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def rotate_half(x):
|
| 157 |
+
"""Rotates half the hidden dims of the input."""
|
| 158 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 159 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 160 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 164 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
q (`torch.Tensor`): The query tensor.
|
| 168 |
+
k (`torch.Tensor`): The key tensor.
|
| 169 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 170 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 171 |
+
position_ids (`torch.Tensor`):
|
| 172 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 173 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 174 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 175 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 176 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 177 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 178 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 179 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 180 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 181 |
+
Returns:
|
| 182 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 183 |
+
"""
|
| 184 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 185 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 186 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 187 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 188 |
+
return q_embed, k_embed
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class OpenLlamaMLP(nn.Module):
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
hidden_size: int,
|
| 195 |
+
intermediate_size: int,
|
| 196 |
+
hidden_act: str,
|
| 197 |
+
dropout_prob: float,
|
| 198 |
+
):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 201 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 202 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 203 |
+
self.act_fn = ACT2FN[hidden_act]
|
| 204 |
+
self.dropout = nn.Dropout(dropout_prob)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 208 |
+
return self.dropout(out)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class OpenLlamaAttention(nn.Module):
|
| 212 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 213 |
+
|
| 214 |
+
def __init__(self, config: OpenLlamaConfig):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.config = config
|
| 217 |
+
self.hidden_size = config.hidden_size
|
| 218 |
+
self.num_heads = config.num_attention_heads
|
| 219 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 220 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 221 |
+
self.dropout_prob = config.attention_dropout_prob
|
| 222 |
+
self.rope_theta = config.rope_theta
|
| 223 |
+
|
| 224 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 227 |
+
f" and `num_heads`: {self.num_heads})."
|
| 228 |
+
)
|
| 229 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 230 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 231 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 232 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 233 |
+
self._init_rope()
|
| 234 |
+
|
| 235 |
+
def _init_rope(self):
|
| 236 |
+
if self.config.rope_scaling is None:
|
| 237 |
+
self.rotary_emb = OpenLlamaRotaryEmbedding(
|
| 238 |
+
self.head_dim,
|
| 239 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 240 |
+
base=self.rope_theta,
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 244 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 245 |
+
if scaling_type == "linear":
|
| 246 |
+
self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding(
|
| 247 |
+
self.head_dim,
|
| 248 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 249 |
+
scaling_factor=scaling_factor,
|
| 250 |
+
base=self.rope_theta,
|
| 251 |
+
)
|
| 252 |
+
elif scaling_type == "dynamic":
|
| 253 |
+
self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding(
|
| 254 |
+
self.head_dim,
|
| 255 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 256 |
+
scaling_factor=scaling_factor,
|
| 257 |
+
base=self.rope_theta,
|
| 258 |
+
)
|
| 259 |
+
else:
|
| 260 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 261 |
+
|
| 262 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 263 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 264 |
+
|
| 265 |
+
def forward(
|
| 266 |
+
self,
|
| 267 |
+
hidden_states: torch.Tensor,
|
| 268 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 269 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 270 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 271 |
+
output_attentions: bool = False,
|
| 272 |
+
use_cache: bool = False,
|
| 273 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 274 |
+
bsz, q_len, _ = hidden_states.size()
|
| 275 |
+
|
| 276 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 277 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 278 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 279 |
+
|
| 280 |
+
kv_seq_len = key_states.shape[-2]
|
| 281 |
+
if past_key_value is not None:
|
| 282 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 283 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 284 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 285 |
+
# [bsz, nh, t, hd]
|
| 286 |
+
|
| 287 |
+
if past_key_value is not None:
|
| 288 |
+
# reuse k, v, self_attention
|
| 289 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 290 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 291 |
+
|
| 292 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 293 |
+
|
| 294 |
+
if self.config.use_memory_efficient_attention and xops is not None and self.training:
|
| 295 |
+
attn_weights = None
|
| 296 |
+
query_states = query_states.transpose(1, 2)
|
| 297 |
+
key_states = key_states.transpose(1, 2)
|
| 298 |
+
value_states = value_states.transpose(1, 2)
|
| 299 |
+
attn_output = xops.memory_efficient_attention(
|
| 300 |
+
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 304 |
+
|
| 305 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 306 |
+
raise ValueError(
|
| 307 |
+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
| 308 |
+
f" {attn_weights.size()}"
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if attention_mask is not None:
|
| 312 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 313 |
+
raise ValueError(
|
| 314 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 315 |
+
)
|
| 316 |
+
attn_weights = attn_weights + attention_mask
|
| 317 |
+
attn_weights = torch.max(
|
| 318 |
+
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# upcast attention to fp32
|
| 322 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 323 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 324 |
+
|
| 325 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 326 |
+
raise ValueError(
|
| 327 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 328 |
+
f" {attn_output.size()}"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
attn_output = attn_output.transpose(1, 2)
|
| 332 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 333 |
+
|
| 334 |
+
attn_output = self.o_proj(attn_output)
|
| 335 |
+
|
| 336 |
+
if not output_attentions:
|
| 337 |
+
attn_weights = None
|
| 338 |
+
|
| 339 |
+
return attn_output, attn_weights, past_key_value
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class OpenLlamaDecoderLayer(nn.Module):
|
| 343 |
+
def __init__(self, config: OpenLlamaConfig):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.hidden_size = config.hidden_size
|
| 346 |
+
self.self_attn = OpenLlamaAttention(config=config)
|
| 347 |
+
self.mlp = OpenLlamaMLP(
|
| 348 |
+
hidden_size=self.hidden_size,
|
| 349 |
+
intermediate_size=config.intermediate_size,
|
| 350 |
+
hidden_act=config.hidden_act,
|
| 351 |
+
dropout_prob=config.hidden_dropout_prob,
|
| 352 |
+
)
|
| 353 |
+
self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 354 |
+
self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 355 |
+
|
| 356 |
+
def forward(
|
| 357 |
+
self,
|
| 358 |
+
hidden_states: torch.Tensor,
|
| 359 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 360 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 361 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 362 |
+
output_attentions: Optional[bool] = False,
|
| 363 |
+
use_cache: Optional[bool] = False,
|
| 364 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 365 |
+
"""
|
| 366 |
+
Args:
|
| 367 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 368 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 369 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 370 |
+
output_attentions (`bool`, *optional*):
|
| 371 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 372 |
+
returned tensors for more detail.
|
| 373 |
+
use_cache (`bool`, *optional*):
|
| 374 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 375 |
+
(see `past_key_values`).
|
| 376 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
residual = hidden_states
|
| 380 |
+
|
| 381 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 382 |
+
|
| 383 |
+
# Self Attention
|
| 384 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 385 |
+
hidden_states=hidden_states,
|
| 386 |
+
attention_mask=attention_mask,
|
| 387 |
+
position_ids=position_ids,
|
| 388 |
+
past_key_value=past_key_value,
|
| 389 |
+
output_attentions=output_attentions,
|
| 390 |
+
use_cache=use_cache,
|
| 391 |
+
)
|
| 392 |
+
hidden_states = residual + hidden_states
|
| 393 |
+
|
| 394 |
+
# Fully Connected
|
| 395 |
+
residual = hidden_states
|
| 396 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 397 |
+
hidden_states = self.mlp(hidden_states)
|
| 398 |
+
hidden_states = residual + hidden_states
|
| 399 |
+
|
| 400 |
+
outputs = (hidden_states,)
|
| 401 |
+
|
| 402 |
+
if output_attentions:
|
| 403 |
+
outputs += (self_attn_weights,)
|
| 404 |
+
|
| 405 |
+
if use_cache:
|
| 406 |
+
outputs += (present_key_value,)
|
| 407 |
+
|
| 408 |
+
return outputs
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
OPEN_LLAMA_START_DOCSTRING = r"""
|
| 412 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 413 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 414 |
+
etc.)
|
| 415 |
+
|
| 416 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 417 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 418 |
+
and behavior.
|
| 419 |
+
|
| 420 |
+
Parameters:
|
| 421 |
+
config ([`OpenLlamaConfig`]):
|
| 422 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 423 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 424 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
@add_start_docstrings(
|
| 429 |
+
"The bare Open-Llama Model outputting raw hidden-states without any specific head on top.",
|
| 430 |
+
OPEN_LLAMA_START_DOCSTRING,
|
| 431 |
+
)
|
| 432 |
+
class OpenLlamaPreTrainedModel(PreTrainedModel):
|
| 433 |
+
config_class = OpenLlamaConfig
|
| 434 |
+
base_model_prefix = "model"
|
| 435 |
+
supports_gradient_checkpointing = True
|
| 436 |
+
_no_split_modules = ["OpenLlamaDecoderLayer"]
|
| 437 |
+
|
| 438 |
+
def _init_weights(self, module):
|
| 439 |
+
std = self.config.initializer_range
|
| 440 |
+
if isinstance(module, nn.Linear):
|
| 441 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 442 |
+
if module.bias is not None:
|
| 443 |
+
module.bias.data.zero_()
|
| 444 |
+
elif isinstance(module, nn.Embedding):
|
| 445 |
+
if self.config.use_stable_embedding:
|
| 446 |
+
torch.nn.init.xavier_normal_(module.weight.data)
|
| 447 |
+
else:
|
| 448 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 449 |
+
if module.padding_idx is not None:
|
| 450 |
+
module.weight.data[module.padding_idx].zero_()
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
OPEN_LLAMA_INPUTS_DOCSTRING = r"""
|
| 454 |
+
Args:
|
| 455 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 456 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 457 |
+
it.
|
| 458 |
+
|
| 459 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 460 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 461 |
+
|
| 462 |
+
[What are input IDs?](../glossary#input-ids)
|
| 463 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 464 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 465 |
+
|
| 466 |
+
- 1 for tokens that are **not masked**,
|
| 467 |
+
- 0 for tokens that are **masked**.
|
| 468 |
+
|
| 469 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 470 |
+
|
| 471 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 472 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 473 |
+
|
| 474 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
| 475 |
+
`past_key_values`).
|
| 476 |
+
|
| 477 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 478 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 479 |
+
information on the default strategy.
|
| 480 |
+
|
| 481 |
+
- 1 indicates the head is **not masked**,
|
| 482 |
+
- 0 indicates the head is **masked**.
|
| 483 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 484 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 485 |
+
config.n_positions - 1]`.
|
| 486 |
+
|
| 487 |
+
[What are position IDs?](../glossary#position-ids)
|
| 488 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 489 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 490 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
| 491 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
| 492 |
+
|
| 493 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 494 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 495 |
+
|
| 496 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 497 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 498 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 499 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 500 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 501 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 502 |
+
model's internal embedding lookup matrix.
|
| 503 |
+
use_cache (`bool`, *optional*):
|
| 504 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 505 |
+
`past_key_values`).
|
| 506 |
+
output_attentions (`bool`, *optional*):
|
| 507 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 508 |
+
tensors for more detail.
|
| 509 |
+
output_hidden_states (`bool`, *optional*):
|
| 510 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 511 |
+
more detail.
|
| 512 |
+
return_dict (`bool`, *optional*):
|
| 513 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 514 |
+
"""
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
@add_start_docstrings(
|
| 518 |
+
"The bare Open-Llama Model outputting raw hidden-states without any specific head on top.",
|
| 519 |
+
OPEN_LLAMA_START_DOCSTRING,
|
| 520 |
+
)
|
| 521 |
+
class OpenLlamaModel(OpenLlamaPreTrainedModel):
|
| 522 |
+
"""
|
| 523 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OpenLlamaDecoderLayer`]
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
config: OpenLlamaConfig
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
def __init__(self, config: OpenLlamaConfig):
|
| 530 |
+
super().__init__(config)
|
| 531 |
+
self.padding_idx = config.pad_token_id
|
| 532 |
+
self.vocab_size = config.vocab_size
|
| 533 |
+
|
| 534 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 535 |
+
if config.use_stable_embedding:
|
| 536 |
+
self.embed_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 537 |
+
else:
|
| 538 |
+
self.embed_layer_norm = None
|
| 539 |
+
self.layers = nn.ModuleList([OpenLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 540 |
+
self.norm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 541 |
+
|
| 542 |
+
self.gradient_checkpointing = False
|
| 543 |
+
# Initialize weights and apply final processing
|
| 544 |
+
self.post_init()
|
| 545 |
+
|
| 546 |
+
def get_input_embeddings(self):
|
| 547 |
+
return self.embed_tokens
|
| 548 |
+
|
| 549 |
+
def set_input_embeddings(self, value):
|
| 550 |
+
self.embed_tokens = value
|
| 551 |
+
|
| 552 |
+
@add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
|
| 553 |
+
def forward(
|
| 554 |
+
self,
|
| 555 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 556 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 557 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 558 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 559 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 560 |
+
use_cache: Optional[bool] = None,
|
| 561 |
+
output_attentions: Optional[bool] = None,
|
| 562 |
+
output_hidden_states: Optional[bool] = None,
|
| 563 |
+
return_dict: Optional[bool] = None,
|
| 564 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 565 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 566 |
+
output_hidden_states = (
|
| 567 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 568 |
+
)
|
| 569 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 570 |
+
|
| 571 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 572 |
+
|
| 573 |
+
# retrieve input_ids and inputs_embeds
|
| 574 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 575 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 576 |
+
elif input_ids is not None:
|
| 577 |
+
batch_size, seq_length = input_ids.shape
|
| 578 |
+
elif inputs_embeds is not None:
|
| 579 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 580 |
+
else:
|
| 581 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 582 |
+
|
| 583 |
+
seq_length_with_past = seq_length
|
| 584 |
+
past_key_values_length = 0
|
| 585 |
+
|
| 586 |
+
if self.gradient_checkpointing and self.training:
|
| 587 |
+
if use_cache:
|
| 588 |
+
logger.warning_once(
|
| 589 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 590 |
+
)
|
| 591 |
+
use_cache = False
|
| 592 |
+
|
| 593 |
+
if past_key_values is not None:
|
| 594 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 595 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 596 |
+
|
| 597 |
+
if position_ids is None:
|
| 598 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 599 |
+
position_ids = torch.arange(
|
| 600 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 601 |
+
)
|
| 602 |
+
position_ids = position_ids.unsqueeze(0)
|
| 603 |
+
|
| 604 |
+
if inputs_embeds is None:
|
| 605 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 606 |
+
if self.embed_layer_norm:
|
| 607 |
+
inputs_embeds = self.embed_layer_norm(inputs_embeds)
|
| 608 |
+
# embed positions
|
| 609 |
+
if self.config.use_memory_efficient_attention and self.training:
|
| 610 |
+
attention_mask = None
|
| 611 |
+
elif attention_mask is None:
|
| 612 |
+
attention_mask = torch.ones(
|
| 613 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
input_shape = (batch_size, seq_length)
|
| 617 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 618 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
hidden_states = inputs_embeds
|
| 622 |
+
|
| 623 |
+
# decoder layers
|
| 624 |
+
all_hidden_states = () if output_hidden_states else None
|
| 625 |
+
all_self_attns = () if output_attentions else None
|
| 626 |
+
next_decoder_cache = () if use_cache else None
|
| 627 |
+
|
| 628 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 629 |
+
if output_hidden_states:
|
| 630 |
+
all_hidden_states += (hidden_states,)
|
| 631 |
+
|
| 632 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 633 |
+
|
| 634 |
+
if self.gradient_checkpointing and self.training:
|
| 635 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 636 |
+
decoder_layer.__call__,
|
| 637 |
+
hidden_states,
|
| 638 |
+
attention_mask,
|
| 639 |
+
position_ids,
|
| 640 |
+
None,
|
| 641 |
+
output_attentions,
|
| 642 |
+
None,
|
| 643 |
+
)
|
| 644 |
+
else:
|
| 645 |
+
layer_outputs = decoder_layer(
|
| 646 |
+
hidden_states,
|
| 647 |
+
attention_mask=attention_mask,
|
| 648 |
+
position_ids=position_ids,
|
| 649 |
+
past_key_value=past_key_value,
|
| 650 |
+
output_attentions=output_attentions,
|
| 651 |
+
use_cache=use_cache,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
hidden_states = layer_outputs[0]
|
| 655 |
+
|
| 656 |
+
if use_cache:
|
| 657 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 658 |
+
|
| 659 |
+
if output_attentions:
|
| 660 |
+
all_self_attns += (layer_outputs[1],)
|
| 661 |
+
|
| 662 |
+
hidden_states = self.norm(hidden_states)
|
| 663 |
+
|
| 664 |
+
# add hidden states from the last decoder layer
|
| 665 |
+
if output_hidden_states:
|
| 666 |
+
all_hidden_states += (hidden_states,)
|
| 667 |
+
|
| 668 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 669 |
+
if not return_dict:
|
| 670 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 671 |
+
return BaseModelOutputWithPast(
|
| 672 |
+
last_hidden_state=hidden_states,
|
| 673 |
+
past_key_values=next_cache,
|
| 674 |
+
hidden_states=all_hidden_states,
|
| 675 |
+
attentions=all_self_attns,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
|
| 680 |
+
def __init__(self, config):
|
| 681 |
+
super().__init__(config)
|
| 682 |
+
self.model = OpenLlamaModel(config)
|
| 683 |
+
if config.shared_input_output_embedding:
|
| 684 |
+
self.lm_head = None
|
| 685 |
+
else:
|
| 686 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 687 |
+
|
| 688 |
+
# Initialize weights and apply final processing
|
| 689 |
+
self.post_init()
|
| 690 |
+
|
| 691 |
+
def get_input_embeddings(self):
|
| 692 |
+
return self.model.embed_tokens
|
| 693 |
+
|
| 694 |
+
def set_input_embeddings(self, value):
|
| 695 |
+
self.model.embed_tokens = value
|
| 696 |
+
|
| 697 |
+
def get_output_embeddings(self):
|
| 698 |
+
return self.lm_head
|
| 699 |
+
|
| 700 |
+
def set_output_embeddings(self, new_embeddings):
|
| 701 |
+
self.lm_head = new_embeddings
|
| 702 |
+
|
| 703 |
+
def set_decoder(self, decoder):
|
| 704 |
+
self.model = decoder
|
| 705 |
+
|
| 706 |
+
def get_decoder(self):
|
| 707 |
+
return self.model
|
| 708 |
+
|
| 709 |
+
@add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
|
| 710 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 711 |
+
def forward(
|
| 712 |
+
self,
|
| 713 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 714 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 715 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 716 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 717 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 718 |
+
labels: Optional[torch.LongTensor] = None,
|
| 719 |
+
use_cache: Optional[bool] = None,
|
| 720 |
+
output_attentions: Optional[bool] = None,
|
| 721 |
+
output_hidden_states: Optional[bool] = None,
|
| 722 |
+
return_dict: Optional[bool] = None,
|
| 723 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 724 |
+
r"""
|
| 725 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 726 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 727 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 728 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 729 |
+
|
| 730 |
+
Returns:
|
| 731 |
+
|
| 732 |
+
Example:
|
| 733 |
+
|
| 734 |
+
```python
|
| 735 |
+
>>> from transformers import AutoTokenizer, OpenLlamaForCausalLM
|
| 736 |
+
|
| 737 |
+
>>> model = OpenLlamaForCausalLM.from_pretrained("openlm-research/open_llama_7b")
|
| 738 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b")
|
| 739 |
+
|
| 740 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 741 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 742 |
+
|
| 743 |
+
>>> # Generate
|
| 744 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 745 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 746 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 747 |
+
```"""
|
| 748 |
+
|
| 749 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 750 |
+
output_hidden_states = (
|
| 751 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 752 |
+
)
|
| 753 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 754 |
+
|
| 755 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 756 |
+
outputs = self.model(
|
| 757 |
+
input_ids=input_ids,
|
| 758 |
+
attention_mask=attention_mask,
|
| 759 |
+
position_ids=position_ids,
|
| 760 |
+
past_key_values=past_key_values,
|
| 761 |
+
inputs_embeds=inputs_embeds,
|
| 762 |
+
use_cache=use_cache,
|
| 763 |
+
output_attentions=output_attentions,
|
| 764 |
+
output_hidden_states=output_hidden_states,
|
| 765 |
+
return_dict=return_dict,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
hidden_states = outputs[0]
|
| 769 |
+
if self.config.shared_input_output_embedding:
|
| 770 |
+
logits = torch.einsum(
|
| 771 |
+
"blh,vh->blv", hidden_states.to(self.model.embed_tokens.weight.device), self.model.embed_tokens.weight
|
| 772 |
+
)
|
| 773 |
+
else:
|
| 774 |
+
logits = self.lm_head(hidden_states)
|
| 775 |
+
|
| 776 |
+
loss = None
|
| 777 |
+
if labels is not None:
|
| 778 |
+
# move labels to correct device to enable model parallelism
|
| 779 |
+
labels = labels.to(logits.device)
|
| 780 |
+
# Shift so that tokens < n predict n
|
| 781 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 782 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 783 |
+
# Flatten the tokens
|
| 784 |
+
loss_fct = CrossEntropyLoss()
|
| 785 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 786 |
+
shift_labels = shift_labels.view(-1)
|
| 787 |
+
# Enable model parallelism
|
| 788 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 789 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 790 |
+
|
| 791 |
+
if not return_dict:
|
| 792 |
+
output = (logits,) + outputs[1:]
|
| 793 |
+
return (loss,) + output if loss is not None else output
|
| 794 |
+
|
| 795 |
+
return CausalLMOutputWithPast(
|
| 796 |
+
loss=loss,
|
| 797 |
+
logits=logits,
|
| 798 |
+
past_key_values=outputs.past_key_values,
|
| 799 |
+
hidden_states=outputs.hidden_states,
|
| 800 |
+
attentions=outputs.attentions,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
def prepare_inputs_for_generation(
|
| 804 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 805 |
+
):
|
| 806 |
+
if past_key_values is not None:
|
| 807 |
+
past_length = past_key_values[0][0].shape[2]
|
| 808 |
+
|
| 809 |
+
# Some generation methods already pass only the last input ID
|
| 810 |
+
if input_ids.shape[1] > past_length:
|
| 811 |
+
remove_prefix_length = past_length
|
| 812 |
+
else:
|
| 813 |
+
# Default to old behavior: keep only final ID
|
| 814 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
| 815 |
+
|
| 816 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
| 817 |
+
|
| 818 |
+
position_ids = kwargs.get("position_ids", None)
|
| 819 |
+
if attention_mask is not None and position_ids is None:
|
| 820 |
+
# create position_ids on the fly for batch generation
|
| 821 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 822 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 823 |
+
if past_key_values:
|
| 824 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 825 |
+
|
| 826 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 827 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 828 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 829 |
+
else:
|
| 830 |
+
model_inputs = {"input_ids": input_ids}
|
| 831 |
+
|
| 832 |
+
model_inputs.update(
|
| 833 |
+
{
|
| 834 |
+
"position_ids": position_ids,
|
| 835 |
+
"past_key_values": past_key_values,
|
| 836 |
+
"use_cache": kwargs.get("use_cache"),
|
| 837 |
+
"attention_mask": attention_mask,
|
| 838 |
+
}
|
| 839 |
+
)
|
| 840 |
+
return model_inputs
|
| 841 |
+
|
| 842 |
+
@staticmethod
|
| 843 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 844 |
+
reordered_past = ()
|
| 845 |
+
for layer_past in past_key_values:
|
| 846 |
+
reordered_past += (
|
| 847 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 848 |
+
)
|
| 849 |
+
return reordered_past
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
@add_start_docstrings(
|
| 853 |
+
"""
|
| 854 |
+
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
| 855 |
+
|
| 856 |
+
[`OpenLlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
| 857 |
+
models (e.g. GPT-2) do.
|
| 858 |
+
|
| 859 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 860 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 861 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 862 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 863 |
+
each row of the batch).
|
| 864 |
+
""",
|
| 865 |
+
OPEN_LLAMA_START_DOCSTRING,
|
| 866 |
+
)
|
| 867 |
+
class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
|
| 868 |
+
def __init__(self, config):
|
| 869 |
+
super().__init__(config)
|
| 870 |
+
self.num_labels = config.num_labels
|
| 871 |
+
self.model = OpenLlamaModel(config)
|
| 872 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 873 |
+
|
| 874 |
+
# Initialize weights and apply final processing
|
| 875 |
+
self.post_init()
|
| 876 |
+
|
| 877 |
+
def get_input_embeddings(self):
|
| 878 |
+
return self.model.embed_tokens
|
| 879 |
+
|
| 880 |
+
def set_input_embeddings(self, value):
|
| 881 |
+
self.model.embed_tokens = value
|
| 882 |
+
|
| 883 |
+
@add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
|
| 884 |
+
def forward(
|
| 885 |
+
self,
|
| 886 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 887 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 888 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 889 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 890 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 891 |
+
labels: Optional[torch.LongTensor] = None,
|
| 892 |
+
use_cache: Optional[bool] = None,
|
| 893 |
+
output_attentions: Optional[bool] = None,
|
| 894 |
+
output_hidden_states: Optional[bool] = None,
|
| 895 |
+
return_dict: Optional[bool] = None,
|
| 896 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
| 897 |
+
r"""
|
| 898 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 899 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 900 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 901 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 902 |
+
"""
|
| 903 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 904 |
+
|
| 905 |
+
transformer_outputs = self.model(
|
| 906 |
+
input_ids,
|
| 907 |
+
attention_mask=attention_mask,
|
| 908 |
+
position_ids=position_ids,
|
| 909 |
+
past_key_values=past_key_values,
|
| 910 |
+
inputs_embeds=inputs_embeds,
|
| 911 |
+
use_cache=use_cache,
|
| 912 |
+
output_attentions=output_attentions,
|
| 913 |
+
output_hidden_states=output_hidden_states,
|
| 914 |
+
return_dict=return_dict,
|
| 915 |
+
)
|
| 916 |
+
hidden_states = transformer_outputs[0]
|
| 917 |
+
logits = self.score(hidden_states)
|
| 918 |
+
|
| 919 |
+
if input_ids is not None:
|
| 920 |
+
batch_size = input_ids.shape[0]
|
| 921 |
+
else:
|
| 922 |
+
batch_size = inputs_embeds.shape[0]
|
| 923 |
+
|
| 924 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 925 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 926 |
+
if self.config.pad_token_id is None:
|
| 927 |
+
sequence_lengths = -1
|
| 928 |
+
else:
|
| 929 |
+
if input_ids is not None:
|
| 930 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
| 931 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
| 932 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
| 933 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
| 934 |
+
else:
|
| 935 |
+
sequence_lengths = -1
|
| 936 |
+
|
| 937 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 938 |
+
|
| 939 |
+
loss = None
|
| 940 |
+
if labels is not None:
|
| 941 |
+
labels = labels.to(logits.device)
|
| 942 |
+
if self.config.problem_type is None:
|
| 943 |
+
if self.num_labels == 1:
|
| 944 |
+
self.config.problem_type = "regression"
|
| 945 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 946 |
+
self.config.problem_type = "single_label_classification"
|
| 947 |
+
else:
|
| 948 |
+
self.config.problem_type = "multi_label_classification"
|
| 949 |
+
|
| 950 |
+
if self.config.problem_type == "regression":
|
| 951 |
+
loss_fct = MSELoss()
|
| 952 |
+
if self.num_labels == 1:
|
| 953 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 954 |
+
else:
|
| 955 |
+
loss = loss_fct(pooled_logits, labels)
|
| 956 |
+
elif self.config.problem_type == "single_label_classification":
|
| 957 |
+
loss_fct = CrossEntropyLoss()
|
| 958 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 959 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 960 |
+
loss_fct = BCEWithLogitsLoss()
|
| 961 |
+
loss = loss_fct(pooled_logits, labels)
|
| 962 |
+
if not return_dict:
|
| 963 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 964 |
+
return ((loss,) + output) if loss is not None else output
|
| 965 |
+
|
| 966 |
+
return SequenceClassifierOutputWithPast(
|
| 967 |
+
loss=loss,
|
| 968 |
+
logits=pooled_logits,
|
| 969 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 970 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 971 |
+
attentions=transformer_outputs.attentions,
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
__all__ = ["OpenLlamaPreTrainedModel", "OpenLlamaModel", "OpenLlamaForCausalLM", "OpenLlamaForSequenceClassification"]
|
docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ....utils import _LazyModule
|
| 17 |
+
from ....utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_qdqbert import *
|
| 22 |
+
from .modeling_qdqbert import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""QDQBERT model configuration"""
|
| 16 |
+
|
| 17 |
+
from ....configuration_utils import PretrainedConfig
|
| 18 |
+
from ....utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class QDQBertConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`QDQBertModel`]. It is used to instantiate an
|
| 27 |
+
QDQBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 28 |
+
with the defaults will yield a similar configuration to that of the BERT
|
| 29 |
+
[google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 37 |
+
Vocabulary size of the QDQBERT model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`QDQBertModel`].
|
| 39 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 40 |
+
Dimension of the encoder layers and the pooler layer.
|
| 41 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 42 |
+
Number of hidden layers in the Transformer encoder.
|
| 43 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 44 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 45 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 46 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 47 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 48 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 49 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 50 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 51 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 52 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 53 |
+
The dropout ratio for the attention probabilities.
|
| 54 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 55 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 56 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 57 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 58 |
+
The vocabulary size of the `token_type_ids` passed when calling [`QDQBertModel`].
|
| 59 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 60 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 61 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 62 |
+
The epsilon used by the layer normalization layers.
|
| 63 |
+
is_decoder (`bool`, *optional*, defaults to `False`):
|
| 64 |
+
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
|
| 65 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 66 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 67 |
+
relevant if `config.is_decoder=True`.
|
| 68 |
+
|
| 69 |
+
Examples:
|
| 70 |
+
|
| 71 |
+
```python
|
| 72 |
+
>>> from transformers import QDQBertModel, QDQBertConfig
|
| 73 |
+
|
| 74 |
+
>>> # Initializing a QDQBERT google-bert/bert-base-uncased style configuration
|
| 75 |
+
>>> configuration = QDQBertConfig()
|
| 76 |
+
|
| 77 |
+
>>> # Initializing a model from the google-bert/bert-base-uncased style configuration
|
| 78 |
+
>>> model = QDQBertModel(configuration)
|
| 79 |
+
|
| 80 |
+
>>> # Accessing the model configuration
|
| 81 |
+
>>> configuration = model.config
|
| 82 |
+
```"""
|
| 83 |
+
|
| 84 |
+
model_type = "qdqbert"
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
vocab_size=30522,
|
| 89 |
+
hidden_size=768,
|
| 90 |
+
num_hidden_layers=12,
|
| 91 |
+
num_attention_heads=12,
|
| 92 |
+
intermediate_size=3072,
|
| 93 |
+
hidden_act="gelu",
|
| 94 |
+
hidden_dropout_prob=0.1,
|
| 95 |
+
attention_probs_dropout_prob=0.1,
|
| 96 |
+
max_position_embeddings=512,
|
| 97 |
+
type_vocab_size=2,
|
| 98 |
+
initializer_range=0.02,
|
| 99 |
+
layer_norm_eps=1e-12,
|
| 100 |
+
use_cache=True,
|
| 101 |
+
pad_token_id=1,
|
| 102 |
+
bos_token_id=0,
|
| 103 |
+
eos_token_id=2,
|
| 104 |
+
**kwargs,
|
| 105 |
+
):
|
| 106 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 107 |
+
|
| 108 |
+
self.vocab_size = vocab_size
|
| 109 |
+
self.max_position_embeddings = max_position_embeddings
|
| 110 |
+
self.hidden_size = hidden_size
|
| 111 |
+
self.num_hidden_layers = num_hidden_layers
|
| 112 |
+
self.num_attention_heads = num_attention_heads
|
| 113 |
+
self.intermediate_size = intermediate_size
|
| 114 |
+
self.hidden_act = hidden_act
|
| 115 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 116 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 117 |
+
self.initializer_range = initializer_range
|
| 118 |
+
self.type_vocab_size = type_vocab_size
|
| 119 |
+
self.layer_norm_eps = layer_norm_eps
|
| 120 |
+
self.use_cache = use_cache
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
__all__ = ["QDQBertConfig"]
|
docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py
ADDED
|
@@ -0,0 +1,1749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 NVIDIA Corporation and The HuggingFace Team.
|
| 3 |
+
# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch QDQBERT model."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
+
|
| 28 |
+
from ....activations import ACT2FN
|
| 29 |
+
from ....modeling_outputs import (
|
| 30 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 31 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 32 |
+
CausalLMOutputWithCrossAttentions,
|
| 33 |
+
MaskedLMOutput,
|
| 34 |
+
MultipleChoiceModelOutput,
|
| 35 |
+
NextSentencePredictorOutput,
|
| 36 |
+
QuestionAnsweringModelOutput,
|
| 37 |
+
SequenceClassifierOutput,
|
| 38 |
+
TokenClassifierOutput,
|
| 39 |
+
)
|
| 40 |
+
from ....modeling_utils import PreTrainedModel
|
| 41 |
+
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 42 |
+
from ....utils import (
|
| 43 |
+
add_code_sample_docstrings,
|
| 44 |
+
add_start_docstrings,
|
| 45 |
+
add_start_docstrings_to_model_forward,
|
| 46 |
+
is_pytorch_quantization_available,
|
| 47 |
+
logging,
|
| 48 |
+
replace_return_docstrings,
|
| 49 |
+
requires_backends,
|
| 50 |
+
)
|
| 51 |
+
from .configuration_qdqbert import QDQBertConfig
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
logger = logging.get_logger(__name__)
|
| 55 |
+
|
| 56 |
+
# soft dependency
|
| 57 |
+
if is_pytorch_quantization_available():
|
| 58 |
+
try:
|
| 59 |
+
from pytorch_quantization import nn as quant_nn
|
| 60 |
+
from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer
|
| 61 |
+
except OSError:
|
| 62 |
+
logger.error(
|
| 63 |
+
"QDQBERT model are not usable since `pytorch_quantization` can't be loaded. Please try to reinstall it"
|
| 64 |
+
" following the instructions here:"
|
| 65 |
+
" https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
|
| 69 |
+
_CONFIG_FOR_DOC = "QDQBertConfig"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def load_tf_weights_in_qdqbert(model, tf_checkpoint_path):
|
| 73 |
+
"""Load tf checkpoints in a pytorch model."""
|
| 74 |
+
try:
|
| 75 |
+
import re
|
| 76 |
+
|
| 77 |
+
import numpy as np
|
| 78 |
+
import tensorflow as tf
|
| 79 |
+
except ImportError:
|
| 80 |
+
logger.error(
|
| 81 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 82 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 83 |
+
)
|
| 84 |
+
raise
|
| 85 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
| 86 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
| 87 |
+
# Load weights from TF model
|
| 88 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 89 |
+
names = []
|
| 90 |
+
arrays = []
|
| 91 |
+
for name, shape in init_vars:
|
| 92 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
| 93 |
+
array = tf.train.load_variable(tf_path, name)
|
| 94 |
+
names.append(name)
|
| 95 |
+
arrays.append(array)
|
| 96 |
+
|
| 97 |
+
for name, array in zip(names, arrays):
|
| 98 |
+
name = name.split("/")
|
| 99 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
| 100 |
+
# which are not required for using pretrained model
|
| 101 |
+
if any(
|
| 102 |
+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
| 103 |
+
for n in name
|
| 104 |
+
):
|
| 105 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 106 |
+
continue
|
| 107 |
+
pointer = model
|
| 108 |
+
for m_name in name:
|
| 109 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
| 110 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
| 111 |
+
else:
|
| 112 |
+
scope_names = [m_name]
|
| 113 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
| 114 |
+
pointer = getattr(pointer, "weight")
|
| 115 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
| 116 |
+
pointer = getattr(pointer, "bias")
|
| 117 |
+
elif scope_names[0] == "output_weights":
|
| 118 |
+
pointer = getattr(pointer, "weight")
|
| 119 |
+
elif scope_names[0] == "squad":
|
| 120 |
+
pointer = getattr(pointer, "classifier")
|
| 121 |
+
else:
|
| 122 |
+
try:
|
| 123 |
+
pointer = getattr(pointer, scope_names[0])
|
| 124 |
+
except AttributeError:
|
| 125 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 126 |
+
continue
|
| 127 |
+
if len(scope_names) >= 2:
|
| 128 |
+
num = int(scope_names[1])
|
| 129 |
+
pointer = pointer[num]
|
| 130 |
+
if m_name[-11:] == "_embeddings":
|
| 131 |
+
pointer = getattr(pointer, "weight")
|
| 132 |
+
elif m_name == "kernel":
|
| 133 |
+
array = np.transpose(array)
|
| 134 |
+
try:
|
| 135 |
+
if pointer.shape != array.shape:
|
| 136 |
+
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
| 137 |
+
except AssertionError as e:
|
| 138 |
+
e.args += (pointer.shape, array.shape)
|
| 139 |
+
raise
|
| 140 |
+
logger.info(f"Initialize PyTorch weight {name}")
|
| 141 |
+
pointer.data = torch.from_numpy(array)
|
| 142 |
+
return model
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class QDQBertEmbeddings(nn.Module):
|
| 146 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 147 |
+
|
| 148 |
+
def __init__(self, config):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 151 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 152 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 153 |
+
|
| 154 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 155 |
+
# any TensorFlow checkpoint file
|
| 156 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 157 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 158 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 159 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 160 |
+
self.register_buffer(
|
| 161 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 162 |
+
)
|
| 163 |
+
self.register_buffer(
|
| 164 |
+
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def forward(
|
| 168 |
+
self,
|
| 169 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 170 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 171 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 172 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 173 |
+
past_key_values_length: int = 0,
|
| 174 |
+
) -> torch.Tensor:
|
| 175 |
+
if input_ids is not None:
|
| 176 |
+
input_shape = input_ids.size()
|
| 177 |
+
else:
|
| 178 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 179 |
+
|
| 180 |
+
seq_length = input_shape[1]
|
| 181 |
+
|
| 182 |
+
if position_ids is None:
|
| 183 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 184 |
+
|
| 185 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 186 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 187 |
+
# issue #5664
|
| 188 |
+
if token_type_ids is None:
|
| 189 |
+
if hasattr(self, "token_type_ids"):
|
| 190 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 191 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 192 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 193 |
+
else:
|
| 194 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 195 |
+
|
| 196 |
+
if inputs_embeds is None:
|
| 197 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 198 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 199 |
+
|
| 200 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 201 |
+
if self.position_embedding_type == "absolute":
|
| 202 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 203 |
+
embeddings += position_embeddings
|
| 204 |
+
embeddings = self.LayerNorm(embeddings)
|
| 205 |
+
embeddings = self.dropout(embeddings)
|
| 206 |
+
return embeddings
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class QDQBertSelfAttention(nn.Module):
|
| 210 |
+
def __init__(self, config):
|
| 211 |
+
super().__init__()
|
| 212 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 213 |
+
raise ValueError(
|
| 214 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 215 |
+
f"heads ({config.num_attention_heads})"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
self.num_attention_heads = config.num_attention_heads
|
| 219 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 220 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 221 |
+
|
| 222 |
+
self.query = quant_nn.QuantLinear(config.hidden_size, self.all_head_size)
|
| 223 |
+
self.key = quant_nn.QuantLinear(config.hidden_size, self.all_head_size)
|
| 224 |
+
self.value = quant_nn.QuantLinear(config.hidden_size, self.all_head_size)
|
| 225 |
+
|
| 226 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 227 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 228 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 229 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 230 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 231 |
+
|
| 232 |
+
self.is_decoder = config.is_decoder
|
| 233 |
+
|
| 234 |
+
self.matmul_q_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
|
| 235 |
+
self.matmul_k_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
|
| 236 |
+
self.matmul_v_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
|
| 237 |
+
self.matmul_a_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
|
| 238 |
+
|
| 239 |
+
def transpose_for_scores(self, x):
|
| 240 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 241 |
+
x = x.view(*new_x_shape)
|
| 242 |
+
return x.permute(0, 2, 1, 3)
|
| 243 |
+
|
| 244 |
+
def forward(
|
| 245 |
+
self,
|
| 246 |
+
hidden_states,
|
| 247 |
+
attention_mask=None,
|
| 248 |
+
head_mask=None,
|
| 249 |
+
encoder_hidden_states=None,
|
| 250 |
+
encoder_attention_mask=None,
|
| 251 |
+
past_key_value=None,
|
| 252 |
+
output_attentions=False,
|
| 253 |
+
):
|
| 254 |
+
mixed_query_layer = self.query(hidden_states)
|
| 255 |
+
|
| 256 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 257 |
+
# and values come from an encoder; the attention mask needs to be
|
| 258 |
+
# such that the encoder's padding tokens are not attended to.
|
| 259 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 260 |
+
|
| 261 |
+
if is_cross_attention and past_key_value is not None:
|
| 262 |
+
# reuse k,v, cross_attentions
|
| 263 |
+
key_layer = past_key_value[0]
|
| 264 |
+
value_layer = past_key_value[1]
|
| 265 |
+
attention_mask = encoder_attention_mask
|
| 266 |
+
elif is_cross_attention:
|
| 267 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 268 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 269 |
+
attention_mask = encoder_attention_mask
|
| 270 |
+
elif past_key_value is not None:
|
| 271 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 272 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 273 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 274 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 275 |
+
else:
|
| 276 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 277 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 278 |
+
|
| 279 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 280 |
+
|
| 281 |
+
if self.is_decoder:
|
| 282 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 283 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 284 |
+
# key/value_states (first "if" case)
|
| 285 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 286 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 287 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 288 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 289 |
+
past_key_value = (key_layer, value_layer)
|
| 290 |
+
|
| 291 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 292 |
+
attention_scores = torch.matmul(
|
| 293 |
+
self.matmul_q_input_quantizer(query_layer), self.matmul_k_input_quantizer(key_layer.transpose(-1, -2))
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 297 |
+
seq_length = hidden_states.size()[1]
|
| 298 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 299 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 300 |
+
distance = position_ids_l - position_ids_r
|
| 301 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 302 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 303 |
+
|
| 304 |
+
if self.position_embedding_type == "relative_key":
|
| 305 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 306 |
+
attention_scores = attention_scores + relative_position_scores
|
| 307 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 308 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 309 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 310 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 311 |
+
|
| 312 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 313 |
+
if attention_mask is not None:
|
| 314 |
+
# Apply the attention mask is (precomputed for all layers in QDQBertModel forward() function)
|
| 315 |
+
attention_scores = attention_scores + attention_mask
|
| 316 |
+
|
| 317 |
+
# Normalize the attention scores to probabilities.
|
| 318 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 319 |
+
|
| 320 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 321 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 322 |
+
attention_probs = self.dropout(attention_probs)
|
| 323 |
+
|
| 324 |
+
# Mask heads if we want to
|
| 325 |
+
if head_mask is not None:
|
| 326 |
+
attention_probs = attention_probs * head_mask
|
| 327 |
+
|
| 328 |
+
context_layer = torch.matmul(
|
| 329 |
+
self.matmul_a_input_quantizer(attention_probs), self.matmul_v_input_quantizer(value_layer)
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 333 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 334 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 335 |
+
|
| 336 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 337 |
+
|
| 338 |
+
if self.is_decoder:
|
| 339 |
+
outputs = outputs + (past_key_value,)
|
| 340 |
+
return outputs
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class QDQBertSelfOutput(nn.Module):
|
| 344 |
+
def __init__(self, config):
|
| 345 |
+
super().__init__()
|
| 346 |
+
# Quantize Linear layer
|
| 347 |
+
self.dense = quant_nn.QuantLinear(config.hidden_size, config.hidden_size)
|
| 348 |
+
|
| 349 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 350 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 351 |
+
|
| 352 |
+
# Quantize the inputs to the residual add
|
| 353 |
+
self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
|
| 354 |
+
self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
|
| 355 |
+
|
| 356 |
+
def forward(self, hidden_states, input_tensor):
|
| 357 |
+
hidden_states = self.dense(hidden_states)
|
| 358 |
+
hidden_states = self.dropout(hidden_states)
|
| 359 |
+
# Quantize the inputs to the residual add
|
| 360 |
+
add_local = self.add_local_input_quantizer(hidden_states)
|
| 361 |
+
add_residual = self.add_residual_input_quantizer(input_tensor)
|
| 362 |
+
hidden_states = self.LayerNorm(add_local + add_residual)
|
| 363 |
+
return hidden_states
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# Based on transformers.models.bert.modeling_bert.BertAttention with Bert -> QDQBert
|
| 367 |
+
class QDQBertAttention(nn.Module):
|
| 368 |
+
def __init__(self, config):
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.self = QDQBertSelfAttention(config)
|
| 371 |
+
self.output = QDQBertSelfOutput(config)
|
| 372 |
+
self.pruned_heads = set()
|
| 373 |
+
|
| 374 |
+
def prune_heads(self, heads):
|
| 375 |
+
if len(heads) == 0:
|
| 376 |
+
return
|
| 377 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 378 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# Prune linear layers
|
| 382 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 383 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 384 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 385 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 386 |
+
|
| 387 |
+
# Update hyper params and store pruned heads
|
| 388 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 389 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 390 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 391 |
+
|
| 392 |
+
def forward(
|
| 393 |
+
self,
|
| 394 |
+
hidden_states,
|
| 395 |
+
attention_mask=None,
|
| 396 |
+
head_mask=None,
|
| 397 |
+
encoder_hidden_states=None,
|
| 398 |
+
encoder_attention_mask=None,
|
| 399 |
+
past_key_value=None,
|
| 400 |
+
output_attentions=False,
|
| 401 |
+
):
|
| 402 |
+
self_outputs = self.self(
|
| 403 |
+
hidden_states,
|
| 404 |
+
attention_mask,
|
| 405 |
+
head_mask,
|
| 406 |
+
encoder_hidden_states,
|
| 407 |
+
encoder_attention_mask,
|
| 408 |
+
past_key_value,
|
| 409 |
+
output_attentions,
|
| 410 |
+
)
|
| 411 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 412 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 413 |
+
return outputs
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class QDQBertIntermediate(nn.Module):
|
| 417 |
+
def __init__(self, config):
|
| 418 |
+
super().__init__()
|
| 419 |
+
# Quantize Linear layer
|
| 420 |
+
self.dense = quant_nn.QuantLinear(config.hidden_size, config.intermediate_size)
|
| 421 |
+
if isinstance(config.hidden_act, str):
|
| 422 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 423 |
+
else:
|
| 424 |
+
self.intermediate_act_fn = config.hidden_act
|
| 425 |
+
|
| 426 |
+
def forward(self, hidden_states):
|
| 427 |
+
hidden_states = self.dense(hidden_states)
|
| 428 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 429 |
+
return hidden_states
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
class QDQBertOutput(nn.Module):
|
| 433 |
+
def __init__(self, config):
|
| 434 |
+
super().__init__()
|
| 435 |
+
# Quantize Linear layer
|
| 436 |
+
self.dense = quant_nn.QuantLinear(config.intermediate_size, config.hidden_size)
|
| 437 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 438 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 439 |
+
|
| 440 |
+
# Quantize the inputs to the residual add
|
| 441 |
+
self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
|
| 442 |
+
self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
|
| 443 |
+
|
| 444 |
+
def forward(self, hidden_states, input_tensor):
|
| 445 |
+
hidden_states = self.dense(hidden_states)
|
| 446 |
+
hidden_states = self.dropout(hidden_states)
|
| 447 |
+
# Quantize the inputs to the residual add
|
| 448 |
+
add_local = self.add_local_input_quantizer(hidden_states)
|
| 449 |
+
add_residual = self.add_residual_input_quantizer(input_tensor)
|
| 450 |
+
hidden_states = self.LayerNorm(add_local + add_residual)
|
| 451 |
+
return hidden_states
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert
|
| 455 |
+
class QDQBertLayer(nn.Module):
|
| 456 |
+
def __init__(self, config):
|
| 457 |
+
super().__init__()
|
| 458 |
+
self.seq_len_dim = 1
|
| 459 |
+
self.attention = QDQBertAttention(config)
|
| 460 |
+
self.is_decoder = config.is_decoder
|
| 461 |
+
self.add_cross_attention = config.add_cross_attention
|
| 462 |
+
if self.add_cross_attention:
|
| 463 |
+
if not self.is_decoder:
|
| 464 |
+
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
| 465 |
+
self.crossattention = QDQBertAttention(config)
|
| 466 |
+
self.intermediate = QDQBertIntermediate(config)
|
| 467 |
+
self.output = QDQBertOutput(config)
|
| 468 |
+
|
| 469 |
+
def forward(
|
| 470 |
+
self,
|
| 471 |
+
hidden_states,
|
| 472 |
+
attention_mask=None,
|
| 473 |
+
head_mask=None,
|
| 474 |
+
encoder_hidden_states=None,
|
| 475 |
+
encoder_attention_mask=None,
|
| 476 |
+
past_key_value=None,
|
| 477 |
+
output_attentions=False,
|
| 478 |
+
):
|
| 479 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 480 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 481 |
+
self_attention_outputs = self.attention(
|
| 482 |
+
hidden_states,
|
| 483 |
+
attention_mask,
|
| 484 |
+
head_mask,
|
| 485 |
+
output_attentions=output_attentions,
|
| 486 |
+
past_key_value=self_attn_past_key_value,
|
| 487 |
+
)
|
| 488 |
+
attention_output = self_attention_outputs[0]
|
| 489 |
+
|
| 490 |
+
# if decoder, the last output is tuple of self-attn cache
|
| 491 |
+
if self.is_decoder:
|
| 492 |
+
outputs = self_attention_outputs[1:-1]
|
| 493 |
+
present_key_value = self_attention_outputs[-1]
|
| 494 |
+
else:
|
| 495 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 496 |
+
|
| 497 |
+
cross_attn_present_key_value = None
|
| 498 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
| 499 |
+
if not hasattr(self, "crossattention"):
|
| 500 |
+
raise ValueError(
|
| 501 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
| 502 |
+
" by setting `config.add_cross_attention=True`"
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
| 506 |
+
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
| 507 |
+
cross_attention_outputs = self.crossattention(
|
| 508 |
+
attention_output,
|
| 509 |
+
attention_mask,
|
| 510 |
+
head_mask,
|
| 511 |
+
encoder_hidden_states,
|
| 512 |
+
encoder_attention_mask,
|
| 513 |
+
cross_attn_past_key_value,
|
| 514 |
+
output_attentions,
|
| 515 |
+
)
|
| 516 |
+
attention_output = cross_attention_outputs[0]
|
| 517 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 518 |
+
|
| 519 |
+
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
| 520 |
+
cross_attn_present_key_value = cross_attention_outputs[-1]
|
| 521 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
| 522 |
+
|
| 523 |
+
layer_output = self.feed_forward_chunk(attention_output)
|
| 524 |
+
outputs = (layer_output,) + outputs
|
| 525 |
+
|
| 526 |
+
# if decoder, return the attn key/values as the last output
|
| 527 |
+
if self.is_decoder:
|
| 528 |
+
outputs = outputs + (present_key_value,)
|
| 529 |
+
|
| 530 |
+
return outputs
|
| 531 |
+
|
| 532 |
+
def feed_forward_chunk(self, attention_output):
|
| 533 |
+
intermediate_output = self.intermediate(attention_output)
|
| 534 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 535 |
+
return layer_output
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
# Based on transformers.models.bert.modeling_bert.BertEncoder with Bert -> QDQBert
|
| 539 |
+
class QDQBertEncoder(nn.Module):
|
| 540 |
+
def __init__(self, config):
|
| 541 |
+
super().__init__()
|
| 542 |
+
self.config = config
|
| 543 |
+
self.layer = nn.ModuleList([QDQBertLayer(config) for _ in range(config.num_hidden_layers)])
|
| 544 |
+
self.gradient_checkpointing = False
|
| 545 |
+
|
| 546 |
+
def forward(
|
| 547 |
+
self,
|
| 548 |
+
hidden_states,
|
| 549 |
+
attention_mask=None,
|
| 550 |
+
head_mask=None,
|
| 551 |
+
encoder_hidden_states=None,
|
| 552 |
+
encoder_attention_mask=None,
|
| 553 |
+
past_key_values=None,
|
| 554 |
+
use_cache=None,
|
| 555 |
+
output_attentions=False,
|
| 556 |
+
output_hidden_states=False,
|
| 557 |
+
return_dict=True,
|
| 558 |
+
):
|
| 559 |
+
all_hidden_states = () if output_hidden_states else None
|
| 560 |
+
all_self_attentions = () if output_attentions else None
|
| 561 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 562 |
+
|
| 563 |
+
next_decoder_cache = () if use_cache else None
|
| 564 |
+
for i, layer_module in enumerate(self.layer):
|
| 565 |
+
if output_hidden_states:
|
| 566 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 567 |
+
|
| 568 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 569 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 570 |
+
|
| 571 |
+
if self.gradient_checkpointing and self.training:
|
| 572 |
+
if use_cache:
|
| 573 |
+
logger.warning_once(
|
| 574 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 575 |
+
)
|
| 576 |
+
use_cache = False
|
| 577 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 578 |
+
layer_module.__call__,
|
| 579 |
+
hidden_states,
|
| 580 |
+
attention_mask,
|
| 581 |
+
layer_head_mask,
|
| 582 |
+
encoder_hidden_states,
|
| 583 |
+
encoder_attention_mask,
|
| 584 |
+
past_key_value,
|
| 585 |
+
output_attentions,
|
| 586 |
+
)
|
| 587 |
+
else:
|
| 588 |
+
layer_outputs = layer_module(
|
| 589 |
+
hidden_states,
|
| 590 |
+
attention_mask,
|
| 591 |
+
layer_head_mask,
|
| 592 |
+
encoder_hidden_states,
|
| 593 |
+
encoder_attention_mask,
|
| 594 |
+
past_key_value,
|
| 595 |
+
output_attentions,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
hidden_states = layer_outputs[0]
|
| 599 |
+
if use_cache:
|
| 600 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 601 |
+
if output_attentions:
|
| 602 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 603 |
+
if self.config.add_cross_attention:
|
| 604 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 605 |
+
|
| 606 |
+
if output_hidden_states:
|
| 607 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 608 |
+
|
| 609 |
+
if not return_dict:
|
| 610 |
+
return tuple(
|
| 611 |
+
v
|
| 612 |
+
for v in [
|
| 613 |
+
hidden_states,
|
| 614 |
+
next_decoder_cache,
|
| 615 |
+
all_hidden_states,
|
| 616 |
+
all_self_attentions,
|
| 617 |
+
all_cross_attentions,
|
| 618 |
+
]
|
| 619 |
+
if v is not None
|
| 620 |
+
)
|
| 621 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 622 |
+
last_hidden_state=hidden_states,
|
| 623 |
+
past_key_values=next_decoder_cache,
|
| 624 |
+
hidden_states=all_hidden_states,
|
| 625 |
+
attentions=all_self_attentions,
|
| 626 |
+
cross_attentions=all_cross_attentions,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
class QDQBertPooler(nn.Module):
|
| 631 |
+
def __init__(self, config):
|
| 632 |
+
super().__init__()
|
| 633 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 634 |
+
self.activation = nn.Tanh()
|
| 635 |
+
|
| 636 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 637 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 638 |
+
# to the first token.
|
| 639 |
+
first_token_tensor = hidden_states[:, 0]
|
| 640 |
+
pooled_output = self.dense(first_token_tensor)
|
| 641 |
+
pooled_output = self.activation(pooled_output)
|
| 642 |
+
return pooled_output
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class QDQBertPredictionHeadTransform(nn.Module):
|
| 646 |
+
def __init__(self, config):
|
| 647 |
+
super().__init__()
|
| 648 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 649 |
+
if isinstance(config.hidden_act, str):
|
| 650 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 651 |
+
else:
|
| 652 |
+
self.transform_act_fn = config.hidden_act
|
| 653 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 654 |
+
|
| 655 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 656 |
+
hidden_states = self.dense(hidden_states)
|
| 657 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 658 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 659 |
+
return hidden_states
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
# Based on transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert -> QDQBert
|
| 663 |
+
class QDQBertLMPredictionHead(nn.Module):
|
| 664 |
+
def __init__(self, config):
|
| 665 |
+
super().__init__()
|
| 666 |
+
self.transform = QDQBertPredictionHeadTransform(config)
|
| 667 |
+
|
| 668 |
+
# The output weights are the same as the input embeddings, but there is
|
| 669 |
+
# an output-only bias for each token.
|
| 670 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 671 |
+
|
| 672 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 673 |
+
|
| 674 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 675 |
+
self.decoder.bias = self.bias
|
| 676 |
+
|
| 677 |
+
def _tie_weights(self):
|
| 678 |
+
self.decoder.bias = self.bias
|
| 679 |
+
|
| 680 |
+
def forward(self, hidden_states):
|
| 681 |
+
hidden_states = self.transform(hidden_states)
|
| 682 |
+
hidden_states = self.decoder(hidden_states)
|
| 683 |
+
return hidden_states
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
# Based on transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert -> QDQBert
|
| 687 |
+
class QDQBertOnlyMLMHead(nn.Module):
|
| 688 |
+
def __init__(self, config):
|
| 689 |
+
super().__init__()
|
| 690 |
+
self.predictions = QDQBertLMPredictionHead(config)
|
| 691 |
+
|
| 692 |
+
def forward(self, sequence_output):
|
| 693 |
+
prediction_scores = self.predictions(sequence_output)
|
| 694 |
+
return prediction_scores
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
class QDQBertOnlyNSPHead(nn.Module):
|
| 698 |
+
def __init__(self, config):
|
| 699 |
+
super().__init__()
|
| 700 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 701 |
+
|
| 702 |
+
def forward(self, pooled_output):
|
| 703 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 704 |
+
return seq_relationship_score
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
# Based on transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert -> QDQBert
|
| 708 |
+
class QDQBertPreTrainingHeads(nn.Module):
|
| 709 |
+
def __init__(self, config):
|
| 710 |
+
super().__init__()
|
| 711 |
+
self.predictions = QDQBertLMPredictionHead(config)
|
| 712 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 713 |
+
|
| 714 |
+
def forward(self, sequence_output, pooled_output):
|
| 715 |
+
prediction_scores = self.predictions(sequence_output)
|
| 716 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 717 |
+
return prediction_scores, seq_relationship_score
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
# Based on transformers.models.bert.modeling_bert.BertPreTrainedModel with Bert -> QDQBert
|
| 721 |
+
class QDQBertPreTrainedModel(PreTrainedModel):
|
| 722 |
+
"""
|
| 723 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 724 |
+
models.
|
| 725 |
+
"""
|
| 726 |
+
|
| 727 |
+
config_class = QDQBertConfig
|
| 728 |
+
load_tf_weights = load_tf_weights_in_qdqbert
|
| 729 |
+
base_model_prefix = "bert"
|
| 730 |
+
supports_gradient_checkpointing = True
|
| 731 |
+
|
| 732 |
+
def _init_weights(self, module):
|
| 733 |
+
"""Initialize the weights"""
|
| 734 |
+
if isinstance(module, nn.Linear):
|
| 735 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 736 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 737 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 738 |
+
if module.bias is not None:
|
| 739 |
+
module.bias.data.zero_()
|
| 740 |
+
elif isinstance(module, nn.Embedding):
|
| 741 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 742 |
+
if module.padding_idx is not None:
|
| 743 |
+
module.weight.data[module.padding_idx].zero_()
|
| 744 |
+
elif isinstance(module, nn.LayerNorm):
|
| 745 |
+
module.bias.data.zero_()
|
| 746 |
+
module.weight.data.fill_(1.0)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
QDQBERT_START_DOCSTRING = r"""
|
| 750 |
+
|
| 751 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 752 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 753 |
+
etc.)
|
| 754 |
+
|
| 755 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 756 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 757 |
+
and behavior.
|
| 758 |
+
|
| 759 |
+
Parameters:
|
| 760 |
+
config ([`QDQBertConfig`]): Model configuration class with all the parameters of the model.
|
| 761 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 762 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 763 |
+
"""
|
| 764 |
+
|
| 765 |
+
QDQBERT_INPUTS_DOCSTRING = r"""
|
| 766 |
+
Args:
|
| 767 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 768 |
+
Indices of input sequence tokens in the vocabulary.
|
| 769 |
+
|
| 770 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 771 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 772 |
+
|
| 773 |
+
[What are input IDs?](../glossary#input-ids)
|
| 774 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 775 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 776 |
+
|
| 777 |
+
- 1 for tokens that are **not masked**,
|
| 778 |
+
- 0 for tokens that are **masked**.
|
| 779 |
+
|
| 780 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 781 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 782 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 783 |
+
1]`:
|
| 784 |
+
|
| 785 |
+
- 0 corresponds to a *sentence A* token,
|
| 786 |
+
- 1 corresponds to a *sentence B* token.
|
| 787 |
+
|
| 788 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 789 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 790 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 791 |
+
config.max_position_embeddings - 1]`.
|
| 792 |
+
|
| 793 |
+
[What are position IDs?](../glossary#position-ids)
|
| 794 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 795 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 796 |
+
|
| 797 |
+
- 1 indicates the head is **not masked**,
|
| 798 |
+
- 0 indicates the head is **masked**.
|
| 799 |
+
|
| 800 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 801 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 802 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 803 |
+
model's internal embedding lookup matrix.
|
| 804 |
+
output_attentions (`bool`, *optional*):
|
| 805 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 806 |
+
tensors for more detail.
|
| 807 |
+
output_hidden_states (`bool`, *optional*):
|
| 808 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 809 |
+
more detail.
|
| 810 |
+
return_dict (`bool`, *optional*):
|
| 811 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 812 |
+
"""
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
@add_start_docstrings(
|
| 816 |
+
"The bare QDQBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
| 817 |
+
QDQBERT_START_DOCSTRING,
|
| 818 |
+
)
|
| 819 |
+
class QDQBertModel(QDQBertPreTrainedModel):
|
| 820 |
+
"""
|
| 821 |
+
|
| 822 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 823 |
+
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
| 824 |
+
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 825 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 826 |
+
|
| 827 |
+
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
| 828 |
+
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
| 829 |
+
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
| 830 |
+
"""
|
| 831 |
+
|
| 832 |
+
def __init__(self, config, add_pooling_layer: bool = True):
|
| 833 |
+
requires_backends(self, "pytorch_quantization")
|
| 834 |
+
super().__init__(config)
|
| 835 |
+
self.config = config
|
| 836 |
+
|
| 837 |
+
self.embeddings = QDQBertEmbeddings(config)
|
| 838 |
+
self.encoder = QDQBertEncoder(config)
|
| 839 |
+
|
| 840 |
+
self.pooler = QDQBertPooler(config) if add_pooling_layer else None
|
| 841 |
+
|
| 842 |
+
# Initialize weights and apply final processing
|
| 843 |
+
self.post_init()
|
| 844 |
+
|
| 845 |
+
def get_input_embeddings(self):
|
| 846 |
+
return self.embeddings.word_embeddings
|
| 847 |
+
|
| 848 |
+
def set_input_embeddings(self, value):
|
| 849 |
+
self.embeddings.word_embeddings = value
|
| 850 |
+
|
| 851 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]):
|
| 852 |
+
"""
|
| 853 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 854 |
+
class PreTrainedModel
|
| 855 |
+
"""
|
| 856 |
+
for layer, heads in heads_to_prune.items():
|
| 857 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 858 |
+
|
| 859 |
+
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 860 |
+
@add_code_sample_docstrings(
|
| 861 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 862 |
+
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
| 863 |
+
config_class=_CONFIG_FOR_DOC,
|
| 864 |
+
)
|
| 865 |
+
def forward(
|
| 866 |
+
self,
|
| 867 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 868 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 869 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 870 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 871 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 872 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 873 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 874 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 875 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 876 |
+
use_cache: Optional[bool] = None,
|
| 877 |
+
output_attentions: Optional[bool] = None,
|
| 878 |
+
output_hidden_states: Optional[bool] = None,
|
| 879 |
+
return_dict: Optional[bool] = None,
|
| 880 |
+
) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 881 |
+
r"""
|
| 882 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 883 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 884 |
+
the model is configured as a decoder.
|
| 885 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 886 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 887 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 888 |
+
|
| 889 |
+
- 1 for tokens that are **not masked**,
|
| 890 |
+
- 0 for tokens that are **masked**.
|
| 891 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 892 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 893 |
+
|
| 894 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 895 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 896 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 897 |
+
use_cache (`bool`, *optional*):
|
| 898 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 899 |
+
`past_key_values`).
|
| 900 |
+
"""
|
| 901 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 902 |
+
output_hidden_states = (
|
| 903 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 904 |
+
)
|
| 905 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 906 |
+
|
| 907 |
+
if self.config.is_decoder:
|
| 908 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 909 |
+
else:
|
| 910 |
+
use_cache = False
|
| 911 |
+
|
| 912 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 913 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 914 |
+
elif input_ids is not None:
|
| 915 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 916 |
+
input_shape = input_ids.size()
|
| 917 |
+
batch_size, seq_length = input_shape
|
| 918 |
+
elif inputs_embeds is not None:
|
| 919 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 920 |
+
batch_size, seq_length = input_shape
|
| 921 |
+
else:
|
| 922 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 923 |
+
|
| 924 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 925 |
+
|
| 926 |
+
# past_key_values_length
|
| 927 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 928 |
+
|
| 929 |
+
if attention_mask is None:
|
| 930 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 931 |
+
|
| 932 |
+
if token_type_ids is None:
|
| 933 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
| 934 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
| 935 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
| 936 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 937 |
+
else:
|
| 938 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 939 |
+
|
| 940 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 941 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 942 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 943 |
+
|
| 944 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 945 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 946 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
| 947 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 948 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 949 |
+
if encoder_attention_mask is None:
|
| 950 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 951 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 952 |
+
else:
|
| 953 |
+
encoder_extended_attention_mask = None
|
| 954 |
+
|
| 955 |
+
# Prepare head mask if needed
|
| 956 |
+
# 1.0 in head_mask indicate we keep the head
|
| 957 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 958 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 959 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 960 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 961 |
+
|
| 962 |
+
embedding_output = self.embeddings(
|
| 963 |
+
input_ids=input_ids,
|
| 964 |
+
position_ids=position_ids,
|
| 965 |
+
token_type_ids=token_type_ids,
|
| 966 |
+
inputs_embeds=inputs_embeds,
|
| 967 |
+
past_key_values_length=past_key_values_length,
|
| 968 |
+
)
|
| 969 |
+
encoder_outputs = self.encoder(
|
| 970 |
+
embedding_output,
|
| 971 |
+
attention_mask=extended_attention_mask,
|
| 972 |
+
head_mask=head_mask,
|
| 973 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 974 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 975 |
+
past_key_values=past_key_values,
|
| 976 |
+
use_cache=use_cache,
|
| 977 |
+
output_attentions=output_attentions,
|
| 978 |
+
output_hidden_states=output_hidden_states,
|
| 979 |
+
return_dict=return_dict,
|
| 980 |
+
)
|
| 981 |
+
sequence_output = encoder_outputs[0]
|
| 982 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 983 |
+
|
| 984 |
+
if not return_dict:
|
| 985 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 986 |
+
|
| 987 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 988 |
+
last_hidden_state=sequence_output,
|
| 989 |
+
pooler_output=pooled_output,
|
| 990 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 991 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 992 |
+
attentions=encoder_outputs.attentions,
|
| 993 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
@add_start_docstrings(
|
| 998 |
+
"""QDQBERT Model with a `language modeling` head on top for CLM fine-tuning.""", QDQBERT_START_DOCSTRING
|
| 999 |
+
)
|
| 1000 |
+
class QDQBertLMHeadModel(QDQBertPreTrainedModel):
|
| 1001 |
+
_tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
|
| 1002 |
+
|
| 1003 |
+
def __init__(self, config):
|
| 1004 |
+
super().__init__(config)
|
| 1005 |
+
|
| 1006 |
+
if not config.is_decoder:
|
| 1007 |
+
logger.warning("If you want to use `QDQBertLMHeadModel` as a standalone, add `is_decoder=True.`")
|
| 1008 |
+
|
| 1009 |
+
self.bert = QDQBertModel(config, add_pooling_layer=False)
|
| 1010 |
+
self.cls = QDQBertOnlyMLMHead(config)
|
| 1011 |
+
|
| 1012 |
+
# Initialize weights and apply final processing
|
| 1013 |
+
self.post_init()
|
| 1014 |
+
|
| 1015 |
+
def get_output_embeddings(self):
|
| 1016 |
+
return self.cls.predictions.decoder
|
| 1017 |
+
|
| 1018 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1019 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1020 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 1021 |
+
|
| 1022 |
+
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1023 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
| 1024 |
+
def forward(
|
| 1025 |
+
self,
|
| 1026 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1027 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1028 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1029 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1030 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1031 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1032 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 1033 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 1034 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1035 |
+
past_key_values: Optional[Tuple[Tuple[torch.LongTensor]]] = None,
|
| 1036 |
+
use_cache: Optional[bool] = None,
|
| 1037 |
+
output_attentions: Optional[bool] = None,
|
| 1038 |
+
output_hidden_states: Optional[bool] = None,
|
| 1039 |
+
return_dict: Optional[bool] = None,
|
| 1040 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
| 1041 |
+
r"""
|
| 1042 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1043 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 1044 |
+
the model is configured as a decoder.
|
| 1045 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1046 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 1047 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 1048 |
+
|
| 1049 |
+
- 1 for tokens that are **not masked**,
|
| 1050 |
+
- 0 for tokens that are **masked**.
|
| 1051 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1052 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
| 1053 |
+
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
| 1054 |
+
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
|
| 1055 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 1056 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 1057 |
+
|
| 1058 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 1059 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 1060 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 1061 |
+
use_cache (`bool`, *optional*):
|
| 1062 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 1063 |
+
`past_key_values`).
|
| 1064 |
+
|
| 1065 |
+
Returns:
|
| 1066 |
+
|
| 1067 |
+
Example:
|
| 1068 |
+
|
| 1069 |
+
```python
|
| 1070 |
+
>>> from transformers import AutoTokenizer, QDQBertLMHeadModel, QDQBertConfig
|
| 1071 |
+
>>> import torch
|
| 1072 |
+
|
| 1073 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
|
| 1074 |
+
>>> config = QDQBertConfig.from_pretrained("google-bert/bert-base-cased")
|
| 1075 |
+
>>> config.is_decoder = True
|
| 1076 |
+
>>> model = QDQBertLMHeadModel.from_pretrained("google-bert/bert-base-cased", config=config)
|
| 1077 |
+
|
| 1078 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 1079 |
+
>>> outputs = model(**inputs)
|
| 1080 |
+
|
| 1081 |
+
>>> prediction_logits = outputs.logits
|
| 1082 |
+
```"""
|
| 1083 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1084 |
+
if labels is not None:
|
| 1085 |
+
use_cache = False
|
| 1086 |
+
|
| 1087 |
+
outputs = self.bert(
|
| 1088 |
+
input_ids,
|
| 1089 |
+
attention_mask=attention_mask,
|
| 1090 |
+
token_type_ids=token_type_ids,
|
| 1091 |
+
position_ids=position_ids,
|
| 1092 |
+
head_mask=head_mask,
|
| 1093 |
+
inputs_embeds=inputs_embeds,
|
| 1094 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1095 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1096 |
+
past_key_values=past_key_values,
|
| 1097 |
+
use_cache=use_cache,
|
| 1098 |
+
output_attentions=output_attentions,
|
| 1099 |
+
output_hidden_states=output_hidden_states,
|
| 1100 |
+
return_dict=return_dict,
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
sequence_output = outputs[0]
|
| 1104 |
+
prediction_scores = self.cls(sequence_output)
|
| 1105 |
+
|
| 1106 |
+
lm_loss = None
|
| 1107 |
+
if labels is not None:
|
| 1108 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 1109 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
| 1110 |
+
labels = labels[:, 1:].contiguous()
|
| 1111 |
+
loss_fct = CrossEntropyLoss()
|
| 1112 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1113 |
+
|
| 1114 |
+
if not return_dict:
|
| 1115 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1116 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 1117 |
+
|
| 1118 |
+
return CausalLMOutputWithCrossAttentions(
|
| 1119 |
+
loss=lm_loss,
|
| 1120 |
+
logits=prediction_scores,
|
| 1121 |
+
past_key_values=outputs.past_key_values,
|
| 1122 |
+
hidden_states=outputs.hidden_states,
|
| 1123 |
+
attentions=outputs.attentions,
|
| 1124 |
+
cross_attentions=outputs.cross_attentions,
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
def prepare_inputs_for_generation(
|
| 1128 |
+
self,
|
| 1129 |
+
input_ids: Optional[torch.LongTensor],
|
| 1130 |
+
past_key_values=None,
|
| 1131 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1132 |
+
**model_kwargs,
|
| 1133 |
+
):
|
| 1134 |
+
input_shape = input_ids.shape
|
| 1135 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 1136 |
+
if attention_mask is None:
|
| 1137 |
+
attention_mask = input_ids.new_ones(input_shape)
|
| 1138 |
+
|
| 1139 |
+
# cut decoder_input_ids if past_key_values is used
|
| 1140 |
+
if past_key_values is not None:
|
| 1141 |
+
past_length = past_key_values[0][0].shape[2]
|
| 1142 |
+
|
| 1143 |
+
# Some generation methods already pass only the last input ID
|
| 1144 |
+
if input_ids.shape[1] > past_length:
|
| 1145 |
+
remove_prefix_length = past_length
|
| 1146 |
+
else:
|
| 1147 |
+
# Default to old behavior: keep only final ID
|
| 1148 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
| 1149 |
+
|
| 1150 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
| 1151 |
+
|
| 1152 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
| 1153 |
+
|
| 1154 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
| 1155 |
+
reordered_past = ()
|
| 1156 |
+
for layer_past in past_key_values:
|
| 1157 |
+
reordered_past += (
|
| 1158 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1159 |
+
)
|
| 1160 |
+
return reordered_past
|
| 1161 |
+
|
| 1162 |
+
|
| 1163 |
+
@add_start_docstrings("""QDQBERT Model with a `language modeling` head on top.""", QDQBERT_START_DOCSTRING)
|
| 1164 |
+
class QDQBertForMaskedLM(QDQBertPreTrainedModel):
|
| 1165 |
+
_tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
|
| 1166 |
+
|
| 1167 |
+
def __init__(self, config):
|
| 1168 |
+
super().__init__(config)
|
| 1169 |
+
|
| 1170 |
+
if config.is_decoder:
|
| 1171 |
+
logger.warning(
|
| 1172 |
+
"If you want to use `QDQBertForMaskedLM` make sure `config.is_decoder=False` for "
|
| 1173 |
+
"bi-directional self-attention."
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
self.bert = QDQBertModel(config, add_pooling_layer=False)
|
| 1177 |
+
self.cls = QDQBertOnlyMLMHead(config)
|
| 1178 |
+
|
| 1179 |
+
# Initialize weights and apply final processing
|
| 1180 |
+
self.post_init()
|
| 1181 |
+
|
| 1182 |
+
def get_output_embeddings(self):
|
| 1183 |
+
return self.cls.predictions.decoder
|
| 1184 |
+
|
| 1185 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1186 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1187 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 1188 |
+
|
| 1189 |
+
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1190 |
+
@add_code_sample_docstrings(
|
| 1191 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1192 |
+
output_type=MaskedLMOutput,
|
| 1193 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1194 |
+
)
|
| 1195 |
+
def forward(
|
| 1196 |
+
self,
|
| 1197 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1198 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1199 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1200 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1201 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1202 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1203 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 1204 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 1205 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1206 |
+
output_attentions: Optional[bool] = None,
|
| 1207 |
+
output_hidden_states: Optional[bool] = None,
|
| 1208 |
+
return_dict: Optional[bool] = None,
|
| 1209 |
+
) -> Union[Tuple, MaskedLMOutput]:
|
| 1210 |
+
r"""
|
| 1211 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1212 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1213 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1214 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1215 |
+
"""
|
| 1216 |
+
|
| 1217 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1218 |
+
|
| 1219 |
+
outputs = self.bert(
|
| 1220 |
+
input_ids,
|
| 1221 |
+
attention_mask=attention_mask,
|
| 1222 |
+
token_type_ids=token_type_ids,
|
| 1223 |
+
position_ids=position_ids,
|
| 1224 |
+
head_mask=head_mask,
|
| 1225 |
+
inputs_embeds=inputs_embeds,
|
| 1226 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1227 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1228 |
+
output_attentions=output_attentions,
|
| 1229 |
+
output_hidden_states=output_hidden_states,
|
| 1230 |
+
return_dict=return_dict,
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
sequence_output = outputs[0]
|
| 1234 |
+
prediction_scores = self.cls(sequence_output)
|
| 1235 |
+
|
| 1236 |
+
masked_lm_loss = None
|
| 1237 |
+
if labels is not None:
|
| 1238 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1239 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1240 |
+
|
| 1241 |
+
if not return_dict:
|
| 1242 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1243 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1244 |
+
|
| 1245 |
+
return MaskedLMOutput(
|
| 1246 |
+
loss=masked_lm_loss,
|
| 1247 |
+
logits=prediction_scores,
|
| 1248 |
+
hidden_states=outputs.hidden_states,
|
| 1249 |
+
attentions=outputs.attentions,
|
| 1250 |
+
)
|
| 1251 |
+
|
| 1252 |
+
def prepare_inputs_for_generation(
|
| 1253 |
+
self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, **model_kwargs
|
| 1254 |
+
):
|
| 1255 |
+
input_shape = input_ids.shape
|
| 1256 |
+
effective_batch_size = input_shape[0]
|
| 1257 |
+
|
| 1258 |
+
# add a dummy token
|
| 1259 |
+
if self.config.pad_token_id is None:
|
| 1260 |
+
raise ValueError("The PAD token should be defined for generation")
|
| 1261 |
+
|
| 1262 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
| 1263 |
+
dummy_token = torch.full(
|
| 1264 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
| 1265 |
+
)
|
| 1266 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 1267 |
+
|
| 1268 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 1269 |
+
|
| 1270 |
+
|
| 1271 |
+
@add_start_docstrings(
|
| 1272 |
+
"""Bert Model with a `next sentence prediction (classification)` head on top.""",
|
| 1273 |
+
QDQBERT_START_DOCSTRING,
|
| 1274 |
+
)
|
| 1275 |
+
class QDQBertForNextSentencePrediction(QDQBertPreTrainedModel):
|
| 1276 |
+
def __init__(self, config):
|
| 1277 |
+
super().__init__(config)
|
| 1278 |
+
|
| 1279 |
+
self.bert = QDQBertModel(config)
|
| 1280 |
+
self.cls = QDQBertOnlyNSPHead(config)
|
| 1281 |
+
|
| 1282 |
+
# Initialize weights and apply final processing
|
| 1283 |
+
self.post_init()
|
| 1284 |
+
|
| 1285 |
+
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1286 |
+
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
| 1287 |
+
def forward(
|
| 1288 |
+
self,
|
| 1289 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1290 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1291 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1292 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1293 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1294 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1295 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1296 |
+
output_attentions: Optional[bool] = None,
|
| 1297 |
+
output_hidden_states: Optional[bool] = None,
|
| 1298 |
+
return_dict: Optional[bool] = None,
|
| 1299 |
+
**kwargs,
|
| 1300 |
+
) -> Union[Tuple, NextSentencePredictorOutput]:
|
| 1301 |
+
r"""
|
| 1302 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1303 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
| 1304 |
+
(see `input_ids` docstring). Indices should be in `[0, 1]`:
|
| 1305 |
+
|
| 1306 |
+
- 0 indicates sequence B is a continuation of sequence A,
|
| 1307 |
+
- 1 indicates sequence B is a random sequence.
|
| 1308 |
+
|
| 1309 |
+
Returns:
|
| 1310 |
+
|
| 1311 |
+
Example:
|
| 1312 |
+
|
| 1313 |
+
```python
|
| 1314 |
+
>>> from transformers import AutoTokenizer, QDQBertForNextSentencePrediction
|
| 1315 |
+
>>> import torch
|
| 1316 |
+
|
| 1317 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1318 |
+
>>> model = QDQBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
|
| 1319 |
+
|
| 1320 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1321 |
+
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
| 1322 |
+
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
| 1323 |
+
|
| 1324 |
+
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
| 1325 |
+
>>> logits = outputs.logits
|
| 1326 |
+
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
| 1327 |
+
```"""
|
| 1328 |
+
|
| 1329 |
+
if "next_sentence_label" in kwargs:
|
| 1330 |
+
warnings.warn(
|
| 1331 |
+
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
|
| 1332 |
+
" `labels` instead.",
|
| 1333 |
+
FutureWarning,
|
| 1334 |
+
)
|
| 1335 |
+
labels = kwargs.pop("next_sentence_label")
|
| 1336 |
+
|
| 1337 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1338 |
+
|
| 1339 |
+
outputs = self.bert(
|
| 1340 |
+
input_ids,
|
| 1341 |
+
attention_mask=attention_mask,
|
| 1342 |
+
token_type_ids=token_type_ids,
|
| 1343 |
+
position_ids=position_ids,
|
| 1344 |
+
head_mask=head_mask,
|
| 1345 |
+
inputs_embeds=inputs_embeds,
|
| 1346 |
+
output_attentions=output_attentions,
|
| 1347 |
+
output_hidden_states=output_hidden_states,
|
| 1348 |
+
return_dict=return_dict,
|
| 1349 |
+
)
|
| 1350 |
+
|
| 1351 |
+
pooled_output = outputs[1]
|
| 1352 |
+
|
| 1353 |
+
seq_relationship_scores = self.cls(pooled_output)
|
| 1354 |
+
|
| 1355 |
+
next_sentence_loss = None
|
| 1356 |
+
if labels is not None:
|
| 1357 |
+
loss_fct = CrossEntropyLoss()
|
| 1358 |
+
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
|
| 1359 |
+
|
| 1360 |
+
if not return_dict:
|
| 1361 |
+
output = (seq_relationship_scores,) + outputs[2:]
|
| 1362 |
+
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
| 1363 |
+
|
| 1364 |
+
return NextSentencePredictorOutput(
|
| 1365 |
+
loss=next_sentence_loss,
|
| 1366 |
+
logits=seq_relationship_scores,
|
| 1367 |
+
hidden_states=outputs.hidden_states,
|
| 1368 |
+
attentions=outputs.attentions,
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
|
| 1372 |
+
@add_start_docstrings(
|
| 1373 |
+
"""
|
| 1374 |
+
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 1375 |
+
output) e.g. for GLUE tasks.
|
| 1376 |
+
""",
|
| 1377 |
+
QDQBERT_START_DOCSTRING,
|
| 1378 |
+
)
|
| 1379 |
+
class QDQBertForSequenceClassification(QDQBertPreTrainedModel):
|
| 1380 |
+
def __init__(self, config):
|
| 1381 |
+
super().__init__(config)
|
| 1382 |
+
self.num_labels = config.num_labels
|
| 1383 |
+
self.config = config
|
| 1384 |
+
|
| 1385 |
+
self.bert = QDQBertModel(config)
|
| 1386 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1387 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1388 |
+
# Initialize weights and apply final processing
|
| 1389 |
+
self.post_init()
|
| 1390 |
+
|
| 1391 |
+
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1392 |
+
@add_code_sample_docstrings(
|
| 1393 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1394 |
+
output_type=SequenceClassifierOutput,
|
| 1395 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1396 |
+
)
|
| 1397 |
+
def forward(
|
| 1398 |
+
self,
|
| 1399 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1400 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1401 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1402 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1403 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1404 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1405 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1406 |
+
output_attentions: Optional[bool] = None,
|
| 1407 |
+
output_hidden_states: Optional[bool] = None,
|
| 1408 |
+
return_dict: Optional[bool] = None,
|
| 1409 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1410 |
+
r"""
|
| 1411 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1412 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1413 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1414 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1415 |
+
"""
|
| 1416 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1417 |
+
|
| 1418 |
+
outputs = self.bert(
|
| 1419 |
+
input_ids,
|
| 1420 |
+
attention_mask=attention_mask,
|
| 1421 |
+
token_type_ids=token_type_ids,
|
| 1422 |
+
position_ids=position_ids,
|
| 1423 |
+
head_mask=head_mask,
|
| 1424 |
+
inputs_embeds=inputs_embeds,
|
| 1425 |
+
output_attentions=output_attentions,
|
| 1426 |
+
output_hidden_states=output_hidden_states,
|
| 1427 |
+
return_dict=return_dict,
|
| 1428 |
+
)
|
| 1429 |
+
|
| 1430 |
+
pooled_output = outputs[1]
|
| 1431 |
+
|
| 1432 |
+
pooled_output = self.dropout(pooled_output)
|
| 1433 |
+
logits = self.classifier(pooled_output)
|
| 1434 |
+
|
| 1435 |
+
loss = None
|
| 1436 |
+
if labels is not None:
|
| 1437 |
+
if self.config.problem_type is None:
|
| 1438 |
+
if self.num_labels == 1:
|
| 1439 |
+
self.config.problem_type = "regression"
|
| 1440 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1441 |
+
self.config.problem_type = "single_label_classification"
|
| 1442 |
+
else:
|
| 1443 |
+
self.config.problem_type = "multi_label_classification"
|
| 1444 |
+
|
| 1445 |
+
if self.config.problem_type == "regression":
|
| 1446 |
+
loss_fct = MSELoss()
|
| 1447 |
+
if self.num_labels == 1:
|
| 1448 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1449 |
+
else:
|
| 1450 |
+
loss = loss_fct(logits, labels)
|
| 1451 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1452 |
+
loss_fct = CrossEntropyLoss()
|
| 1453 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1454 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1455 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1456 |
+
loss = loss_fct(logits, labels)
|
| 1457 |
+
if not return_dict:
|
| 1458 |
+
output = (logits,) + outputs[2:]
|
| 1459 |
+
return ((loss,) + output) if loss is not None else output
|
| 1460 |
+
|
| 1461 |
+
return SequenceClassifierOutput(
|
| 1462 |
+
loss=loss,
|
| 1463 |
+
logits=logits,
|
| 1464 |
+
hidden_states=outputs.hidden_states,
|
| 1465 |
+
attentions=outputs.attentions,
|
| 1466 |
+
)
|
| 1467 |
+
|
| 1468 |
+
|
| 1469 |
+
@add_start_docstrings(
|
| 1470 |
+
"""
|
| 1471 |
+
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1472 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1473 |
+
""",
|
| 1474 |
+
QDQBERT_START_DOCSTRING,
|
| 1475 |
+
)
|
| 1476 |
+
class QDQBertForMultipleChoice(QDQBertPreTrainedModel):
|
| 1477 |
+
def __init__(self, config):
|
| 1478 |
+
super().__init__(config)
|
| 1479 |
+
|
| 1480 |
+
self.bert = QDQBertModel(config)
|
| 1481 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1482 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 1483 |
+
|
| 1484 |
+
# Initialize weights and apply final processing
|
| 1485 |
+
self.post_init()
|
| 1486 |
+
|
| 1487 |
+
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
| 1488 |
+
@add_code_sample_docstrings(
|
| 1489 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1490 |
+
output_type=MultipleChoiceModelOutput,
|
| 1491 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1492 |
+
)
|
| 1493 |
+
def forward(
|
| 1494 |
+
self,
|
| 1495 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1496 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1497 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1498 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1499 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1500 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1501 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1502 |
+
output_attentions: Optional[bool] = None,
|
| 1503 |
+
output_hidden_states: Optional[bool] = None,
|
| 1504 |
+
return_dict: Optional[bool] = None,
|
| 1505 |
+
) -> Union[Tuple, MultipleChoiceModelOutput]:
|
| 1506 |
+
r"""
|
| 1507 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1508 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 1509 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
| 1510 |
+
`input_ids` above)
|
| 1511 |
+
"""
|
| 1512 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1513 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1514 |
+
|
| 1515 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1516 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1517 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1518 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1519 |
+
inputs_embeds = (
|
| 1520 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1521 |
+
if inputs_embeds is not None
|
| 1522 |
+
else None
|
| 1523 |
+
)
|
| 1524 |
+
|
| 1525 |
+
outputs = self.bert(
|
| 1526 |
+
input_ids,
|
| 1527 |
+
attention_mask=attention_mask,
|
| 1528 |
+
token_type_ids=token_type_ids,
|
| 1529 |
+
position_ids=position_ids,
|
| 1530 |
+
head_mask=head_mask,
|
| 1531 |
+
inputs_embeds=inputs_embeds,
|
| 1532 |
+
output_attentions=output_attentions,
|
| 1533 |
+
output_hidden_states=output_hidden_states,
|
| 1534 |
+
return_dict=return_dict,
|
| 1535 |
+
)
|
| 1536 |
+
|
| 1537 |
+
pooled_output = outputs[1]
|
| 1538 |
+
|
| 1539 |
+
pooled_output = self.dropout(pooled_output)
|
| 1540 |
+
logits = self.classifier(pooled_output)
|
| 1541 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1542 |
+
|
| 1543 |
+
loss = None
|
| 1544 |
+
if labels is not None:
|
| 1545 |
+
loss_fct = CrossEntropyLoss()
|
| 1546 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1547 |
+
|
| 1548 |
+
if not return_dict:
|
| 1549 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1550 |
+
return ((loss,) + output) if loss is not None else output
|
| 1551 |
+
|
| 1552 |
+
return MultipleChoiceModelOutput(
|
| 1553 |
+
loss=loss,
|
| 1554 |
+
logits=reshaped_logits,
|
| 1555 |
+
hidden_states=outputs.hidden_states,
|
| 1556 |
+
attentions=outputs.attentions,
|
| 1557 |
+
)
|
| 1558 |
+
|
| 1559 |
+
|
| 1560 |
+
@add_start_docstrings(
|
| 1561 |
+
"""
|
| 1562 |
+
QDQBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1563 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1564 |
+
""",
|
| 1565 |
+
QDQBERT_START_DOCSTRING,
|
| 1566 |
+
)
|
| 1567 |
+
class QDQBertForTokenClassification(QDQBertPreTrainedModel):
|
| 1568 |
+
def __init__(self, config):
|
| 1569 |
+
super().__init__(config)
|
| 1570 |
+
self.num_labels = config.num_labels
|
| 1571 |
+
|
| 1572 |
+
self.bert = QDQBertModel(config, add_pooling_layer=False)
|
| 1573 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1574 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1575 |
+
|
| 1576 |
+
# Initialize weights and apply final processing
|
| 1577 |
+
self.post_init()
|
| 1578 |
+
|
| 1579 |
+
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1580 |
+
@add_code_sample_docstrings(
|
| 1581 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1582 |
+
output_type=TokenClassifierOutput,
|
| 1583 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1584 |
+
)
|
| 1585 |
+
def forward(
|
| 1586 |
+
self,
|
| 1587 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1588 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1589 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1590 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1591 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1592 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1593 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1594 |
+
output_attentions: Optional[bool] = None,
|
| 1595 |
+
output_hidden_states: Optional[bool] = None,
|
| 1596 |
+
return_dict: Optional[bool] = None,
|
| 1597 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1598 |
+
r"""
|
| 1599 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1600 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1601 |
+
"""
|
| 1602 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1603 |
+
|
| 1604 |
+
outputs = self.bert(
|
| 1605 |
+
input_ids,
|
| 1606 |
+
attention_mask=attention_mask,
|
| 1607 |
+
token_type_ids=token_type_ids,
|
| 1608 |
+
position_ids=position_ids,
|
| 1609 |
+
head_mask=head_mask,
|
| 1610 |
+
inputs_embeds=inputs_embeds,
|
| 1611 |
+
output_attentions=output_attentions,
|
| 1612 |
+
output_hidden_states=output_hidden_states,
|
| 1613 |
+
return_dict=return_dict,
|
| 1614 |
+
)
|
| 1615 |
+
|
| 1616 |
+
sequence_output = outputs[0]
|
| 1617 |
+
|
| 1618 |
+
sequence_output = self.dropout(sequence_output)
|
| 1619 |
+
logits = self.classifier(sequence_output)
|
| 1620 |
+
|
| 1621 |
+
loss = None
|
| 1622 |
+
if labels is not None:
|
| 1623 |
+
loss_fct = CrossEntropyLoss()
|
| 1624 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1625 |
+
|
| 1626 |
+
if not return_dict:
|
| 1627 |
+
output = (logits,) + outputs[2:]
|
| 1628 |
+
return ((loss,) + output) if loss is not None else output
|
| 1629 |
+
|
| 1630 |
+
return TokenClassifierOutput(
|
| 1631 |
+
loss=loss,
|
| 1632 |
+
logits=logits,
|
| 1633 |
+
hidden_states=outputs.hidden_states,
|
| 1634 |
+
attentions=outputs.attentions,
|
| 1635 |
+
)
|
| 1636 |
+
|
| 1637 |
+
|
| 1638 |
+
@add_start_docstrings(
|
| 1639 |
+
"""
|
| 1640 |
+
QDQBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1641 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1642 |
+
""",
|
| 1643 |
+
QDQBERT_START_DOCSTRING,
|
| 1644 |
+
)
|
| 1645 |
+
class QDQBertForQuestionAnswering(QDQBertPreTrainedModel):
|
| 1646 |
+
def __init__(self, config):
|
| 1647 |
+
super().__init__(config)
|
| 1648 |
+
self.num_labels = config.num_labels
|
| 1649 |
+
|
| 1650 |
+
self.bert = QDQBertModel(config, add_pooling_layer=False)
|
| 1651 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1652 |
+
|
| 1653 |
+
# Initialize weights and apply final processing
|
| 1654 |
+
self.post_init()
|
| 1655 |
+
|
| 1656 |
+
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1657 |
+
@add_code_sample_docstrings(
|
| 1658 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1659 |
+
output_type=QuestionAnsweringModelOutput,
|
| 1660 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1661 |
+
)
|
| 1662 |
+
def forward(
|
| 1663 |
+
self,
|
| 1664 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1665 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1666 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1667 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1668 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1669 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1670 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 1671 |
+
end_positions: Optional[torch.LongTensor] = None,
|
| 1672 |
+
output_attentions: Optional[bool] = None,
|
| 1673 |
+
output_hidden_states: Optional[bool] = None,
|
| 1674 |
+
return_dict: Optional[bool] = None,
|
| 1675 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
| 1676 |
+
r"""
|
| 1677 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1678 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1679 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1680 |
+
are not taken into account for computing the loss.
|
| 1681 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1682 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1683 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1684 |
+
are not taken into account for computing the loss.
|
| 1685 |
+
"""
|
| 1686 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1687 |
+
|
| 1688 |
+
outputs = self.bert(
|
| 1689 |
+
input_ids,
|
| 1690 |
+
attention_mask=attention_mask,
|
| 1691 |
+
token_type_ids=token_type_ids,
|
| 1692 |
+
position_ids=position_ids,
|
| 1693 |
+
head_mask=head_mask,
|
| 1694 |
+
inputs_embeds=inputs_embeds,
|
| 1695 |
+
output_attentions=output_attentions,
|
| 1696 |
+
output_hidden_states=output_hidden_states,
|
| 1697 |
+
return_dict=return_dict,
|
| 1698 |
+
)
|
| 1699 |
+
|
| 1700 |
+
sequence_output = outputs[0]
|
| 1701 |
+
|
| 1702 |
+
logits = self.qa_outputs(sequence_output)
|
| 1703 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1704 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1705 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1706 |
+
|
| 1707 |
+
total_loss = None
|
| 1708 |
+
if start_positions is not None and end_positions is not None:
|
| 1709 |
+
# If we are on multi-GPU, split add a dimension
|
| 1710 |
+
if len(start_positions.size()) > 1:
|
| 1711 |
+
start_positions = start_positions.squeeze(-1)
|
| 1712 |
+
if len(end_positions.size()) > 1:
|
| 1713 |
+
end_positions = end_positions.squeeze(-1)
|
| 1714 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1715 |
+
ignored_index = start_logits.size(1)
|
| 1716 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1717 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1718 |
+
|
| 1719 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1720 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1721 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1722 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1723 |
+
|
| 1724 |
+
if not return_dict:
|
| 1725 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1726 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1727 |
+
|
| 1728 |
+
return QuestionAnsweringModelOutput(
|
| 1729 |
+
loss=total_loss,
|
| 1730 |
+
start_logits=start_logits,
|
| 1731 |
+
end_logits=end_logits,
|
| 1732 |
+
hidden_states=outputs.hidden_states,
|
| 1733 |
+
attentions=outputs.attentions,
|
| 1734 |
+
)
|
| 1735 |
+
|
| 1736 |
+
|
| 1737 |
+
__all__ = [
|
| 1738 |
+
"QDQBertForMaskedLM",
|
| 1739 |
+
"QDQBertForMultipleChoice",
|
| 1740 |
+
"QDQBertForNextSentencePrediction",
|
| 1741 |
+
"QDQBertForQuestionAnswering",
|
| 1742 |
+
"QDQBertForSequenceClassification",
|
| 1743 |
+
"QDQBertForTokenClassification",
|
| 1744 |
+
"QDQBertLayer",
|
| 1745 |
+
"QDQBertLMHeadModel",
|
| 1746 |
+
"QDQBertModel",
|
| 1747 |
+
"QDQBertPreTrainedModel",
|
| 1748 |
+
"load_tf_weights_in_qdqbert",
|
| 1749 |
+
]
|