Student0809 commited on
Commit
33b613a
·
verified ·
1 Parent(s): fd421e2

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. docs/resources/grpo_clevr_count.png +3 -0
  3. docs/resources/grpo_countdown_1.png +3 -0
  4. docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py +172 -0
  5. docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py +324 -0
  6. docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py +807 -0
  7. docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py +1198 -0
  8. docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py +28 -0
  9. docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py +114 -0
  10. docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py +1058 -0
  11. docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py +410 -0
  12. docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py +28 -0
  13. docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py +157 -0
  14. docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py +181 -0
  15. docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +1337 -0
  16. docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py +518 -0
  17. docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py +27 -0
  18. docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx +107 -0
  19. docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py +134 -0
  20. docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py +220 -0
  21. docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py +911 -0
  22. docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py +28 -0
  23. docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py +613 -0
  24. docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py +279 -0
  25. docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py +0 -0
  26. docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py +407 -0
  27. docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py +29 -0
  28. docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py +184 -0
  29. docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py +291 -0
  30. docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py +791 -0
  31. docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py +146 -0
  32. docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py +27 -0
  33. docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py +243 -0
  34. docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py +298 -0
  35. docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py +0 -0
  36. docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py +27 -0
  37. docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py +45 -0
  38. docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py +410 -0
  39. docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py +27 -0
  40. docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py +148 -0
  41. docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py +953 -0
  42. docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py +27 -0
  43. docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py +105 -0
  44. docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py +1697 -0
  45. docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py +27 -0
  46. docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py +169 -0
  47. docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py +975 -0
  48. docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py +27 -0
  49. docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py +123 -0
  50. docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +1749 -0
.gitattributes CHANGED
@@ -53,3 +53,5 @@ docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text
53
  docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
54
  docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
55
  docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
 
 
 
53
  docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
54
  docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
55
  docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
56
+ docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text
57
+ docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text
docs/resources/grpo_clevr_count.png ADDED

Git LFS Details

  • SHA256: 7192dc4f04801dbdff30bed098a16a7e21212a773ba7b6dc1424b261feca366f
  • Pointer size: 131 Bytes
  • Size of remote file: 671 kB
docs/resources/grpo_countdown_1.png ADDED

Git LFS Details

  • SHA256: b78dc3ce1cd541e76f2c557dea3aff06b278bb3b5413946a92c584cf42c1369f
  • Pointer size: 131 Bytes
  • Size of remote file: 785 kB
docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """EfficientFormer model configuration"""
16
+
17
+ from typing import List
18
+
19
+ from ....configuration_utils import PretrainedConfig
20
+ from ....utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class EfficientFormerConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of an [`EfficientFormerModel`]. It is used to
29
+ instantiate an EfficientFormer model according to the specified arguments, defining the model architecture.
30
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the EfficientFormer
31
+ [snap-research/efficientformer-l1](https://huggingface.co/snap-research/efficientformer-l1) architecture.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+ Args:
37
+ depths (`List(int)`, *optional*, defaults to `[3, 2, 6, 4]`)
38
+ Depth of each stage.
39
+ hidden_sizes (`List(int)`, *optional*, defaults to `[48, 96, 224, 448]`)
40
+ Dimensionality of each stage.
41
+ downsamples (`List(bool)`, *optional*, defaults to `[True, True, True, True]`)
42
+ Whether or not to downsample inputs between two stages.
43
+ dim (`int`, *optional*, defaults to 448):
44
+ Number of channels in Meta3D layers
45
+ key_dim (`int`, *optional*, defaults to 32):
46
+ The size of the key in meta3D block.
47
+ attention_ratio (`int`, *optional*, defaults to 4):
48
+ Ratio of the dimension of the query and value to the dimension of the key in MSHA block
49
+ resolution (`int`, *optional*, defaults to 7)
50
+ Size of each patch
51
+ num_hidden_layers (`int`, *optional*, defaults to 5):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 8):
54
+ Number of attention heads for each attention layer in the 3D MetaBlock.
55
+ mlp_expansion_ratio (`int`, *optional*, defaults to 4):
56
+ Ratio of size of the hidden dimensionality of an MLP to the dimensionality of its input.
57
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
58
+ The dropout probability for all fully connected layers in the embeddings and encoder.
59
+ patch_size (`int`, *optional*, defaults to 16):
60
+ The size (resolution) of each patch.
61
+ num_channels (`int`, *optional*, defaults to 3):
62
+ The number of input channels.
63
+ pool_size (`int`, *optional*, defaults to 3):
64
+ Kernel size of pooling layers.
65
+ downsample_patch_size (`int`, *optional*, defaults to 3):
66
+ The size of patches in downsampling layers.
67
+ downsample_stride (`int`, *optional*, defaults to 2):
68
+ The stride of convolution kernels in downsampling layers.
69
+ downsample_pad (`int`, *optional*, defaults to 1):
70
+ Padding in downsampling layers.
71
+ drop_path_rate (`int`, *optional*, defaults to 0):
72
+ Rate at which to increase dropout probability in DropPath.
73
+ num_meta3d_blocks (`int`, *optional*, defaults to 1):
74
+ The number of 3D MetaBlocks in the last stage.
75
+ distillation (`bool`, *optional*, defaults to `True`):
76
+ Whether to add a distillation head.
77
+ use_layer_scale (`bool`, *optional*, defaults to `True`):
78
+ Whether to scale outputs from token mixers.
79
+ layer_scale_init_value (`float`, *optional*, defaults to 1e-5):
80
+ Factor by which outputs from token mixers are scaled.
81
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
82
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
83
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
84
+ initializer_range (`float`, *optional*, defaults to 0.02):
85
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
86
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
87
+ The epsilon used by the layer normalization layers.
88
+ image_size (`int`, *optional*, defaults to `224`):
89
+ The size (resolution) of each image.
90
+
91
+ Example:
92
+
93
+ ```python
94
+ >>> from transformers import EfficientFormerConfig, EfficientFormerModel
95
+
96
+ >>> # Initializing a EfficientFormer efficientformer-l1 style configuration
97
+ >>> configuration = EfficientFormerConfig()
98
+
99
+ >>> # Initializing a EfficientFormerModel (with random weights) from the efficientformer-l3 style configuration
100
+ >>> model = EfficientFormerModel(configuration)
101
+
102
+ >>> # Accessing the model configuration
103
+ >>> configuration = model.config
104
+ ```"""
105
+
106
+ model_type = "efficientformer"
107
+
108
+ def __init__(
109
+ self,
110
+ depths: List[int] = [3, 2, 6, 4],
111
+ hidden_sizes: List[int] = [48, 96, 224, 448],
112
+ downsamples: List[bool] = [True, True, True, True],
113
+ dim: int = 448,
114
+ key_dim: int = 32,
115
+ attention_ratio: int = 4,
116
+ resolution: int = 7,
117
+ num_hidden_layers: int = 5,
118
+ num_attention_heads: int = 8,
119
+ mlp_expansion_ratio: int = 4,
120
+ hidden_dropout_prob: float = 0.0,
121
+ patch_size: int = 16,
122
+ num_channels: int = 3,
123
+ pool_size: int = 3,
124
+ downsample_patch_size: int = 3,
125
+ downsample_stride: int = 2,
126
+ downsample_pad: int = 1,
127
+ drop_path_rate: float = 0.0,
128
+ num_meta3d_blocks: int = 1,
129
+ distillation: bool = True,
130
+ use_layer_scale: bool = True,
131
+ layer_scale_init_value: float = 1e-5,
132
+ hidden_act: str = "gelu",
133
+ initializer_range: float = 0.02,
134
+ layer_norm_eps: float = 1e-12,
135
+ image_size: int = 224,
136
+ batch_norm_eps: float = 1e-05,
137
+ **kwargs,
138
+ ) -> None:
139
+ super().__init__(**kwargs)
140
+
141
+ self.hidden_act = hidden_act
142
+ self.hidden_dropout_prob = hidden_dropout_prob
143
+ self.hidden_sizes = hidden_sizes
144
+ self.num_hidden_layers = num_hidden_layers
145
+ self.num_attention_heads = num_attention_heads
146
+ self.initializer_range = initializer_range
147
+ self.layer_norm_eps = layer_norm_eps
148
+ self.patch_size = patch_size
149
+ self.num_channels = num_channels
150
+ self.depths = depths
151
+ self.mlp_expansion_ratio = mlp_expansion_ratio
152
+ self.downsamples = downsamples
153
+ self.dim = dim
154
+ self.key_dim = key_dim
155
+ self.attention_ratio = attention_ratio
156
+ self.resolution = resolution
157
+ self.pool_size = pool_size
158
+ self.downsample_patch_size = downsample_patch_size
159
+ self.downsample_stride = downsample_stride
160
+ self.downsample_pad = downsample_pad
161
+ self.drop_path_rate = drop_path_rate
162
+ self.num_meta3d_blocks = num_meta3d_blocks
163
+ self.distillation = distillation
164
+ self.use_layer_scale = use_layer_scale
165
+ self.layer_scale_init_value = layer_scale_init_value
166
+ self.image_size = image_size
167
+ self.batch_norm_eps = batch_norm_eps
168
+
169
+
170
+ __all__ = [
171
+ "EfficientFormerConfig",
172
+ ]
docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for EfficientFormer."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ....image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ....image_transforms import (
23
+ get_resize_output_image_size,
24
+ resize,
25
+ to_channel_dimension_format,
26
+ )
27
+ from ....image_utils import (
28
+ IMAGENET_DEFAULT_MEAN,
29
+ IMAGENET_DEFAULT_STD,
30
+ ChannelDimension,
31
+ ImageInput,
32
+ PILImageResampling,
33
+ infer_channel_dimension_format,
34
+ is_batched,
35
+ is_scaled_image,
36
+ to_numpy_array,
37
+ valid_images,
38
+ validate_kwargs,
39
+ validate_preprocess_arguments,
40
+ )
41
+ from ....utils import TensorType, logging
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ class EfficientFormerImageProcessor(BaseImageProcessor):
48
+ r"""
49
+ Constructs a EfficientFormer image processor.
50
+
51
+ Args:
52
+ do_resize (`bool`, *optional*, defaults to `True`):
53
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
54
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
55
+ size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
56
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
57
+ method.
58
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
59
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
60
+ `preprocess` method.
61
+ do_center_crop (`bool`, *optional*, defaults to `True`):
62
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
63
+ `preprocess` method.
64
+ crop_size (`Dict[str, int]` *optional*, defaults to 224):
65
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
66
+ method.
67
+ do_rescale (`bool`, *optional*, defaults to `True`):
68
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
69
+ parameter in the `preprocess` method.
70
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
71
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
72
+ `preprocess` method.
73
+ do_normalize:
74
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
75
+ method.
76
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
77
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
78
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
79
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
80
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
81
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
82
+ """
83
+
84
+ model_input_names = ["pixel_values"]
85
+
86
+ def __init__(
87
+ self,
88
+ do_resize: bool = True,
89
+ size: Optional[Dict[str, int]] = None,
90
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
91
+ do_center_crop: bool = True,
92
+ do_rescale: bool = True,
93
+ rescale_factor: Union[int, float] = 1 / 255,
94
+ crop_size: Dict[str, int] = None,
95
+ do_normalize: bool = True,
96
+ image_mean: Optional[Union[float, List[float]]] = None,
97
+ image_std: Optional[Union[float, List[float]]] = None,
98
+ **kwargs,
99
+ ) -> None:
100
+ super().__init__(**kwargs)
101
+ size = size if size is not None else {"height": 224, "width": 224}
102
+ size = get_size_dict(size)
103
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
104
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
105
+
106
+ self.do_resize = do_resize
107
+ self.do_rescale = do_rescale
108
+ self.do_normalize = do_normalize
109
+ self.do_center_crop = do_center_crop
110
+ self.crop_size = crop_size
111
+ self.size = size
112
+ self.resample = resample
113
+ self.rescale_factor = rescale_factor
114
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
115
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
116
+ self._valid_processor_keys = [
117
+ "images",
118
+ "do_resize",
119
+ "size",
120
+ "resample",
121
+ "do_center_crop",
122
+ "crop_size",
123
+ "do_rescale",
124
+ "rescale_factor",
125
+ "do_normalize",
126
+ "image_mean",
127
+ "image_std",
128
+ "return_tensors",
129
+ "data_format",
130
+ "input_data_format",
131
+ ]
132
+
133
+ def resize(
134
+ self,
135
+ image: np.ndarray,
136
+ size: Dict[str, int],
137
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
138
+ data_format: Optional[Union[str, ChannelDimension]] = None,
139
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
140
+ **kwargs,
141
+ ) -> np.ndarray:
142
+ """
143
+ Resize an image to `(size["height"], size["width"])`.
144
+
145
+ Args:
146
+ image (`np.ndarray`):
147
+ Image to resize.
148
+ size (`Dict[str, int]`):
149
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
150
+ resample:
151
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
152
+ data_format (`ChannelDimension` or `str`, *optional*):
153
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
154
+ image is used. Can be one of:
155
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
156
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
157
+ input_data_format (`ChannelDimension` or `str`, *optional*):
158
+ The channel dimension format of the input image. If not provided, it will be inferred.
159
+
160
+ Returns:
161
+ `np.ndarray`: The resized image.
162
+ """
163
+ size = get_size_dict(size)
164
+
165
+ if "shortest_edge" in size:
166
+ size = get_resize_output_image_size(
167
+ image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
168
+ )
169
+ # size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
170
+ elif "height" in size and "width" in size:
171
+ size = (size["height"], size["width"])
172
+ else:
173
+ raise ValueError(f"Size must contain 'height' and 'width' keys or 'shortest_edge' key. Got {size.keys()}")
174
+ return resize(
175
+ image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
176
+ )
177
+
178
+ def preprocess(
179
+ self,
180
+ images: ImageInput,
181
+ do_resize: Optional[bool] = None,
182
+ size: Dict[str, int] = None,
183
+ resample: PILImageResampling = None,
184
+ do_center_crop: Optional[bool] = None,
185
+ crop_size: Optional[int] = None,
186
+ do_rescale: Optional[bool] = None,
187
+ rescale_factor: Optional[float] = None,
188
+ do_normalize: Optional[bool] = None,
189
+ image_mean: Optional[Union[float, List[float]]] = None,
190
+ image_std: Optional[Union[float, List[float]]] = None,
191
+ return_tensors: Optional[Union[str, TensorType]] = None,
192
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
193
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
194
+ **kwargs,
195
+ ) -> BatchFeature:
196
+ """
197
+ Preprocess an image or batch of images.
198
+
199
+ Args:
200
+ images (`ImageInput`):
201
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
202
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
203
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
204
+ Whether to resize the image.
205
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
206
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
207
+ resizing.
208
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
209
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
210
+ an effect if `do_resize` is set to `True`.
211
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
212
+ Whether to center crop the image.
213
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
214
+ Whether to rescale the image values between [0 - 1].
215
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
216
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
217
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
218
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
219
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
220
+ Whether to normalize the image.
221
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
222
+ Image mean to use if `do_normalize` is set to `True`.
223
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
224
+ Image standard deviation to use if `do_normalize` is set to `True`.
225
+ return_tensors (`str` or `TensorType`, *optional*):
226
+ The type of tensors to return. Can be one of:
227
+ - Unset: Return a list of `np.ndarray`.
228
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
229
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
230
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
231
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
232
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
233
+ The channel dimension format for the output image. Can be one of:
234
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
235
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
236
+ - Unset: Use the channel dimension format of the input image.
237
+ input_data_format (`ChannelDimension` or `str`, *optional*):
238
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
239
+ from the input image. Can be one of:
240
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
241
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
242
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
243
+ """
244
+ do_resize = do_resize if do_resize is not None else self.do_resize
245
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
246
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
247
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
248
+ crop_size = crop_size if crop_size is not None else self.crop_size
249
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
250
+ resample = resample if resample is not None else self.resample
251
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
252
+ image_mean = image_mean if image_mean is not None else self.image_mean
253
+ image_std = image_std if image_std is not None else self.image_std
254
+
255
+ size = size if size is not None else self.size
256
+ size_dict = get_size_dict(size)
257
+
258
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
259
+
260
+ if not is_batched(images):
261
+ images = [images]
262
+
263
+ if not valid_images(images):
264
+ raise ValueError(
265
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
266
+ "torch.Tensor, tf.Tensor or jax.ndarray."
267
+ )
268
+ validate_preprocess_arguments(
269
+ do_rescale=do_rescale,
270
+ rescale_factor=rescale_factor,
271
+ do_normalize=do_normalize,
272
+ image_mean=image_mean,
273
+ image_std=image_std,
274
+ do_center_crop=do_center_crop,
275
+ crop_size=crop_size,
276
+ do_resize=do_resize,
277
+ size=size,
278
+ resample=resample,
279
+ )
280
+ # All transformations expect numpy arrays.
281
+ images = [to_numpy_array(image) for image in images]
282
+
283
+ if do_rescale and is_scaled_image(images[0]):
284
+ logger.warning_once(
285
+ "It looks like you are trying to rescale already rescaled images. If the input"
286
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
287
+ )
288
+
289
+ if input_data_format is None:
290
+ # We assume that all images have the same channel dimension format.
291
+ input_data_format = infer_channel_dimension_format(images[0])
292
+
293
+ if do_resize:
294
+ images = [
295
+ self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
296
+ for image in images
297
+ ]
298
+
299
+ if do_center_crop:
300
+ images = [
301
+ self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
302
+ ]
303
+
304
+ if do_rescale:
305
+ images = [
306
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
307
+ for image in images
308
+ ]
309
+
310
+ if do_normalize:
311
+ images = [
312
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
313
+ for image in images
314
+ ]
315
+
316
+ images = [
317
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
318
+ ]
319
+
320
+ data = {"pixel_values": images}
321
+ return BatchFeature(data=data, tensor_type=return_tensors)
322
+
323
+
324
+ __all__ = ["EfficientFormerImageProcessor"]
docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Snapchat Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch EfficientFormer model."""
16
+
17
+ import itertools
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ....activations import ACT2FN
27
+ from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
28
+ from ....modeling_utils import PreTrainedModel
29
+ from ....utils import (
30
+ ModelOutput,
31
+ add_code_sample_docstrings,
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ logging,
35
+ )
36
+ from .configuration_efficientformer import EfficientFormerConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ # General docstring
42
+ _CONFIG_FOR_DOC = "EfficientFormerConfig"
43
+
44
+ # Base docstring
45
+ _CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
46
+ _EXPECTED_OUTPUT_SHAPE = [1, 49, 448]
47
+
48
+ # Image classification docstring
49
+ _IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300"
50
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
51
+
52
+
53
+ class EfficientFormerPatchEmbeddings(nn.Module):
54
+ """
55
+ This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels,
56
+ height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride]
57
+ """
58
+
59
+ def __init__(self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True):
60
+ super().__init__()
61
+ self.num_channels = num_channels
62
+
63
+ self.projection = nn.Conv2d(
64
+ num_channels,
65
+ embed_dim,
66
+ kernel_size=config.downsample_patch_size,
67
+ stride=config.downsample_stride,
68
+ padding=config.downsample_pad,
69
+ )
70
+ self.norm = nn.BatchNorm2d(embed_dim, eps=config.batch_norm_eps) if apply_norm else nn.Identity()
71
+
72
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
73
+ batch_size, num_channels, height, width = pixel_values.shape
74
+ if num_channels != self.num_channels:
75
+ raise ValueError(
76
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
77
+ )
78
+
79
+ embeddings = self.projection(pixel_values)
80
+ embeddings = self.norm(embeddings)
81
+
82
+ return embeddings
83
+
84
+
85
+ class EfficientFormerSelfAttention(nn.Module):
86
+ def __init__(self, dim: int, key_dim: int, num_heads: int, attention_ratio: int, resolution: int):
87
+ super().__init__()
88
+
89
+ self.num_heads = num_heads
90
+ self.key_dim = key_dim
91
+ self.attention_ratio = attention_ratio
92
+ self.scale = key_dim**-0.5
93
+ self.total_key_dim = key_dim * num_heads
94
+ self.expanded_key_dim = int(attention_ratio * key_dim)
95
+ self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads)
96
+ hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2
97
+ self.qkv = nn.Linear(dim, hidden_size)
98
+ self.projection = nn.Linear(self.total_expanded_key_dim, dim)
99
+ points = list(itertools.product(range(resolution), range(resolution)))
100
+ num_points = len(points)
101
+ attention_offsets = {}
102
+ idxs = []
103
+ for point_1 in points:
104
+ for point_2 in points:
105
+ offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1]))
106
+ if offset not in attention_offsets:
107
+ attention_offsets[offset] = len(attention_offsets)
108
+ idxs.append(attention_offsets[offset])
109
+ self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
110
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(num_points, num_points))
111
+
112
+ @torch.no_grad()
113
+ def train(self, mode=True):
114
+ super().train(mode)
115
+ if mode and hasattr(self, "ab"):
116
+ del self.ab
117
+ else:
118
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
119
+
120
+ def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
121
+ batch_size, sequence_length, num_channels = hidden_states.shape
122
+ qkv = self.qkv(hidden_states)
123
+ query_layer, key_layer, value_layer = qkv.reshape(batch_size, sequence_length, self.num_heads, -1).split(
124
+ [self.key_dim, self.key_dim, self.expanded_key_dim], dim=3
125
+ )
126
+ query_layer = query_layer.permute(0, 2, 1, 3)
127
+ key_layer = key_layer.permute(0, 2, 1, 3)
128
+ value_layer = value_layer.permute(0, 2, 1, 3)
129
+
130
+ # set `model.to(torch_device)` won't change `self.ab.device`, if there is no follow-up `train` or `eval` call.
131
+ # Let's do it manually here, so users won't have to do this everytime.
132
+ if not self.training:
133
+ self.ab = self.ab.to(self.attention_biases.device)
134
+ attention_probs = (torch.matmul(query_layer, key_layer.transpose(-2, -1))) * self.scale + (
135
+ self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
136
+ )
137
+
138
+ attention_probs = attention_probs.softmax(dim=-1)
139
+
140
+ context_layer = torch.matmul(attention_probs, value_layer).transpose(1, 2)
141
+ context_layer = context_layer.reshape(batch_size, sequence_length, self.total_expanded_key_dim)
142
+ context_layer = self.projection(context_layer)
143
+
144
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
145
+
146
+ return outputs
147
+
148
+
149
+ class EfficientFormerConvStem(nn.Module):
150
+ def __init__(self, config: EfficientFormerConfig, out_channels: int):
151
+ super().__init__()
152
+
153
+ self.convolution1 = nn.Conv2d(config.num_channels, out_channels // 2, kernel_size=3, stride=2, padding=1)
154
+ self.batchnorm_before = nn.BatchNorm2d(out_channels // 2, eps=config.batch_norm_eps)
155
+
156
+ self.convolution2 = nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=2, padding=1)
157
+ self.batchnorm_after = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps)
158
+
159
+ self.activation = nn.ReLU()
160
+
161
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
162
+ features = self.batchnorm_before(self.convolution1(pixel_values))
163
+ features = self.activation(features)
164
+ features = self.batchnorm_after(self.convolution2(features))
165
+ features = self.activation(features)
166
+
167
+ return features
168
+
169
+
170
+ class EfficientFormerPooling(nn.Module):
171
+ def __init__(self, pool_size: int):
172
+ super().__init__()
173
+ self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
174
+
175
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
176
+ output = self.pool(hidden_states) - hidden_states
177
+ return output
178
+
179
+
180
+ class EfficientFormerDenseMlp(nn.Module):
181
+ def __init__(
182
+ self,
183
+ config: EfficientFormerConfig,
184
+ in_features: int,
185
+ hidden_features: Optional[int] = None,
186
+ out_features: Optional[int] = None,
187
+ ):
188
+ super().__init__()
189
+ out_features = out_features or in_features
190
+ hidden_features = hidden_features or in_features
191
+
192
+ self.linear_in = nn.Linear(in_features, hidden_features)
193
+ self.activation = ACT2FN[config.hidden_act]
194
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
195
+ self.linear_out = nn.Linear(hidden_features, out_features)
196
+
197
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
198
+ hidden_states = self.linear_in(hidden_states)
199
+ hidden_states = self.activation(hidden_states)
200
+ hidden_states = self.dropout(hidden_states)
201
+ hidden_states = self.linear_out(hidden_states)
202
+ hidden_states = self.dropout(hidden_states)
203
+
204
+ return hidden_states
205
+
206
+
207
+ class EfficientFormerConvMlp(nn.Module):
208
+ def __init__(
209
+ self,
210
+ config: EfficientFormerConfig,
211
+ in_features: int,
212
+ hidden_features: Optional[int] = None,
213
+ out_features: Optional[int] = None,
214
+ drop: float = 0.0,
215
+ ):
216
+ super().__init__()
217
+ out_features = out_features or in_features
218
+ hidden_features = hidden_features or in_features
219
+
220
+ self.convolution1 = nn.Conv2d(in_features, hidden_features, 1)
221
+ self.activation = ACT2FN[config.hidden_act]
222
+ self.convolution2 = nn.Conv2d(hidden_features, out_features, 1)
223
+ self.dropout = nn.Dropout(drop)
224
+
225
+ self.batchnorm_before = nn.BatchNorm2d(hidden_features, eps=config.batch_norm_eps)
226
+ self.batchnorm_after = nn.BatchNorm2d(out_features, eps=config.batch_norm_eps)
227
+
228
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
229
+ hidden_state = self.convolution1(hidden_state)
230
+ hidden_state = self.batchnorm_before(hidden_state)
231
+
232
+ hidden_state = self.activation(hidden_state)
233
+ hidden_state = self.dropout(hidden_state)
234
+ hidden_state = self.convolution2(hidden_state)
235
+
236
+ hidden_state = self.batchnorm_after(hidden_state)
237
+ hidden_state = self.dropout(hidden_state)
238
+
239
+ return hidden_state
240
+
241
+
242
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
243
+ """
244
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
245
+
246
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
247
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
248
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
249
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
250
+ argument.
251
+ """
252
+ if drop_prob == 0.0 or not training:
253
+ return input
254
+ keep_prob = 1 - drop_prob
255
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
256
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
257
+ random_tensor.floor_() # binarize
258
+ output = input.div(keep_prob) * random_tensor
259
+ return output
260
+
261
+
262
+ class EfficientFormerDropPath(nn.Module):
263
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
264
+
265
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
266
+ super().__init__()
267
+ self.drop_prob = drop_prob
268
+
269
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
270
+ return drop_path(hidden_states, self.drop_prob, self.training)
271
+
272
+ def extra_repr(self) -> str:
273
+ return "p={}".format(self.drop_prob)
274
+
275
+
276
+ class EfficientFormerFlat(nn.Module):
277
+ def __init__(self):
278
+ super().__init__()
279
+
280
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
281
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
282
+ return hidden_states
283
+
284
+
285
+ class EfficientFormerMeta3D(nn.Module):
286
+ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0):
287
+ super().__init__()
288
+
289
+ self.token_mixer = EfficientFormerSelfAttention(
290
+ dim=config.dim,
291
+ key_dim=config.key_dim,
292
+ num_heads=config.num_attention_heads,
293
+ attention_ratio=config.attention_ratio,
294
+ resolution=config.resolution,
295
+ )
296
+
297
+ self.layernorm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
298
+ self.layernorm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
299
+
300
+ mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
301
+ self.mlp = EfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim)
302
+
303
+ self.drop_path = EfficientFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
304
+ self.use_layer_scale = config.use_layer_scale
305
+ if config.use_layer_scale:
306
+ self.layer_scale_1 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
307
+ self.layer_scale_2 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
308
+
309
+ def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
310
+ self_attention_outputs = self.token_mixer(self.layernorm1(hidden_states), output_attentions)
311
+ attention_output = self_attention_outputs[0]
312
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
313
+
314
+ if self.use_layer_scale:
315
+ layer_output = hidden_states + self.drop_path(
316
+ self.layer_scale_1.unsqueeze(0).unsqueeze(0) * attention_output
317
+ )
318
+ layer_output = layer_output + self.drop_path(
319
+ self.layer_scale_2.unsqueeze(0).unsqueeze(0) * self.mlp(self.layernorm2(layer_output))
320
+ )
321
+ else:
322
+ layer_output = hidden_states + self.drop_path(attention_output)
323
+ layer_output = layer_output + self.drop_path(self.mlp(self.layernorm2(layer_output)))
324
+
325
+ outputs = (layer_output,) + outputs
326
+
327
+ return outputs
328
+
329
+
330
+ class EfficientFormerMeta3DLayers(nn.Module):
331
+ def __init__(self, config: EfficientFormerConfig):
332
+ super().__init__()
333
+ drop_paths = [
334
+ config.drop_path_rate * (block_idx + sum(config.depths[:-1]))
335
+ for block_idx in range(config.num_meta3d_blocks)
336
+ ]
337
+ self.blocks = nn.ModuleList(
338
+ [EfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path) for drop_path in drop_paths]
339
+ )
340
+
341
+ def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
342
+ all_attention_outputs = () if output_attentions else None
343
+
344
+ for layer_module in self.blocks:
345
+ if isinstance(hidden_states, tuple):
346
+ hidden_states = hidden_states[0]
347
+
348
+ hidden_states = layer_module(hidden_states, output_attentions)
349
+
350
+ if output_attentions:
351
+ all_attention_outputs = all_attention_outputs + (hidden_states[1],)
352
+
353
+ if output_attentions:
354
+ outputs = (hidden_states[0],) + all_attention_outputs
355
+ return outputs
356
+
357
+ return hidden_states
358
+
359
+
360
+ class EfficientFormerMeta4D(nn.Module):
361
+ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0):
362
+ super().__init__()
363
+ pool_size = config.pool_size if config.pool_size is not None else 3
364
+ self.token_mixer = EfficientFormerPooling(pool_size=pool_size)
365
+ mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
366
+ self.mlp = EfficientFormerConvMlp(
367
+ config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob
368
+ )
369
+
370
+ self.drop_path = EfficientFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
371
+ self.use_layer_scale = config.use_layer_scale
372
+ if config.use_layer_scale:
373
+ self.layer_scale_1 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
374
+ self.layer_scale_2 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
375
+
376
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
377
+ outputs = self.token_mixer(hidden_states)
378
+
379
+ if self.use_layer_scale:
380
+ layer_output = hidden_states + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * outputs)
381
+
382
+ layer_output = layer_output + self.drop_path(
383
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(layer_output)
384
+ )
385
+ else:
386
+ layer_output = hidden_states + self.drop_path(outputs)
387
+ layer_output = layer_output + self.drop_path(self.mlp(layer_output))
388
+
389
+ return layer_output
390
+
391
+
392
+ class EfficientFormerMeta4DLayers(nn.Module):
393
+ def __init__(self, config: EfficientFormerConfig, stage_idx: int):
394
+ super().__init__()
395
+ num_layers = (
396
+ config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks
397
+ )
398
+ drop_paths = [
399
+ config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers)
400
+ ]
401
+
402
+ self.blocks = nn.ModuleList(
403
+ [
404
+ EfficientFormerMeta4D(config, config.hidden_sizes[stage_idx], drop_path=drop_path)
405
+ for drop_path in drop_paths
406
+ ]
407
+ )
408
+
409
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
410
+ for layer_module in self.blocks:
411
+ hidden_states = layer_module(hidden_states)
412
+ return hidden_states
413
+
414
+
415
+ class EfficientFormerIntermediateStage(nn.Module):
416
+ def __init__(self, config: EfficientFormerConfig, index: int):
417
+ super().__init__()
418
+ self.meta4D_layers = EfficientFormerMeta4DLayers(config, index)
419
+
420
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
421
+ hidden_states = self.meta4D_layers(hidden_states)
422
+ return hidden_states
423
+
424
+
425
+ class EfficientFormerLastStage(nn.Module):
426
+ def __init__(self, config: EfficientFormerConfig):
427
+ super().__init__()
428
+ self.meta4D_layers = EfficientFormerMeta4DLayers(config, -1)
429
+ self.flat = EfficientFormerFlat()
430
+ self.meta3D_layers = EfficientFormerMeta3DLayers(config)
431
+
432
+ def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
433
+ hidden_states = self.meta4D_layers(hidden_states)
434
+ hidden_states = self.flat(hidden_states)
435
+ hidden_states = self.meta3D_layers(hidden_states, output_attentions)
436
+
437
+ return hidden_states
438
+
439
+
440
+ class EfficientFormerEncoder(nn.Module):
441
+ def __init__(self, config: EfficientFormerConfig):
442
+ super().__init__()
443
+ self.config = config
444
+ num_intermediate_stages = len(config.depths) - 1
445
+ downsamples = [
446
+ config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1]
447
+ for i in range(num_intermediate_stages)
448
+ ]
449
+ intermediate_stages = []
450
+
451
+ for i in range(num_intermediate_stages):
452
+ intermediate_stages.append(EfficientFormerIntermediateStage(config, i))
453
+ if downsamples[i]:
454
+ intermediate_stages.append(
455
+ EfficientFormerPatchEmbeddings(config, config.hidden_sizes[i], config.hidden_sizes[i + 1])
456
+ )
457
+
458
+ self.intermediate_stages = nn.ModuleList(intermediate_stages)
459
+ self.last_stage = EfficientFormerLastStage(config)
460
+
461
+ def forward(
462
+ self,
463
+ hidden_states: torch.Tensor,
464
+ output_hidden_states: bool = False,
465
+ output_attentions: bool = False,
466
+ return_dict: bool = True,
467
+ ) -> BaseModelOutput:
468
+ all_hidden_states = () if output_hidden_states else None
469
+ all_self_attentions = () if output_attentions else None
470
+
471
+ if output_hidden_states:
472
+ all_hidden_states = all_hidden_states + (hidden_states,)
473
+
474
+ for layer_module in self.intermediate_stages:
475
+ hidden_states = layer_module(hidden_states)
476
+ if output_hidden_states:
477
+ all_hidden_states = all_hidden_states + (hidden_states,)
478
+
479
+ layer_output = self.last_stage(hidden_states, output_attentions=output_attentions)
480
+
481
+ if output_attentions:
482
+ all_self_attentions = all_self_attentions + layer_output[1:]
483
+
484
+ if output_hidden_states:
485
+ all_hidden_states = all_hidden_states + (layer_output[0],)
486
+
487
+ if not return_dict:
488
+ return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None)
489
+
490
+ return BaseModelOutput(
491
+ last_hidden_state=layer_output[0],
492
+ hidden_states=all_hidden_states,
493
+ attentions=all_self_attentions,
494
+ )
495
+
496
+
497
+ class EfficientFormerPreTrainedModel(PreTrainedModel):
498
+ """
499
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
500
+ models.
501
+ """
502
+
503
+ config_class = EfficientFormerConfig
504
+ base_model_prefix = "efficientformer"
505
+ main_input_name = "pixel_values"
506
+ supports_gradient_checkpointing = False
507
+
508
+ def _init_weights(self, module: nn.Module):
509
+ """Initialize the weights"""
510
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
511
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
512
+ if module.bias is not None:
513
+ module.bias.data.zero_()
514
+ elif isinstance(module, nn.LayerNorm):
515
+ module.bias.data.zero_()
516
+ module.weight.data.fill_(1.0)
517
+
518
+
519
+ EFFICIENTFORMER_START_DOCSTRING = r"""
520
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) subclass. Use it as a
521
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
522
+
523
+ Parameters:
524
+ config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model.
525
+ Initializing with a config file does not load the weights associated with the model, only the
526
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
527
+ """
528
+
529
+ EFFICIENTFORMER_INPUTS_DOCSTRING = r"""
530
+ Args:
531
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
532
+ Pixel values. Pixel values can be obtained using [`ViTImageProcessor`]. See
533
+ [`ViTImageProcessor.preprocess`] for details.
534
+ output_attentions (`bool`, *optional*):
535
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
536
+ tensors for more detail.
537
+ output_hidden_states (`bool`, *optional*):
538
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
539
+ more detail.
540
+ return_dict (`bool`, *optional*):
541
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
542
+ """
543
+
544
+
545
+ @add_start_docstrings(
546
+ "The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.",
547
+ EFFICIENTFORMER_START_DOCSTRING,
548
+ )
549
+ class EfficientFormerModel(EfficientFormerPreTrainedModel):
550
+ def __init__(self, config: EfficientFormerConfig):
551
+ super().__init__(config)
552
+ self.config = config
553
+ _no_split_modules = ["EfficientFormerMeta4D"]
554
+
555
+ self.patch_embed = EfficientFormerConvStem(config, config.hidden_sizes[0])
556
+ self.encoder = EfficientFormerEncoder(config)
557
+ self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
558
+
559
+ # Initialize weights and apply final processing
560
+ self.post_init()
561
+
562
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
563
+ @add_code_sample_docstrings(
564
+ checkpoint=_CHECKPOINT_FOR_DOC,
565
+ output_type=BaseModelOutputWithPooling,
566
+ config_class=_CONFIG_FOR_DOC,
567
+ modality="vision",
568
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
569
+ )
570
+ def forward(
571
+ self,
572
+ pixel_values: Optional[torch.Tensor] = None,
573
+ output_attentions: Optional[bool] = None,
574
+ output_hidden_states: Optional[bool] = None,
575
+ return_dict: Optional[bool] = None,
576
+ ) -> Union[tuple, BaseModelOutput]:
577
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
578
+ output_hidden_states = (
579
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
580
+ )
581
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
582
+
583
+ if pixel_values is None:
584
+ raise ValueError("You have to specify pixel_values")
585
+
586
+ embedding_output = self.patch_embed(pixel_values)
587
+ encoder_outputs = self.encoder(
588
+ embedding_output, output_attentions=output_attentions, output_hidden_states=output_hidden_states
589
+ )
590
+
591
+ sequence_output = encoder_outputs[0]
592
+ sequence_output = self.layernorm(sequence_output)
593
+
594
+ if not return_dict:
595
+ head_outputs = (sequence_output,)
596
+ return head_outputs + encoder_outputs[1:]
597
+
598
+ return BaseModelOutput(
599
+ last_hidden_state=sequence_output,
600
+ hidden_states=encoder_outputs.hidden_states,
601
+ attentions=encoder_outputs.attentions,
602
+ )
603
+
604
+
605
+ @add_start_docstrings(
606
+ """
607
+ EfficientFormer Model transformer with an image classification head on top (a linear layer on top of the final
608
+ hidden state of the [CLS] token) e.g. for ImageNet.
609
+ """,
610
+ EFFICIENTFORMER_START_DOCSTRING,
611
+ )
612
+ class EfficientFormerForImageClassification(EfficientFormerPreTrainedModel):
613
+ def __init__(self, config: EfficientFormerConfig):
614
+ super().__init__(config)
615
+
616
+ self.num_labels = config.num_labels
617
+ self.efficientformer = EfficientFormerModel(config)
618
+
619
+ # Classifier head
620
+ self.classifier = (
621
+ nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
622
+ )
623
+
624
+ # Initialize weights and apply final processing
625
+ self.post_init()
626
+
627
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
628
+ @add_code_sample_docstrings(
629
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
630
+ output_type=ImageClassifierOutput,
631
+ config_class=_CONFIG_FOR_DOC,
632
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
633
+ )
634
+ def forward(
635
+ self,
636
+ pixel_values: Optional[torch.Tensor] = None,
637
+ labels: Optional[torch.Tensor] = None,
638
+ output_attentions: Optional[bool] = None,
639
+ output_hidden_states: Optional[bool] = None,
640
+ return_dict: Optional[bool] = None,
641
+ ) -> Union[tuple, ImageClassifierOutput]:
642
+ r"""
643
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
644
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
645
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
646
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
647
+ """
648
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
649
+
650
+ outputs = self.efficientformer(
651
+ pixel_values,
652
+ output_attentions=output_attentions,
653
+ output_hidden_states=output_hidden_states,
654
+ return_dict=return_dict,
655
+ )
656
+
657
+ sequence_output = outputs[0]
658
+
659
+ logits = self.classifier(sequence_output.mean(-2))
660
+
661
+ loss = None
662
+ if labels is not None:
663
+ if self.config.problem_type is None:
664
+ if self.num_labels == 1:
665
+ self.config.problem_type = "regression"
666
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
667
+ self.config.problem_type = "single_label_classification"
668
+ else:
669
+ self.config.problem_type = "multi_label_classification"
670
+
671
+ if self.config.problem_type == "regression":
672
+ loss_fct = MSELoss()
673
+ if self.num_labels == 1:
674
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
675
+ else:
676
+ loss = loss_fct(logits, labels)
677
+ elif self.config.problem_type == "single_label_classification":
678
+ loss_fct = CrossEntropyLoss()
679
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
680
+ elif self.config.problem_type == "multi_label_classification":
681
+ loss_fct = BCEWithLogitsLoss()
682
+ loss = loss_fct(logits, labels)
683
+
684
+ if not return_dict:
685
+ output = (logits,) + outputs[1:]
686
+ return ((loss,) + output) if loss is not None else output
687
+
688
+ return ImageClassifierOutput(
689
+ loss=loss,
690
+ logits=logits,
691
+ hidden_states=outputs.hidden_states,
692
+ attentions=outputs.attentions,
693
+ )
694
+
695
+
696
+ @dataclass
697
+ class EfficientFormerForImageClassificationWithTeacherOutput(ModelOutput):
698
+ """
699
+ Output type of [`EfficientFormerForImageClassificationWithTeacher`].
700
+
701
+ Args:
702
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
703
+ Prediction scores as the average of the cls_logits and distillation logits.
704
+ cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
705
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
706
+ class token).
707
+ distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
708
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
709
+ distillation token).
710
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
711
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
712
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
713
+ plus the initial embedding outputs.
714
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
715
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
716
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
717
+ the self-attention heads.
718
+ """
719
+
720
+ logits: Optional[torch.FloatTensor] = None
721
+ cls_logits: Optional[torch.FloatTensor] = None
722
+ distillation_logits: Optional[torch.FloatTensor] = None
723
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
724
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
725
+
726
+
727
+ @add_start_docstrings(
728
+ """
729
+ EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden
730
+ state of the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for
731
+ ImageNet.
732
+
733
+ <Tip warning={true}>
734
+
735
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
736
+ supported.
737
+
738
+ </Tip>
739
+ """,
740
+ EFFICIENTFORMER_START_DOCSTRING,
741
+ )
742
+ class EfficientFormerForImageClassificationWithTeacher(EfficientFormerPreTrainedModel):
743
+ def __init__(self, config: EfficientFormerConfig):
744
+ super().__init__(config)
745
+
746
+ self.num_labels = config.num_labels
747
+ self.efficientformer = EfficientFormerModel(config)
748
+
749
+ # Classifier head
750
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
751
+ # Distillation head
752
+ self.distillation_classifier = (
753
+ nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
754
+ )
755
+
756
+ # Initialize weights and apply final processing
757
+ self.post_init()
758
+
759
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
760
+ @add_code_sample_docstrings(
761
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
762
+ output_type=EfficientFormerForImageClassificationWithTeacherOutput,
763
+ config_class=_CONFIG_FOR_DOC,
764
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
765
+ )
766
+ def forward(
767
+ self,
768
+ pixel_values: Optional[torch.Tensor] = None,
769
+ output_attentions: Optional[bool] = None,
770
+ output_hidden_states: Optional[bool] = None,
771
+ return_dict: Optional[bool] = None,
772
+ ) -> Union[tuple, EfficientFormerForImageClassificationWithTeacherOutput]:
773
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
774
+ outputs = self.efficientformer(
775
+ pixel_values,
776
+ output_attentions=output_attentions,
777
+ output_hidden_states=output_hidden_states,
778
+ return_dict=return_dict,
779
+ )
780
+
781
+ sequence_output = outputs[0]
782
+
783
+ cls_logits = self.classifier(sequence_output.mean(-2))
784
+ distillation_logits = self.distillation_classifier(sequence_output.mean(-2))
785
+
786
+ # during inference, return the average of both classifier predictions
787
+ logits = (cls_logits + distillation_logits) / 2
788
+
789
+ if not return_dict:
790
+ output = (logits, cls_logits, distillation_logits) + outputs[1:]
791
+ return output
792
+
793
+ return EfficientFormerForImageClassificationWithTeacherOutput(
794
+ logits=logits,
795
+ cls_logits=cls_logits,
796
+ distillation_logits=distillation_logits,
797
+ hidden_states=outputs.hidden_states,
798
+ attentions=outputs.attentions,
799
+ )
800
+
801
+
802
+ __all__ = [
803
+ "EfficientFormerForImageClassification",
804
+ "EfficientFormerForImageClassificationWithTeacher",
805
+ "EfficientFormerModel",
806
+ "EfficientFormerPreTrainedModel",
807
+ ]
docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py ADDED
@@ -0,0 +1,1198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Snapchat Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TensorFlow EfficientFormer model."""
16
+
17
+ import itertools
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import tensorflow as tf
22
+
23
+ from ....activations_tf import ACT2FN
24
+ from ....modeling_tf_outputs import (
25
+ TFBaseModelOutput,
26
+ TFBaseModelOutputWithPooling,
27
+ TFImageClassifierOutput,
28
+ )
29
+ from ....modeling_tf_utils import (
30
+ TFPreTrainedModel,
31
+ TFSequenceClassificationLoss,
32
+ get_initializer,
33
+ keras,
34
+ keras_serializable,
35
+ unpack_inputs,
36
+ )
37
+ from ....tf_utils import shape_list, stable_softmax
38
+ from ....utils import (
39
+ ModelOutput,
40
+ add_code_sample_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ logging,
44
+ )
45
+ from .configuration_efficientformer import EfficientFormerConfig
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ # General docstring
51
+ _CONFIG_FOR_DOC = "EfficientFormerConfig"
52
+
53
+ # Base docstring
54
+ _CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
55
+ _EXPECTED_OUTPUT_SHAPE = [1, 49, 448]
56
+
57
+ # Image classification docstring
58
+ _IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300"
59
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_281"
60
+
61
+
62
+ class TFEfficientFormerPatchEmbeddings(keras.layers.Layer):
63
+ """
64
+ This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels,
65
+ height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride]
66
+ """
67
+
68
+ def __init__(
69
+ self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True, **kwargs
70
+ ) -> None:
71
+ super().__init__(**kwargs)
72
+ self.num_channels = num_channels
73
+
74
+ self.padding = keras.layers.ZeroPadding2D(padding=config.downsample_pad)
75
+ self.projection = keras.layers.Conv2D(
76
+ filters=embed_dim,
77
+ kernel_size=config.downsample_patch_size,
78
+ strides=config.downsample_stride,
79
+ padding="valid",
80
+ name="projection",
81
+ )
82
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
83
+ self.norm = (
84
+ keras.layers.BatchNormalization(axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
85
+ if apply_norm
86
+ else tf.identity
87
+ )
88
+ self.embed_dim = embed_dim
89
+
90
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
91
+ tf.debugging.assert_shapes(
92
+ [(pixel_values, (..., None, None, self.num_channels))],
93
+ message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
94
+ )
95
+ embeddings = self.projection(self.padding(pixel_values))
96
+ embeddings = self.norm(embeddings, training=training)
97
+ return embeddings
98
+
99
+ def build(self, input_shape=None):
100
+ if self.built:
101
+ return
102
+ self.built = True
103
+ if getattr(self, "projection", None) is not None:
104
+ with tf.name_scope(self.projection.name):
105
+ self.projection.build([None, None, None, self.num_channels])
106
+ if getattr(self, "norm", None) is not None:
107
+ if hasattr(self.norm, "name"):
108
+ with tf.name_scope(self.norm.name):
109
+ self.norm.build([None, None, None, self.embed_dim])
110
+
111
+
112
+ class TFEfficientFormerSelfAttention(keras.layers.Layer):
113
+ def __init__(
114
+ self,
115
+ dim: int,
116
+ key_dim: int,
117
+ num_heads: int,
118
+ attention_ratio: int,
119
+ resolution: int,
120
+ config: EfficientFormerConfig,
121
+ **kwargs,
122
+ ):
123
+ super().__init__(**kwargs)
124
+
125
+ self.num_heads = num_heads
126
+ self.key_dim = key_dim
127
+ self.attention_ratio = attention_ratio
128
+ self.scale = key_dim**-0.5
129
+ self.total_key_dim = key_dim * num_heads
130
+ self.expanded_key_dim = int(attention_ratio * key_dim)
131
+ self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads)
132
+ hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2
133
+
134
+ self.qkv = keras.layers.Dense(
135
+ units=hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="qkv"
136
+ )
137
+ self.projection = keras.layers.Dense(
138
+ units=dim, kernel_initializer=get_initializer(config.initializer_range), name="projection"
139
+ )
140
+ self.resolution = resolution
141
+ self.dim = dim
142
+
143
+ def build(self, input_shape: tf.TensorShape) -> None:
144
+ points = list(itertools.product(range(self.resolution), range(self.resolution)))
145
+ num_points = len(points)
146
+ attention_offsets = {}
147
+
148
+ idxs = []
149
+
150
+ for point_1 in points:
151
+ for point_2 in points:
152
+ offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1]))
153
+ if offset not in attention_offsets:
154
+ attention_offsets[offset] = len(attention_offsets)
155
+ idxs.append(attention_offsets[offset])
156
+
157
+ self.attention_biases = self.add_weight(
158
+ shape=(self.num_heads, len(attention_offsets)),
159
+ initializer=keras.initializers.zeros(),
160
+ trainable=True,
161
+ name="attention_biases",
162
+ )
163
+ self.attention_bias_idxs = self.add_weight(
164
+ shape=(num_points, num_points),
165
+ trainable=False,
166
+ dtype=tf.int32,
167
+ name="attention_bias_idxs",
168
+ )
169
+
170
+ self.attention_bias_idxs.assign(tf.reshape(tf.cast(idxs, dtype=tf.int32), (num_points, num_points)))
171
+
172
+ if self.built:
173
+ return
174
+ self.built = True
175
+ if getattr(self, "qkv", None) is not None:
176
+ with tf.name_scope(self.qkv.name):
177
+ self.qkv.build([None, None, self.dim])
178
+ if getattr(self, "projection", None) is not None:
179
+ with tf.name_scope(self.projection.name):
180
+ self.projection.build([None, None, self.total_expanded_key_dim])
181
+
182
+ def call(
183
+ self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
184
+ ) -> Tuple[tf.Tensor]:
185
+ batch_size, sequence_length, *_ = shape_list(hidden_states)
186
+ qkv = self.qkv(inputs=hidden_states)
187
+
188
+ query_layer, key_layer, value_layer = tf.split(
189
+ tf.reshape(tensor=qkv, shape=(batch_size, sequence_length, self.num_heads, -1)),
190
+ num_or_size_splits=[self.key_dim, self.key_dim, self.expanded_key_dim],
191
+ axis=3,
192
+ )
193
+
194
+ query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3])
195
+ key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3])
196
+ value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3])
197
+
198
+ attention_probs = tf.matmul(query_layer, tf.transpose(key_layer, perm=[0, 1, 3, 2]))
199
+ scale = tf.cast(self.scale, dtype=attention_probs.dtype)
200
+ attention_probs = tf.multiply(attention_probs, scale)
201
+
202
+ attention_biases = tf.gather(params=self.attention_biases, indices=self.attention_bias_idxs, axis=1)
203
+ attention_probs = attention_probs + attention_biases
204
+ attention_probs = stable_softmax(logits=attention_probs, axis=-1)
205
+
206
+ context_layer = tf.matmul(attention_probs, value_layer)
207
+ context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
208
+
209
+ context_layer = tf.reshape(
210
+ tensor=context_layer, shape=(batch_size, sequence_length, self.total_expanded_key_dim)
211
+ )
212
+ context_layer = self.projection(context_layer)
213
+
214
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
215
+
216
+ return outputs
217
+
218
+
219
+ class TFEfficientFormerConvStem(keras.layers.Layer):
220
+ def __init__(self, config: EfficientFormerConfig, out_channels: int, **kwargs):
221
+ super().__init__(**kwargs)
222
+
223
+ self.padding = keras.layers.ZeroPadding2D(padding=1)
224
+ self.convolution1 = keras.layers.Conv2D(
225
+ filters=out_channels // 2, kernel_size=3, strides=2, padding="valid", name="convolution1"
226
+ )
227
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
228
+ self.batchnorm_before = keras.layers.BatchNormalization(
229
+ axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before"
230
+ )
231
+
232
+ self.convolution2 = keras.layers.Conv2D(
233
+ filters=out_channels,
234
+ kernel_size=3,
235
+ strides=2,
236
+ padding="valid",
237
+ name="convolution2",
238
+ )
239
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
240
+ self.batchnorm_after = keras.layers.BatchNormalization(
241
+ axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after"
242
+ )
243
+
244
+ self.activation = keras.layers.Activation(activation=keras.activations.relu, name="activation")
245
+ self.out_channels = out_channels
246
+ self.config = config
247
+
248
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
249
+ features = self.batchnorm_before(self.convolution1(self.padding(pixel_values)), training=training)
250
+ features = self.activation(features)
251
+ features = self.batchnorm_after(self.convolution2(self.padding(features)), training=training)
252
+ features = self.activation(features)
253
+ return features
254
+
255
+ def build(self, input_shape=None):
256
+ if self.built:
257
+ return
258
+ self.built = True
259
+ if getattr(self, "convolution1", None) is not None:
260
+ with tf.name_scope(self.convolution1.name):
261
+ self.convolution1.build([None, None, None, self.config.num_channels])
262
+ if getattr(self, "batchnorm_before", None) is not None:
263
+ with tf.name_scope(self.batchnorm_before.name):
264
+ self.batchnorm_before.build([None, None, None, self.out_channels // 2])
265
+ if getattr(self, "convolution2", None) is not None:
266
+ with tf.name_scope(self.convolution2.name):
267
+ self.convolution2.build([None, None, None, self.out_channels // 2])
268
+ if getattr(self, "batchnorm_after", None) is not None:
269
+ with tf.name_scope(self.batchnorm_after.name):
270
+ self.batchnorm_after.build([None, None, None, self.out_channels])
271
+ if getattr(self, "activation", None) is not None:
272
+ with tf.name_scope(self.activation.name):
273
+ self.activation.build(None)
274
+
275
+
276
+ class TFEfficientFormerPooling(keras.layers.Layer):
277
+ def __init__(self, pool_size: int, **kwargs):
278
+ super().__init__(**kwargs)
279
+ self.pool = keras.layers.AveragePooling2D(pool_size=pool_size, strides=1, padding="same")
280
+
281
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
282
+ output = self.pool(hidden_states)
283
+ output = output - hidden_states
284
+ return output
285
+
286
+
287
+ class TFEfficientFormerDenseMlp(keras.layers.Layer):
288
+ def __init__(
289
+ self,
290
+ config: EfficientFormerConfig,
291
+ in_features: int,
292
+ hidden_features: Optional[int] = None,
293
+ out_features: Optional[int] = None,
294
+ **kwargs,
295
+ ):
296
+ super().__init__(**kwargs)
297
+ out_features = out_features or in_features
298
+ hidden_features = hidden_features or in_features
299
+
300
+ self.linear_in = keras.layers.Dense(
301
+ units=hidden_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_in"
302
+ )
303
+ self.activation = ACT2FN[config.hidden_act]
304
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
305
+
306
+ self.linear_out = keras.layers.Dense(
307
+ units=out_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_out"
308
+ )
309
+ self.hidden_features = hidden_features
310
+ self.in_features = in_features
311
+
312
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
313
+ hidden_states = self.linear_in(inputs=hidden_states)
314
+ hidden_states = self.activation(hidden_states)
315
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
316
+ hidden_states = self.linear_out(inputs=hidden_states)
317
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
318
+
319
+ return hidden_states
320
+
321
+ def build(self, input_shape=None):
322
+ if self.built:
323
+ return
324
+ self.built = True
325
+ if getattr(self, "linear_in", None) is not None:
326
+ with tf.name_scope(self.linear_in.name):
327
+ self.linear_in.build([None, None, self.in_features])
328
+ if getattr(self, "linear_out", None) is not None:
329
+ with tf.name_scope(self.linear_out.name):
330
+ self.linear_out.build([None, None, self.hidden_features])
331
+
332
+
333
+ class TFEfficientFormerConvMlp(keras.layers.Layer):
334
+ def __init__(
335
+ self,
336
+ config: EfficientFormerConfig,
337
+ in_features: int,
338
+ hidden_features: Optional[int] = None,
339
+ out_features: Optional[int] = None,
340
+ drop: float = 0.0,
341
+ **kwargs,
342
+ ):
343
+ super().__init__(**kwargs)
344
+ out_features = out_features or in_features
345
+ hidden_features = hidden_features or in_features
346
+
347
+ self.convolution1 = keras.layers.Conv2D(
348
+ filters=hidden_features,
349
+ kernel_size=1,
350
+ name="convolution1",
351
+ padding="valid",
352
+ )
353
+
354
+ self.activation = ACT2FN[config.hidden_act]
355
+
356
+ self.convolution2 = keras.layers.Conv2D(
357
+ filters=out_features,
358
+ kernel_size=1,
359
+ name="convolution2",
360
+ padding="valid",
361
+ )
362
+
363
+ self.dropout = keras.layers.Dropout(rate=drop)
364
+
365
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
366
+ self.batchnorm_before = keras.layers.BatchNormalization(
367
+ axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before"
368
+ )
369
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
370
+ self.batchnorm_after = keras.layers.BatchNormalization(
371
+ axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after"
372
+ )
373
+ self.hidden_features = hidden_features
374
+ self.in_features = in_features
375
+ self.out_features = out_features
376
+
377
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
378
+ hidden_state = self.convolution1(hidden_state)
379
+ hidden_state = self.batchnorm_before(hidden_state, training=training)
380
+ hidden_state = self.activation(hidden_state)
381
+ hidden_state = self.dropout(hidden_state, training=training)
382
+ hidden_state = self.convolution2(hidden_state)
383
+ hidden_state = self.batchnorm_after(hidden_state, training=training)
384
+ hidden_state = self.dropout(hidden_state, training=training)
385
+ return hidden_state
386
+
387
+ def build(self, input_shape=None):
388
+ if self.built:
389
+ return
390
+ self.built = True
391
+ if getattr(self, "convolution1", None) is not None:
392
+ with tf.name_scope(self.convolution1.name):
393
+ self.convolution1.build([None, None, None, self.in_features])
394
+ if getattr(self, "convolution2", None) is not None:
395
+ with tf.name_scope(self.convolution2.name):
396
+ self.convolution2.build([None, None, None, self.hidden_features])
397
+ if getattr(self, "batchnorm_before", None) is not None:
398
+ with tf.name_scope(self.batchnorm_before.name):
399
+ self.batchnorm_before.build([None, None, None, self.hidden_features])
400
+ if getattr(self, "batchnorm_after", None) is not None:
401
+ with tf.name_scope(self.batchnorm_after.name):
402
+ self.batchnorm_after.build([None, None, None, self.out_features])
403
+
404
+
405
+ # Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->EfficientFormer
406
+ class TFEfficientFormerDropPath(keras.layers.Layer):
407
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
408
+ References:
409
+ (1) github.com:rwightman/pytorch-image-models
410
+ """
411
+
412
+ def __init__(self, drop_path: float, **kwargs):
413
+ super().__init__(**kwargs)
414
+ self.drop_path = drop_path
415
+
416
+ def call(self, x: tf.Tensor, training=None):
417
+ if training:
418
+ keep_prob = 1 - self.drop_path
419
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
420
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
421
+ random_tensor = tf.floor(random_tensor)
422
+ return (x / keep_prob) * random_tensor
423
+ return x
424
+
425
+
426
+ class TFEfficientFormerFlat(keras.layers.Layer):
427
+ def __init__(self, **kwargs):
428
+ super().__init__(**kwargs)
429
+
430
+ def call(self, hidden_states: tf.Tensor) -> Tuple[tf.Tensor]:
431
+ batch_size, _, _, in_channels = shape_list(hidden_states)
432
+ hidden_states = tf.reshape(hidden_states, shape=[batch_size, -1, in_channels])
433
+ return hidden_states
434
+
435
+
436
+ class TFEfficientFormerMeta3D(keras.layers.Layer):
437
+ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs):
438
+ super().__init__(**kwargs)
439
+
440
+ self.token_mixer = TFEfficientFormerSelfAttention(
441
+ dim=config.dim,
442
+ key_dim=config.key_dim,
443
+ num_heads=config.num_attention_heads,
444
+ attention_ratio=config.attention_ratio,
445
+ resolution=config.resolution,
446
+ name="token_mixer",
447
+ config=config,
448
+ )
449
+ self.dim = dim
450
+ self.config = config
451
+
452
+ self.layernorm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm1")
453
+ self.layernorm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm2")
454
+ mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
455
+ self.mlp = TFEfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim, name="mlp")
456
+
457
+ # Using `layers.Activation` instead of `tf.identity` to better control `training' behavior.
458
+ self.drop_path = (
459
+ TFEfficientFormerDropPath(drop_path)
460
+ if drop_path > 0.0
461
+ else keras.layers.Activation("linear", name="drop_path")
462
+ )
463
+ self.config = config
464
+
465
+ def build(self, input_shape=None):
466
+ self.layer_scale_1 = None
467
+ self.layer_scale_2 = None
468
+
469
+ if self.config.use_layer_scale:
470
+ self.layer_scale_1 = self.add_weight(
471
+ shape=(self.dim,),
472
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
473
+ trainable=True,
474
+ name="layer_scale_1",
475
+ )
476
+ self.layer_scale_2 = self.add_weight(
477
+ shape=(self.dim,),
478
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
479
+ trainable=True,
480
+ name="layer_scale_2",
481
+ )
482
+
483
+ if self.built:
484
+ return
485
+ self.built = True
486
+ if getattr(self, "token_mixer", None) is not None:
487
+ with tf.name_scope(self.token_mixer.name):
488
+ self.token_mixer.build(None)
489
+ if getattr(self, "layernorm1", None) is not None:
490
+ with tf.name_scope(self.layernorm1.name):
491
+ self.layernorm1.build([None, None, self.dim])
492
+ if getattr(self, "layernorm2", None) is not None:
493
+ with tf.name_scope(self.layernorm2.name):
494
+ self.layernorm2.build([None, None, self.dim])
495
+ if getattr(self, "mlp", None) is not None:
496
+ with tf.name_scope(self.mlp.name):
497
+ self.mlp.build(None)
498
+ if getattr(self, "drop_path", None) is not None:
499
+ with tf.name_scope(self.drop_path.name):
500
+ self.drop_path.build(None)
501
+
502
+ def call(
503
+ self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
504
+ ) -> Tuple[tf.Tensor]:
505
+ self_attention_outputs = self.token_mixer(
506
+ hidden_states=self.layernorm1(hidden_states, training=training),
507
+ output_attentions=output_attentions,
508
+ training=training,
509
+ )
510
+
511
+ attention_output = self_attention_outputs[0]
512
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
513
+
514
+ if self.config.use_layer_scale:
515
+ layer_output = hidden_states + self.drop_path(
516
+ tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * attention_output,
517
+ training=training,
518
+ )
519
+ layer_output = layer_output + self.drop_path(
520
+ tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0)
521
+ * self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training),
522
+ training=training,
523
+ )
524
+ else:
525
+ layer_output = hidden_states + self.drop_path(attention_output, training=training)
526
+ layer_output = layer_output + self.drop_path(
527
+ self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training),
528
+ training=training,
529
+ )
530
+
531
+ outputs = (layer_output,) + outputs
532
+
533
+ return outputs
534
+
535
+
536
+ class TFEfficientFormerMeta3DLayers(keras.layers.Layer):
537
+ def __init__(self, config: EfficientFormerConfig, **kwargs):
538
+ super().__init__(**kwargs)
539
+ drop_paths = [
540
+ config.drop_path_rate * (block_idx + sum(config.depths[:-1]))
541
+ for block_idx in range(config.num_meta3d_blocks)
542
+ ]
543
+ self.blocks = [
544
+ TFEfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path, name=f"blocks.{i}")
545
+ for i, drop_path in enumerate(drop_paths)
546
+ ]
547
+
548
+ def call(
549
+ self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
550
+ ) -> Tuple[tf.Tensor]:
551
+ all_attention_outputs = () if output_attentions else None
552
+
553
+ for i, layer_module in enumerate(self.blocks):
554
+ if isinstance(hidden_states, tuple):
555
+ hidden_states = hidden_states[0]
556
+
557
+ hidden_states = layer_module(
558
+ hidden_states=hidden_states, output_attentions=output_attentions, training=training
559
+ )
560
+ if output_attentions:
561
+ all_attention_outputs = all_attention_outputs + (hidden_states[1],)
562
+
563
+ if output_attentions:
564
+ outputs = (hidden_states[0],) + all_attention_outputs
565
+ return outputs
566
+
567
+ return hidden_states
568
+
569
+ def build(self, input_shape=None):
570
+ if self.built:
571
+ return
572
+ self.built = True
573
+ if getattr(self, "blocks", None) is not None:
574
+ for layer in self.blocks:
575
+ with tf.name_scope(layer.name):
576
+ layer.build(None)
577
+
578
+
579
+ class TFEfficientFormerMeta4D(keras.layers.Layer):
580
+ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs):
581
+ super().__init__(**kwargs)
582
+ pool_size = config.pool_size if config.pool_size is not None else 3
583
+ self.token_mixer = TFEfficientFormerPooling(pool_size=pool_size, name="token_mixer")
584
+ self.dim = dim
585
+ mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
586
+ self.mlp = TFEfficientFormerConvMlp(
587
+ config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob, name="mlp"
588
+ )
589
+
590
+ self.drop_path = (
591
+ TFEfficientFormerDropPath(drop_path, name="drop_path")
592
+ if drop_path > 0.0
593
+ else keras.layers.Activation("linear", name="drop_path")
594
+ )
595
+ self.config = config
596
+
597
+ def build(self, input_shape=None):
598
+ self.layer_scale_1 = None
599
+ self.layer_scale_2 = None
600
+
601
+ if self.config.use_layer_scale:
602
+ self.layer_scale_1 = self.add_weight(
603
+ shape=(self.dim),
604
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
605
+ trainable=True,
606
+ name="layer_scale_1",
607
+ )
608
+ self.layer_scale_2 = self.add_weight(
609
+ shape=(self.dim),
610
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
611
+ trainable=True,
612
+ name="layer_scale_2",
613
+ )
614
+
615
+ if self.built:
616
+ return
617
+ self.built = True
618
+ if getattr(self, "token_mixer", None) is not None:
619
+ with tf.name_scope(self.token_mixer.name):
620
+ self.token_mixer.build(None)
621
+ if getattr(self, "mlp", None) is not None:
622
+ with tf.name_scope(self.mlp.name):
623
+ self.mlp.build(None)
624
+ if getattr(self, "drop_path", None) is not None:
625
+ with tf.name_scope(self.drop_path.name):
626
+ self.drop_path.build(None)
627
+
628
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
629
+ outputs = self.token_mixer(hidden_states)
630
+
631
+ if self.config.use_layer_scale:
632
+ layer_output = hidden_states + self.drop_path(
633
+ tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * outputs,
634
+ training=training,
635
+ )
636
+
637
+ layer_output = layer_output + self.drop_path(
638
+ tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0)
639
+ * self.mlp(hidden_state=layer_output, training=training),
640
+ training=training,
641
+ )
642
+
643
+ else:
644
+ layer_output = hidden_states + self.drop_path(outputs, training=training)
645
+ layer_output = layer_output + self.drop_path(
646
+ self.mlp(hidden_state=layer_output, training=training), training=training
647
+ )
648
+
649
+ return layer_output
650
+
651
+
652
+ class TFEfficientFormerMeta4DLayers(keras.layers.Layer):
653
+ def __init__(self, config: EfficientFormerConfig, stage_idx: int, **kwargs):
654
+ super().__init__(**kwargs)
655
+ num_layers = (
656
+ config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks
657
+ )
658
+ drop_paths = [
659
+ config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers)
660
+ ]
661
+
662
+ self.blocks = [
663
+ TFEfficientFormerMeta4D(
664
+ config=config, dim=config.hidden_sizes[stage_idx], drop_path=drop_paths[i], name=f"blocks.{i}"
665
+ )
666
+ for i in range(len(drop_paths))
667
+ ]
668
+
669
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
670
+ for layer_module in self.blocks:
671
+ hidden_states = layer_module(hidden_states=hidden_states, training=training)
672
+ return hidden_states
673
+
674
+ def build(self, input_shape=None):
675
+ if self.built:
676
+ return
677
+ self.built = True
678
+ if getattr(self, "blocks", None) is not None:
679
+ for layer in self.blocks:
680
+ with tf.name_scope(layer.name):
681
+ layer.build(None)
682
+
683
+
684
+ class TFEfficientFormerIntermediateStage(keras.layers.Layer):
685
+ def __init__(self, config: EfficientFormerConfig, index: int, **kwargs):
686
+ super().__init__(**kwargs)
687
+ self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=index, name="meta4D_layers")
688
+
689
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
690
+ hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training)
691
+ return hidden_states
692
+
693
+ def build(self, input_shape=None):
694
+ if self.built:
695
+ return
696
+ self.built = True
697
+ if getattr(self, "meta4D_layers", None) is not None:
698
+ with tf.name_scope(self.meta4D_layers.name):
699
+ self.meta4D_layers.build(None)
700
+
701
+
702
+ class TFEfficientFormerLastStage(keras.layers.Layer):
703
+ def __init__(self, config: EfficientFormerConfig, **kwargs):
704
+ super().__init__(**kwargs)
705
+ self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=-1, name="meta4D_layers")
706
+ self.flat = TFEfficientFormerFlat(name="flat")
707
+ self.meta3D_layers = TFEfficientFormerMeta3DLayers(config, name="meta3D_layers")
708
+
709
+ def call(
710
+ self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
711
+ ) -> Tuple[tf.Tensor]:
712
+ hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training)
713
+ hidden_states = self.flat(hidden_states=hidden_states)
714
+ hidden_states = self.meta3D_layers(
715
+ hidden_states=hidden_states, output_attentions=output_attentions, training=training
716
+ )
717
+
718
+ return hidden_states
719
+
720
+ def build(self, input_shape=None):
721
+ if self.built:
722
+ return
723
+ self.built = True
724
+ if getattr(self, "meta4D_layers", None) is not None:
725
+ with tf.name_scope(self.meta4D_layers.name):
726
+ self.meta4D_layers.build(None)
727
+ if getattr(self, "flat", None) is not None:
728
+ with tf.name_scope(self.flat.name):
729
+ self.flat.build(None)
730
+ if getattr(self, "meta3D_layers", None) is not None:
731
+ with tf.name_scope(self.meta3D_layers.name):
732
+ self.meta3D_layers.build(None)
733
+
734
+
735
+ class TFEfficientFormerEncoder(keras.layers.Layer):
736
+ def __init__(self, config: EfficientFormerConfig, **kwargs):
737
+ super().__init__(**kwargs)
738
+
739
+ self.config = config
740
+ num_intermediate_stages = len(config.depths) - 1
741
+ downsamples = [
742
+ config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1]
743
+ for i in range(num_intermediate_stages)
744
+ ]
745
+
746
+ intermediate_stages = []
747
+ layer_count = -1
748
+ for i in range(num_intermediate_stages):
749
+ layer_count += 1
750
+ intermediate_stages.append(
751
+ TFEfficientFormerIntermediateStage(config, i, name=f"intermediate_stages.{layer_count}")
752
+ )
753
+ if downsamples[i]:
754
+ layer_count += 1
755
+ intermediate_stages.append(
756
+ TFEfficientFormerPatchEmbeddings(
757
+ config,
758
+ config.hidden_sizes[i],
759
+ config.hidden_sizes[i + 1],
760
+ name=f"intermediate_stages.{layer_count}",
761
+ )
762
+ )
763
+ self.intermediate_stages = intermediate_stages
764
+ self.last_stage = TFEfficientFormerLastStage(config, name="last_stage")
765
+
766
+ def call(
767
+ self,
768
+ hidden_states: tf.Tensor,
769
+ output_hidden_states: bool,
770
+ output_attentions: bool,
771
+ return_dict: bool,
772
+ training: bool = False,
773
+ ) -> TFBaseModelOutput:
774
+ all_hidden_states = () if output_hidden_states else None
775
+ all_self_attentions = () if output_attentions else None
776
+
777
+ if output_hidden_states:
778
+ all_hidden_states = all_hidden_states + (hidden_states,)
779
+
780
+ for layer_module in self.intermediate_stages:
781
+ hidden_states = layer_module(hidden_states, training=training)
782
+
783
+ if output_hidden_states:
784
+ all_hidden_states = all_hidden_states + (hidden_states,)
785
+
786
+ layer_output = self.last_stage(hidden_states, output_attentions=output_attentions, training=training)
787
+
788
+ if output_attentions:
789
+ all_self_attentions = all_self_attentions + layer_output[1:]
790
+
791
+ if output_hidden_states:
792
+ all_hidden_states = all_hidden_states + (layer_output[0],)
793
+
794
+ if not return_dict:
795
+ return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None)
796
+
797
+ return TFBaseModelOutput(
798
+ last_hidden_state=layer_output[0],
799
+ hidden_states=all_hidden_states,
800
+ attentions=all_self_attentions,
801
+ )
802
+
803
+ def build(self, input_shape=None):
804
+ if self.built:
805
+ return
806
+ self.built = True
807
+ if getattr(self, "last_stage", None) is not None:
808
+ with tf.name_scope(self.last_stage.name):
809
+ self.last_stage.build(None)
810
+ for layer in self.intermediate_stages:
811
+ with tf.name_scope(layer.name):
812
+ layer.build(None)
813
+
814
+
815
+ @keras_serializable
816
+ class TFEfficientFormerMainLayer(keras.layers.Layer):
817
+ config_class = EfficientFormerConfig
818
+
819
+ def __init__(self, config: EfficientFormerConfig, **kwargs) -> None:
820
+ super().__init__(**kwargs)
821
+ self.config = config
822
+
823
+ self.patch_embed = TFEfficientFormerConvStem(config, config.hidden_sizes[0], name="patch_embed")
824
+ self.encoder = TFEfficientFormerEncoder(config, name="encoder")
825
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
826
+
827
+ @unpack_inputs
828
+ def call(
829
+ self,
830
+ pixel_values: Optional[tf.Tensor] = None,
831
+ output_attentions: Optional[tf.Tensor] = None,
832
+ output_hidden_states: Optional[tf.Tensor] = None,
833
+ return_dict: Optional[bool] = None,
834
+ training: bool = False,
835
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor, ...]]:
836
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
837
+
838
+ output_hidden_states = (
839
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
840
+ )
841
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
842
+
843
+ if pixel_values is None:
844
+ raise ValueError("You have to specify pixel_values")
845
+
846
+ # When running on CPU, keras.layers.Conv2D and keras.layers.AveragePool2D do not
847
+ # support channels first NCHW format. A number of blocks contain both.
848
+ # So change the input format from (batch_size, num_channels, height, width) to
849
+ # (batch_size, height, width, num_channels) here.
850
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
851
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
852
+ embedding_output = self.patch_embed(pixel_values, training=training)
853
+
854
+ encoder_outputs = self.encoder(
855
+ hidden_states=embedding_output,
856
+ output_attentions=output_attentions,
857
+ output_hidden_states=output_hidden_states,
858
+ return_dict=return_dict,
859
+ training=training,
860
+ )
861
+
862
+ sequence_output = encoder_outputs[0]
863
+ sequence_output = self.layernorm(sequence_output, training=training)
864
+
865
+ # Change the hidden states from (batch_size, height, width, num_channels) to
866
+ # (batch_size, num_channels, height, width).
867
+ # The hidden states are in (batch_size, height, width, num_channels)
868
+ # shape after all stages except the MB3D blocks.
869
+ if output_hidden_states:
870
+ hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1][:-1]]) + (
871
+ encoder_outputs[1][-1],
872
+ )
873
+
874
+ if not return_dict:
875
+ head_outputs = (sequence_output,)
876
+ return head_outputs + encoder_outputs[1:]
877
+
878
+ return TFBaseModelOutput(
879
+ last_hidden_state=sequence_output,
880
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
881
+ attentions=encoder_outputs.attentions,
882
+ )
883
+
884
+ def build(self, input_shape=None):
885
+ if self.built:
886
+ return
887
+ self.built = True
888
+ if getattr(self, "patch_embed", None) is not None:
889
+ with tf.name_scope(self.patch_embed.name):
890
+ self.patch_embed.build(None)
891
+ if getattr(self, "encoder", None) is not None:
892
+ with tf.name_scope(self.encoder.name):
893
+ self.encoder.build(None)
894
+ if getattr(self, "layernorm", None) is not None:
895
+ with tf.name_scope(self.layernorm.name):
896
+ self.layernorm.build([None, None, self.config.hidden_sizes[-1]])
897
+
898
+
899
+ class TFEfficientFormerPreTrainedModel(TFPreTrainedModel):
900
+ """
901
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
902
+ models.
903
+ """
904
+
905
+ config_class = EfficientFormerConfig
906
+ base_model_prefix = "efficientformer"
907
+ main_input_name = "pixel_values"
908
+
909
+
910
+ EFFICIENTFORMER_START_DOCSTRING = r"""
911
+ This model is a TensorFlow
912
+ [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular
913
+ TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior.
914
+
915
+
916
+ Parameters:
917
+ config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model.
918
+ Initializing with a config file does not load the weights associated with the model, only the
919
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
920
+ """
921
+
922
+ EFFICIENTFORMER_INPUTS_DOCSTRING = r"""
923
+ Args:
924
+ pixel_values ((`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
925
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
926
+ [`EfficientFormerImageProcessor.__call__`] for details.
927
+ output_attentions (`bool`, *optional*):
928
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
929
+ tensors for more detail.
930
+ output_hidden_states (`bool`, *optional*):
931
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
932
+ more detail.
933
+ return_dict (`bool`, *optional*):
934
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
935
+ """
936
+
937
+
938
+ @add_start_docstrings(
939
+ "The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.",
940
+ EFFICIENTFORMER_START_DOCSTRING,
941
+ )
942
+ class TFEfficientFormerModel(TFEfficientFormerPreTrainedModel):
943
+ def __init__(self, config: EfficientFormerConfig, **kwargs) -> None:
944
+ super().__init__(config, **kwargs)
945
+
946
+ self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer")
947
+
948
+ @unpack_inputs
949
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
950
+ @add_code_sample_docstrings(
951
+ checkpoint=_CHECKPOINT_FOR_DOC,
952
+ output_type=TFBaseModelOutputWithPooling,
953
+ config_class=_CONFIG_FOR_DOC,
954
+ modality="vision",
955
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
956
+ )
957
+ def call(
958
+ self,
959
+ pixel_values: Optional[tf.Tensor] = None,
960
+ output_attentions: Optional[bool] = None,
961
+ output_hidden_states: Optional[bool] = None,
962
+ return_dict: Optional[bool] = None,
963
+ training: bool = False,
964
+ ) -> Union[Tuple, TFBaseModelOutput]:
965
+ outputs = self.efficientformer(
966
+ pixel_values=pixel_values,
967
+ output_attentions=output_attentions,
968
+ output_hidden_states=output_hidden_states,
969
+ return_dict=return_dict,
970
+ training=training,
971
+ )
972
+ return outputs
973
+
974
+ def build(self, input_shape=None):
975
+ if self.built:
976
+ return
977
+ self.built = True
978
+ if getattr(self, "efficientformer", None) is not None:
979
+ with tf.name_scope(self.efficientformer.name):
980
+ self.efficientformer.build(None)
981
+
982
+
983
+ @add_start_docstrings(
984
+ """
985
+ EfficientFormer Model transformer with an image classification head on top of pooled last hidden state, e.g. for
986
+ ImageNet.
987
+ """,
988
+ EFFICIENTFORMER_START_DOCSTRING,
989
+ )
990
+ class TFEfficientFormerForImageClassification(TFEfficientFormerPreTrainedModel, TFSequenceClassificationLoss):
991
+ def __init__(self, config: EfficientFormerConfig):
992
+ super().__init__(config)
993
+
994
+ self.num_labels = config.num_labels
995
+ self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer")
996
+
997
+ # Classifier head
998
+ self.classifier = (
999
+ keras.layers.Dense(config.num_labels, name="classifier")
1000
+ if config.num_labels > 0
1001
+ else keras.layers.Activation("linear", name="classifier")
1002
+ )
1003
+ self.config = config
1004
+
1005
+ @unpack_inputs
1006
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
1007
+ @add_code_sample_docstrings(
1008
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1009
+ output_type=TFImageClassifierOutput,
1010
+ config_class=_CONFIG_FOR_DOC,
1011
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1012
+ )
1013
+ def call(
1014
+ self,
1015
+ pixel_values: Optional[tf.Tensor] = None,
1016
+ labels: Optional[tf.Tensor] = None,
1017
+ output_attentions: Optional[bool] = None,
1018
+ output_hidden_states: Optional[bool] = None,
1019
+ return_dict: Optional[bool] = None,
1020
+ training: bool = False,
1021
+ ) -> Union[tf.Tensor, TFImageClassifierOutput]:
1022
+ r"""
1023
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1024
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1025
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1026
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1027
+ """
1028
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1029
+
1030
+ outputs = self.efficientformer(
1031
+ pixel_values=pixel_values,
1032
+ output_attentions=output_attentions,
1033
+ output_hidden_states=output_hidden_states,
1034
+ return_dict=return_dict,
1035
+ training=training,
1036
+ )
1037
+
1038
+ sequence_output = outputs[0]
1039
+
1040
+ logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2))
1041
+
1042
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
1043
+
1044
+ if not return_dict:
1045
+ output = (logits,) + outputs[1:]
1046
+ return ((loss,) + output) if loss is not None else output
1047
+
1048
+ return TFImageClassifierOutput(
1049
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1050
+ )
1051
+
1052
+ def build(self, input_shape=None):
1053
+ if self.built:
1054
+ return
1055
+ self.built = True
1056
+ if getattr(self, "efficientformer", None) is not None:
1057
+ with tf.name_scope(self.efficientformer.name):
1058
+ self.efficientformer.build(None)
1059
+ if getattr(self, "classifier", None) is not None:
1060
+ if hasattr(self.classifier, "name"):
1061
+ with tf.name_scope(self.classifier.name):
1062
+ self.classifier.build([None, None, self.config.hidden_sizes[-1]])
1063
+
1064
+
1065
+ @dataclass
1066
+ class TFEfficientFormerForImageClassificationWithTeacherOutput(ModelOutput):
1067
+ """
1068
+ Args:
1069
+ Output type of [`EfficientFormerForImageClassificationWithTeacher`].
1070
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
1071
+ Prediction scores as the average of the cls_logits and distillation logits.
1072
+ cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
1073
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
1074
+ class token).
1075
+ distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
1076
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
1077
+ distillation token).
1078
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
1079
+ `config.output_hidden_states=True`):
1080
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
1081
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
1082
+ the initial embedding outputs.
1083
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
1084
+ `config.output_attentions=True`):
1085
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1086
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
1087
+ the self-attention heads.
1088
+ """
1089
+
1090
+ logits: Optional[tf.Tensor] = None
1091
+ cls_logits: Optional[tf.Tensor] = None
1092
+ distillation_logits: Optional[tf.Tensor] = None
1093
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
1094
+ attentions: Optional[Tuple[tf.Tensor]] = None
1095
+
1096
+
1097
+ @add_start_docstrings(
1098
+ """
1099
+ EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden
1100
+ state and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
1101
+
1102
+ .. warning::
1103
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
1104
+ supported.
1105
+ """,
1106
+ EFFICIENTFORMER_START_DOCSTRING,
1107
+ )
1108
+ class TFEfficientFormerForImageClassificationWithTeacher(TFEfficientFormerPreTrainedModel):
1109
+ def __init__(self, config: EfficientFormerConfig) -> None:
1110
+ super().__init__(config)
1111
+
1112
+ self.num_labels = config.num_labels
1113
+ self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer")
1114
+
1115
+ # Classifier heads
1116
+ self.classifier = (
1117
+ keras.layers.Dense(config.num_labels, name="classifier")
1118
+ if config.num_labels > 0
1119
+ else keras.layers.Activation("linear", name="classifier")
1120
+ )
1121
+ self.distillation_classifier = (
1122
+ keras.layers.Dense(config.num_labels, name="distillation_classifier")
1123
+ if config.num_labels > 0
1124
+ else keras.layers.Activation("linear", name="distillation_classifier")
1125
+ )
1126
+
1127
+ @unpack_inputs
1128
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
1129
+ @add_code_sample_docstrings(
1130
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1131
+ output_type=TFEfficientFormerForImageClassificationWithTeacherOutput,
1132
+ config_class=_CONFIG_FOR_DOC,
1133
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1134
+ )
1135
+ def call(
1136
+ self,
1137
+ pixel_values: Optional[tf.Tensor] = None,
1138
+ output_attentions: Optional[bool] = None,
1139
+ output_hidden_states: Optional[bool] = None,
1140
+ return_dict: Optional[bool] = None,
1141
+ training: bool = False,
1142
+ ) -> Union[tuple, TFEfficientFormerForImageClassificationWithTeacherOutput]:
1143
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1144
+
1145
+ if training:
1146
+ raise Exception(
1147
+ "This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet supported."
1148
+ )
1149
+
1150
+ outputs = self.efficientformer(
1151
+ pixel_values=pixel_values,
1152
+ output_attentions=output_attentions,
1153
+ output_hidden_states=output_hidden_states,
1154
+ return_dict=return_dict,
1155
+ training=training,
1156
+ )
1157
+
1158
+ sequence_output = outputs[0]
1159
+
1160
+ cls_logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2))
1161
+ distillation_logits = self.distillation_classifier(tf.reduce_mean(sequence_output, axis=-2))
1162
+ logits = (cls_logits + distillation_logits) / 2
1163
+
1164
+ if not return_dict:
1165
+ output = (logits, cls_logits, distillation_logits) + outputs[1:]
1166
+ return output
1167
+
1168
+ return TFEfficientFormerForImageClassificationWithTeacherOutput(
1169
+ logits=logits,
1170
+ cls_logits=cls_logits,
1171
+ distillation_logits=distillation_logits,
1172
+ hidden_states=outputs.hidden_states,
1173
+ attentions=outputs.attentions,
1174
+ )
1175
+
1176
+ def build(self, input_shape=None):
1177
+ if self.built:
1178
+ return
1179
+ self.built = True
1180
+ if getattr(self, "efficientformer", None) is not None:
1181
+ with tf.name_scope(self.efficientformer.name):
1182
+ self.efficientformer.build(None)
1183
+ if getattr(self, "classifier", None) is not None:
1184
+ if hasattr(self.classifier, "name"):
1185
+ with tf.name_scope(self.classifier.name):
1186
+ self.classifier.build([None, None, self.config.hidden_sizes[-1]])
1187
+ if getattr(self, "distillation_classifier", None) is not None:
1188
+ if hasattr(self.distillation_classifier, "name"):
1189
+ with tf.name_scope(self.distillation_classifier.name):
1190
+ self.distillation_classifier.build([None, None, self.config.hidden_sizes[-1]])
1191
+
1192
+
1193
+ __all__ = [
1194
+ "TFEfficientFormerForImageClassification",
1195
+ "TFEfficientFormerForImageClassificationWithTeacher",
1196
+ "TFEfficientFormerModel",
1197
+ "TFEfficientFormerPreTrainedModel",
1198
+ ]
docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace and Baidu Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_ernie_m import *
22
+ from .modeling_ernie_m import *
23
+ from .tokenization_ernie_m import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ErnieM model configuration"""
16
+ # Adapted from original paddlenlp repository.(https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/ernie_m/configuration.py)
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Dict
21
+
22
+ from ....configuration_utils import PretrainedConfig
23
+
24
+
25
+ class ErnieMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`ErnieMModel`]. It is used to instantiate a
28
+ Ernie-M model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the `Ernie-M`
30
+ [susnato/ernie-m-base_pytorch](https://huggingface.co/susnato/ernie-m-base_pytorch) architecture.
31
+
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 250002):
38
+ Vocabulary size of `inputs_ids` in [`ErnieMModel`]. Also is the vocab size of token embedding matrix.
39
+ Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling
40
+ [`ErnieMModel`].
41
+ hidden_size (`int`, *optional*, defaults to 768):
42
+ Dimensionality of the embedding layer, encoder layers and pooler layer.
43
+ num_hidden_layers (`int`, *optional*, defaults to 12):
44
+ Number of hidden layers in the Transformer encoder.
45
+ num_attention_heads (`int`, *optional*, defaults to 12):
46
+ Number of attention heads for each attention layer in the Transformer encoder.
47
+ intermediate_size (`int`, *optional*, defaults to 3072):
48
+ Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors to feed-forward layers are
49
+ firstly projected from hidden_size to intermediate_size, and then projected back to hidden_size. Typically
50
+ intermediate_size is larger than hidden_size.
51
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
52
+ The non-linear activation function in the feed-forward layer. `"gelu"`, `"relu"` and any other torch
53
+ supported activation functions are supported.
54
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
55
+ The dropout probability for all fully connected layers in the embeddings and encoder.
56
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
57
+ The dropout probability used in `MultiHeadAttention` in all encoder layers to drop some attention target.
58
+ max_position_embeddings (`int`, *optional*, defaults to 514):
59
+ The maximum value of the dimensionality of position encoding, which dictates the maximum supported length
60
+ of an input sequence.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the normal initializer for initializing all weight matrices. The index of padding
63
+ token in the token vocabulary.
64
+ pad_token_id (`int`, *optional*, defaults to 1):
65
+ Padding token id.
66
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
67
+ The epsilon used by the layer normalization layers.
68
+ classifier_dropout (`float`, *optional*):
69
+ The dropout ratio for the classification head.
70
+ act_dropout (`float`, *optional*, defaults to 0.0):
71
+ This dropout probability is used in `ErnieMEncoderLayer` after activation.
72
+
73
+ A normal_initializer initializes weight matrices as normal distributions. See
74
+ `ErnieMPretrainedModel._init_weights()` for how weights are initialized in `ErnieMModel`.
75
+ """
76
+
77
+ model_type = "ernie_m"
78
+ attribute_map: Dict[str, str] = {"dropout": "classifier_dropout", "num_classes": "num_labels"}
79
+
80
+ def __init__(
81
+ self,
82
+ vocab_size: int = 250002,
83
+ hidden_size: int = 768,
84
+ num_hidden_layers: int = 12,
85
+ num_attention_heads: int = 12,
86
+ intermediate_size: int = 3072,
87
+ hidden_act: str = "gelu",
88
+ hidden_dropout_prob: float = 0.1,
89
+ attention_probs_dropout_prob: float = 0.1,
90
+ max_position_embeddings: int = 514,
91
+ initializer_range: float = 0.02,
92
+ pad_token_id: int = 1,
93
+ layer_norm_eps: float = 1e-05,
94
+ classifier_dropout=None,
95
+ act_dropout=0.0,
96
+ **kwargs,
97
+ ):
98
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
99
+ self.vocab_size = vocab_size
100
+ self.hidden_size = hidden_size
101
+ self.num_hidden_layers = num_hidden_layers
102
+ self.num_attention_heads = num_attention_heads
103
+ self.intermediate_size = intermediate_size
104
+ self.hidden_act = hidden_act
105
+ self.hidden_dropout_prob = hidden_dropout_prob
106
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
107
+ self.max_position_embeddings = max_position_embeddings
108
+ self.initializer_range = initializer_range
109
+ self.layer_norm_eps = layer_norm_eps
110
+ self.classifier_dropout = classifier_dropout
111
+ self.act_dropout = act_dropout
112
+
113
+
114
+ __all__ = ["ErnieMConfig"]
docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py ADDED
@@ -0,0 +1,1058 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ErnieM model."""
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn, tensor
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from ....activations import ACT2FN
26
+ from ....modeling_outputs import (
27
+ BaseModelOutputWithPastAndCrossAttentions,
28
+ BaseModelOutputWithPoolingAndCrossAttentions,
29
+ MultipleChoiceModelOutput,
30
+ QuestionAnsweringModelOutput,
31
+ SequenceClassifierOutput,
32
+ TokenClassifierOutput,
33
+ )
34
+ from ....modeling_utils import PreTrainedModel
35
+ from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
36
+ from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
37
+ from .configuration_ernie_m import ErnieMConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CHECKPOINT_FOR_DOC = "susnato/ernie-m-base_pytorch"
43
+ _CONFIG_FOR_DOC = "ErnieMConfig"
44
+ _TOKENIZER_FOR_DOC = "ErnieMTokenizer"
45
+
46
+
47
+ # Adapted from paddlenlp.transformers.ernie_m.modeling.ErnieEmbeddings
48
+ class ErnieMEmbeddings(nn.Module):
49
+ """Construct the embeddings from word and position embeddings."""
50
+
51
+ def __init__(self, config):
52
+ super().__init__()
53
+ self.hidden_size = config.hidden_size
54
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
55
+ self.position_embeddings = nn.Embedding(
56
+ config.max_position_embeddings, config.hidden_size, padding_idx=config.pad_token_id
57
+ )
58
+ self.layer_norm = nn.LayerNorm(normalized_shape=config.hidden_size, eps=config.layer_norm_eps)
59
+ self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
60
+ self.padding_idx = config.pad_token_id
61
+
62
+ def forward(
63
+ self,
64
+ input_ids: Optional[torch.LongTensor] = None,
65
+ position_ids: Optional[torch.LongTensor] = None,
66
+ inputs_embeds: Optional[torch.LongTensor] = None,
67
+ past_key_values_length: int = 0,
68
+ ) -> torch.Tensor:
69
+ if inputs_embeds is None:
70
+ inputs_embeds = self.word_embeddings(input_ids)
71
+ if position_ids is None:
72
+ input_shape = inputs_embeds.size()[:-1]
73
+ ones = torch.ones(input_shape, dtype=torch.int64, device=inputs_embeds.device)
74
+ seq_length = torch.cumsum(ones, dim=1)
75
+ position_ids = seq_length - ones
76
+
77
+ if past_key_values_length > 0:
78
+ position_ids = position_ids + past_key_values_length
79
+ # to mimic paddlenlp implementation
80
+ position_ids += 2
81
+ position_embeddings = self.position_embeddings(position_ids)
82
+ embeddings = inputs_embeds + position_embeddings
83
+ embeddings = self.layer_norm(embeddings)
84
+ embeddings = self.dropout(embeddings)
85
+
86
+ return embeddings
87
+
88
+
89
+ class ErnieMSelfAttention(nn.Module):
90
+ def __init__(self, config, position_embedding_type=None):
91
+ super().__init__()
92
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
93
+ raise ValueError(
94
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
95
+ f"heads ({config.num_attention_heads})"
96
+ )
97
+
98
+ self.num_attention_heads = config.num_attention_heads
99
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
100
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
101
+
102
+ self.q_proj = nn.Linear(config.hidden_size, self.all_head_size)
103
+ self.k_proj = nn.Linear(config.hidden_size, self.all_head_size)
104
+ self.v_proj = nn.Linear(config.hidden_size, self.all_head_size)
105
+
106
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
107
+ self.position_embedding_type = position_embedding_type or getattr(
108
+ config, "position_embedding_type", "absolute"
109
+ )
110
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
111
+ self.max_position_embeddings = config.max_position_embeddings
112
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
113
+
114
+ self.is_decoder = config.is_decoder
115
+
116
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
117
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
118
+ x = x.view(new_x_shape)
119
+ return x.permute(0, 2, 1, 3)
120
+
121
+ def forward(
122
+ self,
123
+ hidden_states: torch.Tensor,
124
+ attention_mask: Optional[torch.FloatTensor] = None,
125
+ head_mask: Optional[torch.FloatTensor] = None,
126
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
127
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
128
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
129
+ output_attentions: Optional[bool] = False,
130
+ ) -> Tuple[torch.Tensor]:
131
+ mixed_query_layer = self.q_proj(hidden_states)
132
+
133
+ # If this is instantiated as a cross-attention module, the keys
134
+ # and values come from an encoder; the attention mask needs to be
135
+ # such that the encoder's padding tokens are not attended to.
136
+ is_cross_attention = encoder_hidden_states is not None
137
+
138
+ if is_cross_attention and past_key_value is not None:
139
+ # reuse k,v, cross_attentions
140
+ key_layer = past_key_value[0]
141
+ value_layer = past_key_value[1]
142
+ attention_mask = encoder_attention_mask
143
+ elif is_cross_attention:
144
+ key_layer = self.transpose_for_scores(self.k_proj(encoder_hidden_states))
145
+ value_layer = self.transpose_for_scores(self.v_proj(encoder_hidden_states))
146
+ attention_mask = encoder_attention_mask
147
+ elif past_key_value is not None:
148
+ key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
149
+ value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
150
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
151
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
152
+ else:
153
+ key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
154
+ value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
155
+
156
+ query_layer = self.transpose_for_scores(mixed_query_layer)
157
+
158
+ use_cache = past_key_value is not None
159
+ if self.is_decoder:
160
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
161
+ # Further calls to cross_attention layer can then reuse all cross-attention
162
+ # key/value_states (first "if" case)
163
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
164
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
165
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
166
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
167
+ past_key_value = (key_layer, value_layer)
168
+
169
+ # Take the dot product between "query" and "key" to get the raw attention scores.
170
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
171
+
172
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
173
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
174
+ if use_cache:
175
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
176
+ -1, 1
177
+ )
178
+ else:
179
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
180
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
181
+ distance = position_ids_l - position_ids_r
182
+
183
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
184
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
185
+
186
+ if self.position_embedding_type == "relative_key":
187
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
188
+ attention_scores = attention_scores + relative_position_scores
189
+ elif self.position_embedding_type == "relative_key_query":
190
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
191
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
192
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
193
+
194
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
195
+ if attention_mask is not None:
196
+ # Apply the attention mask is (precomputed for all layers in ErnieMModel forward() function)
197
+ attention_scores = attention_scores + attention_mask
198
+
199
+ # Normalize the attention scores to probabilities.
200
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
201
+
202
+ # This is actually dropping out entire tokens to attend to, which might
203
+ # seem a bit unusual, but is taken from the original Transformer paper.
204
+ attention_probs = self.dropout(attention_probs)
205
+
206
+ # Mask heads if we want to
207
+ if head_mask is not None:
208
+ attention_probs = attention_probs * head_mask
209
+
210
+ context_layer = torch.matmul(attention_probs, value_layer)
211
+
212
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
213
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
214
+ context_layer = context_layer.view(new_context_layer_shape)
215
+
216
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
217
+
218
+ if self.is_decoder:
219
+ outputs = outputs + (past_key_value,)
220
+ return outputs
221
+
222
+
223
+ class ErnieMAttention(nn.Module):
224
+ def __init__(self, config, position_embedding_type=None):
225
+ super().__init__()
226
+ self.self_attn = ErnieMSelfAttention(config, position_embedding_type=position_embedding_type)
227
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
228
+ self.pruned_heads = set()
229
+
230
+ def prune_heads(self, heads):
231
+ if len(heads) == 0:
232
+ return
233
+ heads, index = find_pruneable_heads_and_indices(
234
+ heads, self.self_attn.num_attention_heads, self.self_attn.attention_head_size, self.pruned_heads
235
+ )
236
+
237
+ # Prune linear layers
238
+ self.self_attn.q_proj = prune_linear_layer(self.self_attn.q_proj, index)
239
+ self.self_attn.k_proj = prune_linear_layer(self.self_attn.k_proj, index)
240
+ self.self_attn.v_proj = prune_linear_layer(self.self_attn.v_proj, index)
241
+ self.out_proj = prune_linear_layer(self.out_proj, index, dim=1)
242
+
243
+ # Update hyper params and store pruned heads
244
+ self.self_attn.num_attention_heads = self.self_attn.num_attention_heads - len(heads)
245
+ self.self_attn.all_head_size = self.self_attn.attention_head_size * self.self_attn.num_attention_heads
246
+ self.pruned_heads = self.pruned_heads.union(heads)
247
+
248
+ def forward(
249
+ self,
250
+ hidden_states: torch.Tensor,
251
+ attention_mask: Optional[torch.FloatTensor] = None,
252
+ head_mask: Optional[torch.FloatTensor] = None,
253
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
254
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
255
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
256
+ output_attentions: Optional[bool] = False,
257
+ ) -> Tuple[torch.Tensor]:
258
+ self_outputs = self.self_attn(
259
+ hidden_states,
260
+ attention_mask,
261
+ head_mask,
262
+ encoder_hidden_states,
263
+ encoder_attention_mask,
264
+ past_key_value,
265
+ output_attentions,
266
+ )
267
+ attention_output = self.out_proj(self_outputs[0])
268
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
269
+ return outputs
270
+
271
+
272
+ class ErnieMEncoderLayer(nn.Module):
273
+ def __init__(self, config):
274
+ super().__init__()
275
+ # to mimic paddlenlp implementation
276
+ dropout = 0.1 if config.hidden_dropout_prob is None else config.hidden_dropout_prob
277
+ act_dropout = config.hidden_dropout_prob if config.act_dropout is None else config.act_dropout
278
+
279
+ self.self_attn = ErnieMAttention(config)
280
+ self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size)
281
+ self.dropout = nn.Dropout(act_dropout)
282
+ self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size)
283
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
284
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
285
+ self.dropout1 = nn.Dropout(dropout)
286
+ self.dropout2 = nn.Dropout(dropout)
287
+ if isinstance(config.hidden_act, str):
288
+ self.activation = ACT2FN[config.hidden_act]
289
+ else:
290
+ self.activation = config.hidden_act
291
+
292
+ def forward(
293
+ self,
294
+ hidden_states: torch.Tensor,
295
+ attention_mask: Optional[torch.FloatTensor] = None,
296
+ head_mask: Optional[torch.FloatTensor] = None,
297
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
298
+ output_attentions: Optional[bool] = True,
299
+ ):
300
+ residual = hidden_states
301
+ if output_attentions:
302
+ hidden_states, attention_opt_weights = self.self_attn(
303
+ hidden_states=hidden_states,
304
+ attention_mask=attention_mask,
305
+ head_mask=head_mask,
306
+ past_key_value=past_key_value,
307
+ output_attentions=output_attentions,
308
+ )
309
+
310
+ else:
311
+ hidden_states = self.self_attn(
312
+ hidden_states=hidden_states,
313
+ attention_mask=attention_mask,
314
+ head_mask=head_mask,
315
+ past_key_value=past_key_value,
316
+ output_attentions=output_attentions,
317
+ )
318
+ hidden_states = residual + self.dropout1(hidden_states)
319
+ hidden_states = self.norm1(hidden_states)
320
+ residual = hidden_states
321
+
322
+ hidden_states = self.linear1(hidden_states)
323
+ hidden_states = self.activation(hidden_states)
324
+ hidden_states = self.dropout(hidden_states)
325
+ hidden_states = self.linear2(hidden_states)
326
+ hidden_states = residual + self.dropout2(hidden_states)
327
+ hidden_states = self.norm2(hidden_states)
328
+
329
+ if output_attentions:
330
+ return hidden_states, attention_opt_weights
331
+ else:
332
+ return hidden_states
333
+
334
+
335
+ class ErnieMEncoder(nn.Module):
336
+ def __init__(self, config):
337
+ super().__init__()
338
+ self.config = config
339
+ self.layers = nn.ModuleList([ErnieMEncoderLayer(config) for _ in range(config.num_hidden_layers)])
340
+
341
+ def forward(
342
+ self,
343
+ input_embeds: torch.Tensor,
344
+ attention_mask: Optional[torch.FloatTensor] = None,
345
+ head_mask: Optional[torch.FloatTensor] = None,
346
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
347
+ output_attentions: Optional[bool] = False,
348
+ output_hidden_states: Optional[bool] = False,
349
+ return_dict: Optional[bool] = True,
350
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
351
+ hidden_states = () if output_hidden_states else None
352
+ attentions = () if output_attentions else None
353
+
354
+ output = input_embeds
355
+ if output_hidden_states:
356
+ hidden_states = hidden_states + (output,)
357
+ for i, layer in enumerate(self.layers):
358
+ layer_head_mask = head_mask[i] if head_mask is not None else None
359
+ past_key_value = past_key_values[i] if past_key_values is not None else None
360
+
361
+ output, opt_attn_weights = layer(
362
+ hidden_states=output,
363
+ attention_mask=attention_mask,
364
+ head_mask=layer_head_mask,
365
+ past_key_value=past_key_value,
366
+ )
367
+
368
+ if output_hidden_states:
369
+ hidden_states = hidden_states + (output,)
370
+ if output_attentions:
371
+ attentions = attentions + (opt_attn_weights,)
372
+
373
+ last_hidden_state = output
374
+ if not return_dict:
375
+ return tuple(v for v in [last_hidden_state, hidden_states, attentions] if v is not None)
376
+
377
+ return BaseModelOutputWithPastAndCrossAttentions(
378
+ last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=attentions
379
+ )
380
+
381
+
382
+ class ErnieMPooler(nn.Module):
383
+ def __init__(self, config):
384
+ super().__init__()
385
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
386
+ self.activation = nn.Tanh()
387
+
388
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
389
+ # We "pool" the model by simply taking the hidden state corresponding
390
+ # to the first token.
391
+ first_token_tensor = hidden_states[:, 0]
392
+ pooled_output = self.dense(first_token_tensor)
393
+ pooled_output = self.activation(pooled_output)
394
+ return pooled_output
395
+
396
+
397
+ class ErnieMPreTrainedModel(PreTrainedModel):
398
+ """
399
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
400
+ models.
401
+ """
402
+
403
+ config_class = ErnieMConfig
404
+ base_model_prefix = "ernie_m"
405
+
406
+ def _init_weights(self, module):
407
+ """Initialize the weights"""
408
+ if isinstance(module, nn.Linear):
409
+ # Slightly different from the TF version which uses truncated_normal for initialization
410
+ # cf https://github.com/pytorch/pytorch/pull/5617
411
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
412
+ if module.bias is not None:
413
+ module.bias.data.zero_()
414
+ elif isinstance(module, nn.Embedding):
415
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
416
+ if module.padding_idx is not None:
417
+ module.weight.data[module.padding_idx].zero_()
418
+ elif isinstance(module, nn.LayerNorm):
419
+ module.bias.data.zero_()
420
+ module.weight.data.fill_(1.0)
421
+
422
+
423
+ ERNIE_M_START_DOCSTRING = r"""
424
+
425
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
426
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
427
+ etc.)
428
+
429
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
430
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
431
+ behavior.
432
+
433
+ Parameters:
434
+ config ([`ErnieMConfig`]): Model configuration class with all the parameters of the model.
435
+ Initializing with a config file does not load the weights associated with the model, only the
436
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
437
+ """
438
+
439
+ ERNIE_M_INPUTS_DOCSTRING = r"""
440
+ Args:
441
+ input_ids (`torch.LongTensor` of shape `({0})`):
442
+ Indices of input sequence tokens in the vocabulary.
443
+
444
+ Indices can be obtained using [`ErnieMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
445
+ [`PreTrainedTokenizer.__call__`] for details.
446
+
447
+ [What are input IDs?](../glossary#input-ids)
448
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
449
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
450
+
451
+ - 1 for tokens that are **not masked**,
452
+ - 0 for tokens that are **masked**.
453
+
454
+ [What are attention masks?](../glossary#attention-mask)
455
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
456
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
457
+ config.max_position_embeddings - 1]`.
458
+
459
+ [What are position IDs?](../glossary#position-ids)
460
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
461
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
462
+
463
+ - 1 indicates the head is **not masked**,
464
+ - 0 indicates the head is **masked**.
465
+
466
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
467
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
468
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
469
+ model's internal embedding lookup matrix.
470
+ output_attentions (`bool`, *optional*):
471
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
472
+ tensors for more detail.
473
+ output_hidden_states (`bool`, *optional*):
474
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
475
+ more detail.
476
+ return_dict (`bool`, *optional*):
477
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
478
+ """
479
+
480
+
481
+ @add_start_docstrings(
482
+ "The bare ErnieM Model transformer outputting raw hidden-states without any specific head on top.",
483
+ ERNIE_M_START_DOCSTRING,
484
+ )
485
+ class ErnieMModel(ErnieMPreTrainedModel):
486
+ def __init__(self, config, add_pooling_layer=True):
487
+ super(ErnieMModel, self).__init__(config)
488
+ self.initializer_range = config.initializer_range
489
+ self.embeddings = ErnieMEmbeddings(config)
490
+ self.encoder = ErnieMEncoder(config)
491
+ self.pooler = ErnieMPooler(config) if add_pooling_layer else None
492
+ self.post_init()
493
+
494
+ def get_input_embeddings(self):
495
+ return self.embeddings.word_embeddings
496
+
497
+ def set_input_embeddings(self, value):
498
+ self.embeddings.word_embeddings = value
499
+
500
+ def _prune_heads(self, heads_to_prune):
501
+ """
502
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
503
+ class PreTrainedModel
504
+ """
505
+ for layer, heads in heads_to_prune.items():
506
+ self.encoder.layers[layer].self_attn.prune_heads(heads)
507
+
508
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
509
+ @add_code_sample_docstrings(
510
+ processor_class=_TOKENIZER_FOR_DOC,
511
+ checkpoint=_CHECKPOINT_FOR_DOC,
512
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
513
+ config_class=_CONFIG_FOR_DOC,
514
+ )
515
+ def forward(
516
+ self,
517
+ input_ids: Optional[tensor] = None,
518
+ position_ids: Optional[tensor] = None,
519
+ attention_mask: Optional[tensor] = None,
520
+ head_mask: Optional[tensor] = None,
521
+ inputs_embeds: Optional[tensor] = None,
522
+ past_key_values: Optional[Tuple[Tuple[tensor]]] = None,
523
+ use_cache: Optional[bool] = None,
524
+ output_hidden_states: Optional[bool] = None,
525
+ output_attentions: Optional[bool] = None,
526
+ return_dict: Optional[bool] = None,
527
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
528
+ if input_ids is not None and inputs_embeds is not None:
529
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.")
530
+
531
+ # init the default bool value
532
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
533
+ output_hidden_states = (
534
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
535
+ )
536
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
537
+
538
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
539
+
540
+ past_key_values_length = 0
541
+ if past_key_values is not None:
542
+ past_key_values_length = past_key_values[0][0].shape[2]
543
+
544
+ # Adapted from paddlenlp.transformers.ernie_m.ErnieMModel
545
+ if attention_mask is None:
546
+ attention_mask = (input_ids == self.config.pad_token_id).to(torch.float32)
547
+ attention_mask *= torch.finfo(attention_mask.dtype).min
548
+ if past_key_values is not None:
549
+ batch_size = past_key_values[0][0].shape[0]
550
+ past_mask = torch.zeros([batch_size, 1, 1, past_key_values_length], dtype=attention_mask.dtype)
551
+ attention_mask = torch.concat([past_mask, attention_mask], dim=-1)
552
+ # For 2D attention_mask from tokenizer
553
+ elif attention_mask.ndim == 2:
554
+ attention_mask = attention_mask.to(torch.float32)
555
+ attention_mask = 1.0 - attention_mask
556
+ attention_mask *= torch.finfo(attention_mask.dtype).min
557
+
558
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
559
+
560
+ embedding_output = self.embeddings(
561
+ input_ids=input_ids,
562
+ position_ids=position_ids,
563
+ inputs_embeds=inputs_embeds,
564
+ past_key_values_length=past_key_values_length,
565
+ )
566
+ encoder_outputs = self.encoder(
567
+ embedding_output,
568
+ attention_mask=extended_attention_mask,
569
+ head_mask=head_mask,
570
+ past_key_values=past_key_values,
571
+ output_attentions=output_attentions,
572
+ output_hidden_states=output_hidden_states,
573
+ return_dict=return_dict,
574
+ )
575
+
576
+ if not return_dict:
577
+ sequence_output = encoder_outputs[0]
578
+ pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
579
+ return (sequence_output, pooler_output) + encoder_outputs[1:]
580
+
581
+ sequence_output = encoder_outputs["last_hidden_state"]
582
+ pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
583
+ hidden_states = None if not output_hidden_states else encoder_outputs["hidden_states"]
584
+ attentions = None if not output_attentions else encoder_outputs["attentions"]
585
+
586
+ return BaseModelOutputWithPoolingAndCrossAttentions(
587
+ last_hidden_state=sequence_output,
588
+ pooler_output=pooler_output,
589
+ hidden_states=hidden_states,
590
+ attentions=attentions,
591
+ )
592
+
593
+
594
+ @add_start_docstrings(
595
+ """ErnieM Model transformer with a sequence classification/regression head on top (a linear layer on top of
596
+ the pooled output) e.g. for GLUE tasks.""",
597
+ ERNIE_M_START_DOCSTRING,
598
+ )
599
+ class ErnieMForSequenceClassification(ErnieMPreTrainedModel):
600
+ def __init__(self, config):
601
+ super().__init__(config)
602
+ self.num_labels = config.num_labels
603
+ self.config = config
604
+
605
+ self.ernie_m = ErnieMModel(config)
606
+ classifier_dropout = (
607
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
608
+ )
609
+ self.dropout = nn.Dropout(classifier_dropout)
610
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
611
+
612
+ # Initialize weights and apply final processing
613
+ self.post_init()
614
+
615
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
616
+ @add_code_sample_docstrings(
617
+ processor_class=_TOKENIZER_FOR_DOC,
618
+ checkpoint=_CHECKPOINT_FOR_DOC,
619
+ output_type=SequenceClassifierOutput,
620
+ config_class=_CONFIG_FOR_DOC,
621
+ )
622
+ def forward(
623
+ self,
624
+ input_ids: Optional[torch.Tensor] = None,
625
+ attention_mask: Optional[torch.Tensor] = None,
626
+ position_ids: Optional[torch.Tensor] = None,
627
+ head_mask: Optional[torch.Tensor] = None,
628
+ inputs_embeds: Optional[torch.Tensor] = None,
629
+ past_key_values: Optional[List[torch.Tensor]] = None,
630
+ use_cache: Optional[bool] = None,
631
+ output_hidden_states: Optional[bool] = None,
632
+ output_attentions: Optional[bool] = None,
633
+ return_dict: Optional[bool] = True,
634
+ labels: Optional[torch.Tensor] = None,
635
+ ) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
636
+ r"""
637
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
638
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
639
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
640
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
641
+ """
642
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
643
+
644
+ outputs = self.ernie_m(
645
+ input_ids,
646
+ attention_mask=attention_mask,
647
+ position_ids=position_ids,
648
+ head_mask=head_mask,
649
+ inputs_embeds=inputs_embeds,
650
+ past_key_values=past_key_values,
651
+ output_hidden_states=output_hidden_states,
652
+ output_attentions=output_attentions,
653
+ return_dict=return_dict,
654
+ )
655
+
656
+ pooled_output = outputs[1]
657
+
658
+ pooled_output = self.dropout(pooled_output)
659
+ logits = self.classifier(pooled_output)
660
+
661
+ loss = None
662
+ if labels is not None:
663
+ if self.config.problem_type is None:
664
+ if self.num_labels == 1:
665
+ self.config.problem_type = "regression"
666
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
667
+ self.config.problem_type = "single_label_classification"
668
+ else:
669
+ self.config.problem_type = "multi_label_classification"
670
+
671
+ if self.config.problem_type == "regression":
672
+ loss_fct = MSELoss()
673
+ if self.num_labels == 1:
674
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
675
+ else:
676
+ loss = loss_fct(logits, labels)
677
+ elif self.config.problem_type == "single_label_classification":
678
+ loss_fct = CrossEntropyLoss()
679
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
680
+ elif self.config.problem_type == "multi_label_classification":
681
+ loss_fct = BCEWithLogitsLoss()
682
+ loss = loss_fct(logits, labels)
683
+ if not return_dict:
684
+ output = (logits,) + outputs[2:]
685
+ return ((loss,) + output) if loss is not None else output
686
+
687
+ return SequenceClassifierOutput(
688
+ loss=loss,
689
+ logits=logits,
690
+ hidden_states=outputs.hidden_states,
691
+ attentions=outputs.attentions,
692
+ )
693
+
694
+
695
+ @add_start_docstrings(
696
+ """ErnieM Model with a multiple choice classification head on top (a linear layer on top of
697
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks.""",
698
+ ERNIE_M_START_DOCSTRING,
699
+ )
700
+ class ErnieMForMultipleChoice(ErnieMPreTrainedModel):
701
+ def __init__(self, config):
702
+ super().__init__(config)
703
+
704
+ self.ernie_m = ErnieMModel(config)
705
+ classifier_dropout = (
706
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
707
+ )
708
+ self.dropout = nn.Dropout(classifier_dropout)
709
+ self.classifier = nn.Linear(config.hidden_size, 1)
710
+
711
+ # Initialize weights and apply final processing
712
+ self.post_init()
713
+
714
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
715
+ @add_code_sample_docstrings(
716
+ checkpoint=_CHECKPOINT_FOR_DOC,
717
+ output_type=MultipleChoiceModelOutput,
718
+ config_class=_CONFIG_FOR_DOC,
719
+ )
720
+ def forward(
721
+ self,
722
+ input_ids: Optional[torch.Tensor] = None,
723
+ attention_mask: Optional[torch.Tensor] = None,
724
+ position_ids: Optional[torch.Tensor] = None,
725
+ head_mask: Optional[torch.Tensor] = None,
726
+ inputs_embeds: Optional[torch.Tensor] = None,
727
+ labels: Optional[torch.Tensor] = None,
728
+ output_attentions: Optional[bool] = None,
729
+ output_hidden_states: Optional[bool] = None,
730
+ return_dict: Optional[bool] = True,
731
+ ) -> Union[Tuple[torch.FloatTensor], MultipleChoiceModelOutput]:
732
+ r"""
733
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
734
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
735
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
736
+ `input_ids` above)
737
+ """
738
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
739
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
740
+
741
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
742
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
743
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
744
+ inputs_embeds = (
745
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
746
+ if inputs_embeds is not None
747
+ else None
748
+ )
749
+
750
+ outputs = self.ernie_m(
751
+ input_ids,
752
+ attention_mask=attention_mask,
753
+ position_ids=position_ids,
754
+ head_mask=head_mask,
755
+ inputs_embeds=inputs_embeds,
756
+ output_attentions=output_attentions,
757
+ output_hidden_states=output_hidden_states,
758
+ return_dict=return_dict,
759
+ )
760
+
761
+ pooled_output = outputs[1]
762
+
763
+ pooled_output = self.dropout(pooled_output)
764
+ logits = self.classifier(pooled_output)
765
+ reshaped_logits = logits.view(-1, num_choices)
766
+
767
+ loss = None
768
+ if labels is not None:
769
+ loss_fct = CrossEntropyLoss()
770
+ loss = loss_fct(reshaped_logits, labels)
771
+
772
+ if not return_dict:
773
+ output = (reshaped_logits,) + outputs[2:]
774
+ return ((loss,) + output) if loss is not None else output
775
+
776
+ return MultipleChoiceModelOutput(
777
+ loss=loss,
778
+ logits=reshaped_logits,
779
+ hidden_states=outputs.hidden_states,
780
+ attentions=outputs.attentions,
781
+ )
782
+
783
+
784
+ @add_start_docstrings(
785
+ """ErnieM Model with a token classification head on top (a linear layer on top of
786
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.""",
787
+ ERNIE_M_START_DOCSTRING,
788
+ )
789
+ class ErnieMForTokenClassification(ErnieMPreTrainedModel):
790
+ def __init__(self, config):
791
+ super().__init__(config)
792
+ self.num_labels = config.num_labels
793
+
794
+ self.ernie_m = ErnieMModel(config, add_pooling_layer=False)
795
+ classifier_dropout = (
796
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
797
+ )
798
+ self.dropout = nn.Dropout(classifier_dropout)
799
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
800
+
801
+ # Initialize weights and apply final processing
802
+ self.post_init()
803
+
804
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
805
+ @add_code_sample_docstrings(
806
+ processor_class=_TOKENIZER_FOR_DOC,
807
+ checkpoint=_CHECKPOINT_FOR_DOC,
808
+ output_type=TokenClassifierOutput,
809
+ config_class=_CONFIG_FOR_DOC,
810
+ )
811
+ def forward(
812
+ self,
813
+ input_ids: Optional[torch.Tensor] = None,
814
+ attention_mask: Optional[torch.Tensor] = None,
815
+ position_ids: Optional[torch.Tensor] = None,
816
+ head_mask: Optional[torch.Tensor] = None,
817
+ inputs_embeds: Optional[torch.Tensor] = None,
818
+ past_key_values: Optional[List[torch.Tensor]] = None,
819
+ output_hidden_states: Optional[bool] = None,
820
+ output_attentions: Optional[bool] = None,
821
+ return_dict: Optional[bool] = True,
822
+ labels: Optional[torch.Tensor] = None,
823
+ ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
824
+ r"""
825
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
826
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
827
+ """
828
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
829
+
830
+ outputs = self.ernie_m(
831
+ input_ids,
832
+ attention_mask=attention_mask,
833
+ position_ids=position_ids,
834
+ head_mask=head_mask,
835
+ inputs_embeds=inputs_embeds,
836
+ past_key_values=past_key_values,
837
+ output_attentions=output_attentions,
838
+ output_hidden_states=output_hidden_states,
839
+ return_dict=return_dict,
840
+ )
841
+
842
+ sequence_output = outputs[0]
843
+
844
+ sequence_output = self.dropout(sequence_output)
845
+ logits = self.classifier(sequence_output)
846
+
847
+ loss = None
848
+ if labels is not None:
849
+ loss_fct = CrossEntropyLoss()
850
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
851
+
852
+ if not return_dict:
853
+ output = (logits,) + outputs[2:]
854
+ return ((loss,) + output) if loss is not None else output
855
+
856
+ return TokenClassifierOutput(
857
+ loss=loss,
858
+ logits=logits,
859
+ hidden_states=outputs.hidden_states,
860
+ attentions=outputs.attentions,
861
+ )
862
+
863
+
864
+ @add_start_docstrings(
865
+ """ErnieM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
866
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).""",
867
+ ERNIE_M_START_DOCSTRING,
868
+ )
869
+ class ErnieMForQuestionAnswering(ErnieMPreTrainedModel):
870
+ def __init__(self, config):
871
+ super().__init__(config)
872
+ self.num_labels = config.num_labels
873
+
874
+ self.ernie_m = ErnieMModel(config, add_pooling_layer=False)
875
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
876
+
877
+ # Initialize weights and apply final processing
878
+ self.post_init()
879
+
880
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
881
+ @add_code_sample_docstrings(
882
+ processor_class=_TOKENIZER_FOR_DOC,
883
+ checkpoint=_CHECKPOINT_FOR_DOC,
884
+ output_type=QuestionAnsweringModelOutput,
885
+ config_class=_CONFIG_FOR_DOC,
886
+ )
887
+ def forward(
888
+ self,
889
+ input_ids: Optional[torch.Tensor] = None,
890
+ attention_mask: Optional[torch.Tensor] = None,
891
+ position_ids: Optional[torch.Tensor] = None,
892
+ head_mask: Optional[torch.Tensor] = None,
893
+ inputs_embeds: Optional[torch.Tensor] = None,
894
+ start_positions: Optional[torch.Tensor] = None,
895
+ end_positions: Optional[torch.Tensor] = None,
896
+ output_attentions: Optional[bool] = None,
897
+ output_hidden_states: Optional[bool] = None,
898
+ return_dict: Optional[bool] = True,
899
+ ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
900
+ r"""
901
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
902
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
903
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
904
+ are not taken into account for computing the loss.
905
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
906
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
907
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
908
+ are not taken into account for computing the loss.
909
+ """
910
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
911
+
912
+ outputs = self.ernie_m(
913
+ input_ids,
914
+ attention_mask=attention_mask,
915
+ position_ids=position_ids,
916
+ head_mask=head_mask,
917
+ inputs_embeds=inputs_embeds,
918
+ output_attentions=output_attentions,
919
+ output_hidden_states=output_hidden_states,
920
+ return_dict=return_dict,
921
+ )
922
+
923
+ sequence_output = outputs[0]
924
+
925
+ logits = self.qa_outputs(sequence_output)
926
+ start_logits, end_logits = logits.split(1, dim=-1)
927
+ start_logits = start_logits.squeeze(-1).contiguous()
928
+ end_logits = end_logits.squeeze(-1).contiguous()
929
+
930
+ total_loss = None
931
+ if start_positions is not None and end_positions is not None:
932
+ # If we are on multi-GPU, split add a dimension
933
+ if len(start_positions.size()) > 1:
934
+ start_positions = start_positions.squeeze(-1)
935
+ if len(end_positions.size()) > 1:
936
+ end_positions = end_positions.squeeze(-1)
937
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
938
+ ignored_index = start_logits.size(1)
939
+ start_positions = start_positions.clamp(0, ignored_index)
940
+ end_positions = end_positions.clamp(0, ignored_index)
941
+
942
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
943
+ start_loss = loss_fct(start_logits, start_positions)
944
+ end_loss = loss_fct(end_logits, end_positions)
945
+ total_loss = (start_loss + end_loss) / 2
946
+
947
+ if not return_dict:
948
+ output = (start_logits, end_logits) + outputs[2:]
949
+ return ((total_loss,) + output) if total_loss is not None else output
950
+
951
+ return QuestionAnsweringModelOutput(
952
+ loss=total_loss,
953
+ start_logits=start_logits,
954
+ end_logits=end_logits,
955
+ hidden_states=outputs.hidden_states,
956
+ attentions=outputs.attentions,
957
+ )
958
+
959
+
960
+ @add_start_docstrings(
961
+ """ErnieMForInformationExtraction is a Ernie-M Model with two linear layer on top of the hidden-states output to
962
+ compute `start_prob` and `end_prob`, designed for Universal Information Extraction.""",
963
+ ERNIE_M_START_DOCSTRING,
964
+ )
965
+ class ErnieMForInformationExtraction(ErnieMPreTrainedModel):
966
+ def __init__(self, config):
967
+ super(ErnieMForInformationExtraction, self).__init__(config)
968
+ self.ernie_m = ErnieMModel(config)
969
+ self.linear_start = nn.Linear(config.hidden_size, 1)
970
+ self.linear_end = nn.Linear(config.hidden_size, 1)
971
+ self.sigmoid = nn.Sigmoid()
972
+ self.post_init()
973
+
974
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
975
+ def forward(
976
+ self,
977
+ input_ids: Optional[torch.Tensor] = None,
978
+ attention_mask: Optional[torch.Tensor] = None,
979
+ position_ids: Optional[torch.Tensor] = None,
980
+ head_mask: Optional[torch.Tensor] = None,
981
+ inputs_embeds: Optional[torch.Tensor] = None,
982
+ start_positions: Optional[torch.Tensor] = None,
983
+ end_positions: Optional[torch.Tensor] = None,
984
+ output_attentions: Optional[bool] = None,
985
+ output_hidden_states: Optional[bool] = None,
986
+ return_dict: Optional[bool] = True,
987
+ ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
988
+ r"""
989
+ start_positions (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
990
+ Labels for position (index) for computing the start_positions loss. Position outside of the sequence are
991
+ not taken into account for computing the loss.
992
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
993
+ Labels for position (index) for computing the end_positions loss. Position outside of the sequence are not
994
+ taken into account for computing the loss.
995
+ """
996
+
997
+ result = self.ernie_m(
998
+ input_ids,
999
+ attention_mask=attention_mask,
1000
+ position_ids=position_ids,
1001
+ head_mask=head_mask,
1002
+ inputs_embeds=inputs_embeds,
1003
+ output_attentions=output_attentions,
1004
+ output_hidden_states=output_hidden_states,
1005
+ return_dict=return_dict,
1006
+ )
1007
+ if return_dict:
1008
+ sequence_output = result.last_hidden_state
1009
+ elif not return_dict:
1010
+ sequence_output = result[0]
1011
+
1012
+ start_logits = self.linear_start(sequence_output)
1013
+ start_logits = start_logits.squeeze(-1)
1014
+ end_logits = self.linear_end(sequence_output)
1015
+ end_logits = end_logits.squeeze(-1)
1016
+
1017
+ total_loss = None
1018
+ if start_positions is not None and end_positions is not None:
1019
+ # If we are on multi-GPU, split add a dimension
1020
+ if len(start_positions.size()) > 1:
1021
+ start_positions = start_positions.squeeze(-1)
1022
+ if len(end_positions.size()) > 1:
1023
+ end_positions = end_positions.squeeze(-1)
1024
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1025
+ ignored_index = start_logits.size(1)
1026
+ start_positions = start_positions.clamp(0, ignored_index)
1027
+ end_positions = end_positions.clamp(0, ignored_index)
1028
+
1029
+ loss_fct = BCEWithLogitsLoss()
1030
+ start_loss = loss_fct(start_logits, start_positions)
1031
+ end_loss = loss_fct(end_logits, end_positions)
1032
+ total_loss = (start_loss + end_loss) / 2
1033
+
1034
+ if not return_dict:
1035
+ return tuple(
1036
+ i
1037
+ for i in [total_loss, start_logits, end_logits, result.hidden_states, result.attentions]
1038
+ if i is not None
1039
+ )
1040
+
1041
+ return QuestionAnsweringModelOutput(
1042
+ loss=total_loss,
1043
+ start_logits=start_logits,
1044
+ end_logits=end_logits,
1045
+ hidden_states=result.hidden_states,
1046
+ attentions=result.attentions,
1047
+ )
1048
+
1049
+
1050
+ __all__ = [
1051
+ "ErnieMForMultipleChoice",
1052
+ "ErnieMForQuestionAnswering",
1053
+ "ErnieMForSequenceClassification",
1054
+ "ErnieMForTokenClassification",
1055
+ "ErnieMModel",
1056
+ "ErnieMPreTrainedModel",
1057
+ "ErnieMForInformationExtraction",
1058
+ ]
docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Ernie-M."""
16
+
17
+ import io
18
+ import os
19
+ import unicodedata
20
+ from typing import Any, Dict, List, Optional, Tuple
21
+
22
+ import sentencepiece as spm
23
+
24
+ from ....tokenization_utils import PreTrainedTokenizer
25
+ from ....utils import logging
26
+ from ....utils.import_utils import requires
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ SPIECE_UNDERLINE = "▁"
32
+
33
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "sentencepiece_model_ckpt": "sentencepiece.bpe.model"}
34
+
35
+ RESOURCE_FILES_NAMES = {
36
+ "sentencepiece_model_file": "sentencepiece.bpe.model",
37
+ "vocab_file": "vocab.txt",
38
+ }
39
+
40
+
41
+ # Adapted from paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer
42
+ @requires(backends=("sentencepiece",))
43
+ class ErnieMTokenizer(PreTrainedTokenizer):
44
+ r"""
45
+ Constructs a Ernie-M tokenizer. It uses the `sentencepiece` tools to cut the words to sub-words.
46
+
47
+ Args:
48
+ sentencepiece_model_file (`str`):
49
+ The file path of sentencepiece model.
50
+ vocab_file (`str`, *optional*):
51
+ The file path of the vocabulary.
52
+ do_lower_case (`str`, *optional*, defaults to `True`):
53
+ Whether or not to lowercase the input when tokenizing.
54
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
55
+ A special token representing the `unknown (out-of-vocabulary)` token. An unknown token is set to be
56
+ `unk_token` inorder to be converted to an ID.
57
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
58
+ A special token separating two different sentences in the same input.
59
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
60
+ A special token used to make arrays of tokens the same size for batching purposes.
61
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
62
+ A special token used for sequence classification. It is the last token of the sequence when built with
63
+ special tokens.
64
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
65
+ A special token representing a masked token. This is the token used in the masked language modeling task
66
+ which the model tries to predict the original unmasked ones.
67
+ """
68
+
69
+ # Ernie-M model doesn't have token_type embedding.
70
+ model_input_names: List[str] = ["input_ids"]
71
+
72
+ vocab_files_names = VOCAB_FILES_NAMES
73
+ resource_files_names = RESOURCE_FILES_NAMES
74
+
75
+ def __init__(
76
+ self,
77
+ sentencepiece_model_ckpt,
78
+ vocab_file=None,
79
+ do_lower_case=False,
80
+ encoding="utf8",
81
+ unk_token="[UNK]",
82
+ sep_token="[SEP]",
83
+ pad_token="[PAD]",
84
+ cls_token="[CLS]",
85
+ mask_token="[MASK]",
86
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
87
+ **kwargs,
88
+ ) -> None:
89
+ # Mask token behave like a normal word, i.e. include the space before it and
90
+ # is included in the raw text, there should be a match in a non-normalized sentence.
91
+
92
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
93
+
94
+ self.do_lower_case = do_lower_case
95
+ self.sentencepiece_model_ckpt = sentencepiece_model_ckpt
96
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
97
+ self.sp_model.Load(sentencepiece_model_ckpt)
98
+
99
+ # to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning
100
+ if vocab_file is not None:
101
+ self.vocab = self.load_vocab(filepath=vocab_file)
102
+ else:
103
+ self.vocab = {self.sp_model.id_to_piece(id): id for id in range(self.sp_model.get_piece_size())}
104
+ self.reverse_vocab = {v: k for k, v in self.vocab.items()}
105
+
106
+ super().__init__(
107
+ do_lower_case=do_lower_case,
108
+ unk_token=unk_token,
109
+ sep_token=sep_token,
110
+ pad_token=pad_token,
111
+ cls_token=cls_token,
112
+ mask_token=mask_token,
113
+ vocab_file=vocab_file,
114
+ encoding=encoding,
115
+ sp_model_kwargs=self.sp_model_kwargs,
116
+ **kwargs,
117
+ )
118
+
119
+ def get_offset_mapping(self, text):
120
+ if text is None:
121
+ return None
122
+
123
+ split_tokens = self.tokenize(text)
124
+ normalized_text, char_mapping = "", []
125
+
126
+ for i, ch in enumerate(text):
127
+ if ch in self.SP_CHAR_MAPPING:
128
+ ch = self.SP_CHAR_MAPPING.get(ch)
129
+ else:
130
+ ch = unicodedata.normalize("NFKC", ch)
131
+ if self.is_whitespace(ch):
132
+ continue
133
+ normalized_text += ch
134
+ char_mapping.extend([i] * len(ch))
135
+
136
+ text, token_mapping, offset = normalized_text, [], 0
137
+
138
+ if self.do_lower_case:
139
+ text = text.lower()
140
+
141
+ for token in split_tokens:
142
+ if token[:1] == "▁":
143
+ token = token[1:]
144
+ start = text[offset:].index(token) + offset
145
+ end = start + len(token)
146
+
147
+ token_mapping.append((char_mapping[start], char_mapping[end - 1] + 1))
148
+ offset = end
149
+ return token_mapping
150
+
151
+ @property
152
+ def vocab_size(self):
153
+ return len(self.vocab)
154
+
155
+ def get_vocab(self):
156
+ return dict(self.vocab, **self.added_tokens_encoder)
157
+
158
+ def __getstate__(self):
159
+ state = self.__dict__.copy()
160
+ state["sp_model"] = None
161
+ return state
162
+
163
+ def __setstate__(self, d):
164
+ self.__dict__ = d
165
+
166
+ # for backward compatibility
167
+ if not hasattr(self, "sp_model_kwargs"):
168
+ self.sp_model_kwargs = {}
169
+
170
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
171
+ self.sp_model.Load(self.sentencepiece_model_ckpt)
172
+
173
+ def clean_text(self, text):
174
+ """Performs invalid character removal and whitespace cleanup on text."""
175
+ return "".join((self.SP_CHAR_MAPPING.get(c, c) for c in text))
176
+
177
+ def _tokenize(self, text, enable_sampling=False, nbest_size=64, alpha=0.1):
178
+ """Tokenize a string."""
179
+
180
+ if self.sp_model_kwargs.get("enable_sampling") is True:
181
+ enable_sampling = True
182
+ if self.sp_model_kwargs.get("alpha") is not None:
183
+ alpha = self.sp_model_kwargs.get("alpha")
184
+ if self.sp_model_kwargs.get("nbest_size") is not None:
185
+ nbest_size = self.sp_model_kwargs.get("nbest_size")
186
+
187
+ if not enable_sampling:
188
+ pieces = self.sp_model.EncodeAsPieces(text)
189
+ else:
190
+ pieces = self.sp_model.SampleEncodeAsPieces(text, nbest_size, alpha)
191
+ new_pieces = []
192
+ for pi, piece in enumerate(pieces):
193
+ if piece == SPIECE_UNDERLINE:
194
+ if not pieces[pi + 1].startswith(SPIECE_UNDERLINE) and pi != 0:
195
+ new_pieces.append(SPIECE_UNDERLINE)
196
+ continue
197
+ else:
198
+ continue
199
+ lst_i = 0
200
+ for i, chunk in enumerate(piece):
201
+ if chunk == SPIECE_UNDERLINE:
202
+ continue
203
+ if self.is_ch_char(chunk) or self.is_punct(chunk):
204
+ if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
205
+ new_pieces.append(piece[lst_i:i])
206
+ new_pieces.append(chunk)
207
+ lst_i = i + 1
208
+ elif chunk.isdigit() and i > 0 and not piece[i - 1].isdigit():
209
+ if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
210
+ new_pieces.append(piece[lst_i:i])
211
+ lst_i = i
212
+ elif not chunk.isdigit() and i > 0 and piece[i - 1].isdigit():
213
+ if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
214
+ new_pieces.append(piece[lst_i:i])
215
+ lst_i = i
216
+ if len(piece) > lst_i:
217
+ new_pieces.append(piece[lst_i:])
218
+ return new_pieces
219
+
220
+ def convert_tokens_to_string(self, tokens):
221
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
222
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
223
+ return out_string
224
+
225
+ def convert_ids_to_string(self, ids):
226
+ """
227
+ Converts a sequence of tokens (strings for sub-words) in a single string.
228
+ """
229
+ tokens = self.convert_ids_to_tokens(ids)
230
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
231
+ return out_string
232
+
233
+ # to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning
234
+ def _convert_token_to_id(self, token):
235
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
236
+
237
+ # to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning
238
+ def _convert_id_to_token(self, index):
239
+ """Converts an index (integer) in a token (str) using the vocab."""
240
+ return self.reverse_vocab.get(index, self.unk_token)
241
+
242
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
243
+ r"""
244
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
245
+ adding special tokens. An ErnieM sequence has the following format:
246
+
247
+ - single sequence: `[CLS] X [SEP]`
248
+ - pair of sequences: `[CLS] A [SEP] [SEP] B [SEP]`
249
+
250
+ Args:
251
+ token_ids_0 (`List[int]`):
252
+ List of IDs to which the special tokens will be added.
253
+ token_ids_1 (`List[int]`, *optional*):
254
+ Optional second list of IDs for sequence pairs.
255
+ Returns:
256
+ `List[int]`: List of input_id with the appropriate special tokens.
257
+ """
258
+ if token_ids_1 is None:
259
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
260
+ _cls = [self.cls_token_id]
261
+ _sep = [self.sep_token_id]
262
+ return _cls + token_ids_0 + _sep + _sep + token_ids_1 + _sep
263
+
264
+ def build_offset_mapping_with_special_tokens(self, offset_mapping_0, offset_mapping_1=None):
265
+ r"""
266
+ Build offset map from a pair of offset map by concatenating and adding offsets of special tokens. An Ernie-M
267
+ offset_mapping has the following format:
268
+
269
+ - single sequence: `(0,0) X (0,0)`
270
+ - pair of sequences: `(0,0) A (0,0) (0,0) B (0,0)`
271
+
272
+ Args:
273
+ offset_mapping_ids_0 (`List[tuple]`):
274
+ List of char offsets to which the special tokens will be added.
275
+ offset_mapping_ids_1 (`List[tuple]`, *optional*):
276
+ Optional second list of wordpiece offsets for offset mapping pairs.
277
+ Returns:
278
+ `List[tuple]`: List of wordpiece offsets with the appropriate offsets of special tokens.
279
+ """
280
+ if offset_mapping_1 is None:
281
+ return [(0, 0)] + offset_mapping_0 + [(0, 0)]
282
+
283
+ return [(0, 0)] + offset_mapping_0 + [(0, 0), (0, 0)] + offset_mapping_1 + [(0, 0)]
284
+
285
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
286
+ r"""
287
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
288
+ special tokens using the tokenizer `encode` method.
289
+
290
+ Args:
291
+ token_ids_0 (`List[int]`):
292
+ List of ids of the first sequence.
293
+ token_ids_1 (`List[int]`, *optional*):
294
+ Optional second list of IDs for sequence pairs.
295
+ already_has_special_tokens (`str`, *optional*, defaults to `False`):
296
+ Whether or not the token list is already formatted with special tokens for the model.
297
+ Returns:
298
+ `List[int]`:
299
+ The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
300
+ """
301
+
302
+ if already_has_special_tokens:
303
+ if token_ids_1 is not None:
304
+ raise ValueError(
305
+ "You should not supply a second sequence if the provided sequence of "
306
+ "ids is already formatted with special tokens for the model."
307
+ )
308
+ return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
309
+
310
+ if token_ids_1 is not None:
311
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
312
+ return [1] + ([0] * len(token_ids_0)) + [1]
313
+
314
+ def create_token_type_ids_from_sequences(
315
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
316
+ ) -> List[int]:
317
+ """
318
+ Create the token type IDs corresponding to the sequences passed. [What are token type
319
+ IDs?](../glossary#token-type-ids) Should be overridden in a subclass if the model has a special way of
320
+ building: those.
321
+
322
+ Args:
323
+ token_ids_0 (`List[int]`):
324
+ The first tokenized sequence.
325
+ token_ids_1 (`List[int]`, *optional*):
326
+ The second tokenized sequence.
327
+ Returns:
328
+ `List[int]`: The token type ids.
329
+ """
330
+ # called when `add_special_tokens` is True, so align with `build_inputs_with_special_tokens` method
331
+ if token_ids_1 is None:
332
+ # [CLS] X [SEP]
333
+ return (len(token_ids_0) + 2) * [0]
334
+
335
+ # [CLS] A [SEP] [SEP] B [SEP]
336
+ return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3)
337
+
338
+ def is_ch_char(self, char):
339
+ """
340
+ is_ch_char
341
+ """
342
+ if "\u4e00" <= char <= "\u9fff":
343
+ return True
344
+ return False
345
+
346
+ def is_alpha(self, char):
347
+ """
348
+ is_alpha
349
+ """
350
+ if ("a" <= char <= "z") or ("A" <= char <= "Z"):
351
+ return True
352
+ return False
353
+
354
+ def is_punct(self, char):
355
+ """
356
+ is_punct
357
+ """
358
+ if char in ",;:.?!~,;:。?!《》【】":
359
+ return True
360
+ return False
361
+
362
+ def is_whitespace(self, char):
363
+ """
364
+ is whitespace
365
+ """
366
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
367
+ return True
368
+ if len(char) == 1:
369
+ cat = unicodedata.category(char)
370
+ if cat == "Zs":
371
+ return True
372
+ return False
373
+
374
+ def load_vocab(self, filepath):
375
+ token_to_idx = {}
376
+ with io.open(filepath, "r", encoding="utf-8") as f:
377
+ for index, line in enumerate(f):
378
+ token = line.rstrip("\n")
379
+ token_to_idx[token] = int(index)
380
+
381
+ return token_to_idx
382
+
383
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
384
+ index = 0
385
+ if os.path.isdir(save_directory):
386
+ vocab_file = os.path.join(
387
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
388
+ )
389
+ else:
390
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
391
+ with open(vocab_file, "w", encoding="utf-8") as writer:
392
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
393
+ if index != token_index:
394
+ logger.warning(
395
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
396
+ " Please check that the vocabulary is not corrupted!"
397
+ )
398
+ index = token_index
399
+ writer.write(token + "\n")
400
+ index += 1
401
+
402
+ tokenizer_model_file = os.path.join(save_directory, "sentencepiece.bpe.model")
403
+ with open(tokenizer_model_file, "wb") as fi:
404
+ content_spiece_model = self.sp_model.serialized_model_proto()
405
+ fi.write(content_spiece_model)
406
+
407
+ return (vocab_file,)
408
+
409
+
410
+ __all__ = ["ErnieMTokenizer"]
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_gptsan_japanese import *
22
+ from .modeling_gptsan_japanese import *
23
+ from .tokenization_gptsan_japanese import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """GPTSAN-japanese model configuration"""
16
+
17
+ from ....configuration_utils import PretrainedConfig
18
+ from ....utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class GPTSanJapaneseConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`GPTSanJapaneseModel`]. It is used to instantiate
27
+ a GPTSANJapanese model according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the GPTSANJapanese
29
+ [Tanrei/GPTSAN-japanese](https://huggingface.co/Tanrei/GPTSAN-japanese) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Arguments:
35
+ vocab_size (`int`, *optional*, defaults to 36000):
36
+ Vocabulary size of the GPTSANJapanese model. Defines the number of different tokens that can be represented
37
+ by the `inputs_ids` passed when calling [`GPTSanJapaneseModel`].
38
+ max_position_embeddings (`int`, *optional*, defaults to 1280):
39
+ The maximum sequence length that this model might ever be used with. Defaults set this to 1280.
40
+ d_model (`int`, *optional*, defaults to 1024):
41
+ Size of the encoder layers and the pooler layer.
42
+ d_ff (`int`, *optional*, defaults to 8192):
43
+ Size of the intermediate feed forward layer in each `SwitchTransformersBlock`.
44
+ d_ext (`int`, *optional*, defaults to 4096):
45
+ Size of the intermediate feed forward layer in each Extra-layers.
46
+ d_spout (`int`, *optional*, defaults to 128):
47
+ Size of the `spout` vector.
48
+ num_switch_layers (`int`, *optional*, defaults to 10):
49
+ Number of layers in the Switch Transformer layer.
50
+ num_ext_layers (`int`, *optional*, defaults to 0):
51
+ Number of layers in the Extra-layers.
52
+ num_heads (`int`, *optional*, defaults to 16):
53
+ Number of attention heads for each attention layer in the Transformer encoder.
54
+ num_experts (`int`, *optional*, defaults to 16):
55
+ Number of experts for each SwitchTransformer layer.
56
+ expert_capacity (`int`, *optional*, defaults to 128):
57
+ Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular
58
+ Transformer.
59
+ dropout_rate (`float`, *optional*, defaults to 0.0):
60
+ The ratio for all dropout layers.
61
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
62
+ The epsilon used by the layer normalization layers.
63
+ router_bias (`bool`, *optional*, defaults to `False`):
64
+ Whether to add a bias to the router.
65
+ router_jitter_noise (`float`, *optional*, defaults to 0.0):
66
+ Amount of noise to add to the router. Set it to 0.0 during prediction or set small value (usually 1e-2)
67
+ during training.
68
+ router_dtype (`str`, *optional*, default to `"float32"`):
69
+ The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the
70
+ *selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961).
71
+ router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`):
72
+ Whether to ignore padding tokens when routing.
73
+ output_hidden_states (`bool`, *optional*, default to `False`):
74
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
75
+ more detail.
76
+ output_attentions (`bool`, *optional*, defaults to `False`):
77
+ Whether or not to return the attentions tensors of all attention layers.
78
+ initializer_factor (`float`, *optional*, defaults to 0.002):
79
+ A factor for initializing all weight matrices.
80
+ output_router_logits (`bool`, *optional*, default to `False`):
81
+ Whether or not to return the router logits of all experts.
82
+ use_cache (`bool`, *optional*, defaults to `True`):
83
+ Whether or not the model should return the last key/values attentions (not used by all models)
84
+ """
85
+
86
+ model_type = "gptsan-japanese"
87
+ keys_to_ignore_at_inference = [
88
+ "past_key_values",
89
+ ]
90
+ attribute_map = {
91
+ "hidden_size": "d_model",
92
+ "num_attention_heads": "num_heads",
93
+ "num_hidden_layers": "num_layers",
94
+ }
95
+
96
+ def __init__(
97
+ self,
98
+ vocab_size=36000,
99
+ max_position_embeddings=1280,
100
+ d_model=1024,
101
+ d_ff=8192,
102
+ d_ext=4096,
103
+ d_spout=128,
104
+ num_switch_layers=10,
105
+ num_ext_layers=0,
106
+ num_heads=16,
107
+ num_experts=16,
108
+ expert_capacity=128,
109
+ dropout_rate=0.0,
110
+ layer_norm_epsilon=1e-5,
111
+ router_bias=False,
112
+ router_jitter_noise=0.0,
113
+ router_dtype="float32",
114
+ router_ignore_padding_tokens=False,
115
+ output_hidden_states=False,
116
+ output_attentions=False,
117
+ initializer_factor=0.002,
118
+ output_router_logits=False,
119
+ use_cache=True,
120
+ separator_token_id=35998,
121
+ pad_token_id=35995,
122
+ eos_token_id=35999,
123
+ **kwargs,
124
+ ):
125
+ self.vocab_size = vocab_size
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.d_model = d_model
128
+ self.d_ff = d_ff
129
+ self.d_ext = d_ext
130
+ self.d_spout = d_spout
131
+ self.num_switch_layers = num_switch_layers
132
+ self.num_ext_layers = num_ext_layers
133
+ self.num_layers = num_switch_layers + num_ext_layers
134
+ self.num_heads = num_heads
135
+ self.num_experts = num_experts
136
+ self.expert_capacity = expert_capacity
137
+ self.dropout_rate = dropout_rate
138
+ self.layer_norm_epsilon = layer_norm_epsilon
139
+ self.router_bias = router_bias
140
+ self.router_jitter_noise = router_jitter_noise
141
+ self.router_dtype = router_dtype
142
+ self.router_ignore_padding_tokens = router_ignore_padding_tokens
143
+ self.output_hidden_states = output_hidden_states
144
+ self.output_attentions = output_attentions
145
+ self.initializer_factor = initializer_factor
146
+ self.output_router_logits = output_router_logits
147
+ self.use_cache = use_cache
148
+
149
+ super().__init__(
150
+ separator_token_id=separator_token_id,
151
+ pad_token_id=pad_token_id,
152
+ eos_token_id=eos_token_id,
153
+ **kwargs,
154
+ )
155
+
156
+
157
+ __all__ = ["GPTSanJapaneseConfig"]
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Convert GPTSANJapanese checkpoints from the original repository to pytorch model."""
17
+
18
+ import argparse
19
+ import json
20
+ import os
21
+ from collections import OrderedDict
22
+
23
+ import numpy as np
24
+ import tensorflow as tf
25
+ import torch
26
+
27
+
28
+ def convert_tf_gptsan_to_pt(args):
29
+ parameter_file = os.path.join(args.tf_model_dir, "parameters.json")
30
+ params = json.loads(open(parameter_file).read())
31
+ if not params:
32
+ raise ValueError(
33
+ f"It seems that the json file at {parameter_file} is empty. Make sure you have a correct json file."
34
+ )
35
+ if not args.output.endswith(".pt"):
36
+ args.output = args.output + ".pt"
37
+ new_state = OrderedDict()
38
+ with tf.device("/CPU:0"):
39
+ reader = tf.train.load_checkpoint(args.tf_model_dir)
40
+ shapes = reader.get_variable_to_shape_map()
41
+ for key_name in shapes.keys():
42
+ vnp = reader.get_tensor(key_name).astype(np.float16)
43
+ if key_name.endswith("/adam_m") or key_name.endswith("/adam_v"):
44
+ continue
45
+ if key_name.startswith("pasts/"):
46
+ if key_name.startswith("pasts/mlp"):
47
+ player = int(key_name[9])
48
+ elif key_name.startswith("pasts/out"):
49
+ player = 8
50
+ name = "model.sqout.%d.weight" % (player * 2) # enter to nn.Sequencial with Tanh, so 2 at a time
51
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
52
+ new_state[name] = torch.tensor(state)
53
+ elif key_name.startswith("model/moe"):
54
+ player = int(key_name[9:].split("/")[0])
55
+ if key_name.endswith("/switch_gating/kernel"):
56
+ name = "model.blocks.%d.feed_forward.mlp.router.classifier.weight" % player
57
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
58
+ new_state[name] = torch.tensor(state)
59
+ elif key_name.endswith("/softmlp/kernel"):
60
+ name = "model.blocks.%d.feed_forward.soft_bypass_mlp.weight" % player
61
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
62
+ new_state[name] = torch.tensor(state)
63
+ elif key_name.endswith("/wo/kernel") or key_name.endswith("/wi/kernel"):
64
+ nlayer = key_name[-9:-7]
65
+ for i in range(16):
66
+ name = "model.blocks.%d.feed_forward.mlp.experts.expert_%d.%s.weight" % (player, i, nlayer)
67
+ state = (
68
+ vnp[i].transpose([1, 0]).copy()
69
+ ) # In Mesh-Tensorflow, it is one array, so it is divided
70
+ new_state[name] = torch.tensor(state)
71
+ elif key_name.startswith("model/mlp"):
72
+ player = int(key_name[9:].split("/")[0])
73
+ if key_name.endswith("/p1/kernel"):
74
+ name = "model.blocks.%d.feed_forward.mlp.wi.weight" % player
75
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
76
+ new_state[name] = torch.tensor(state)
77
+ elif key_name.endswith("/p1/bias"):
78
+ name = "model.blocks.%d.feed_forward.mlp.wi.bias" % player
79
+ state = vnp.copy() # same because it is one dimensional
80
+ new_state[name] = torch.tensor(state)
81
+ elif key_name.endswith("/p2/kernel"):
82
+ name = "model.blocks.%d.feed_forward.mlp.wo.weight" % player
83
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
84
+ new_state[name] = torch.tensor(state)
85
+ elif key_name.endswith("/p2/bias"):
86
+ name = "model.blocks.%d.feed_forward.mlp.wo.bias" % player
87
+ state = vnp.copy() # same because it is one dimensional
88
+ new_state[name] = torch.tensor(state)
89
+ elif key_name.startswith("model/ln"):
90
+ player = int(key_name[8:].split("/")[0])
91
+ if key_name.endswith("/b"):
92
+ name = "model.blocks.%d.feed_forward.norm.bias" % player
93
+ state = vnp.copy() # same because it is one dimensional
94
+ new_state[name] = torch.tensor(state)
95
+ elif key_name.endswith("/g"):
96
+ name = "model.blocks.%d.feed_forward.norm.weight" % player
97
+ state = vnp.copy() # same because it is one dimensional
98
+ new_state[name] = torch.tensor(state)
99
+ elif key_name.startswith("model/att"):
100
+ player = int(key_name[9:].split("/")[0])
101
+ if key_name.endswith("/qkv/kernel"):
102
+ state = vnp.copy() # Compute same dimension as Mesh-tensorflow using einsum
103
+ state_q = state[:, 0, :, :]
104
+ state_k = state[:, 1, :, :]
105
+ state_v = state[:, 2, :, :]
106
+ state_q = (
107
+ state_q.reshape([state_q.shape[0], state_q.shape[1] * state_q.shape[2]])
108
+ .transpose([1, 0])
109
+ .copy()
110
+ ) # Mesh-Tensorflow is a diagonal matrix
111
+ state_k = (
112
+ state_k.reshape([state_k.shape[0], state_k.shape[1] * state_k.shape[2]])
113
+ .transpose([1, 0])
114
+ .copy()
115
+ ) # Mesh-Tensorflow is a diagonal matrix
116
+ state_v = (
117
+ state_v.reshape([state_v.shape[0], state_v.shape[1] * state_v.shape[2]])
118
+ .transpose([1, 0])
119
+ .copy()
120
+ ) # Mesh-Tensorflow is a diagonal matrix
121
+ name = "model.blocks.%d.self_attn.self_attn.q_proj.weight" % player
122
+ new_state[name] = torch.tensor(state_q)
123
+ name = "model.blocks.%d.self_attn.self_attn.k_proj.weight" % player
124
+ new_state[name] = torch.tensor(state_k)
125
+ name = "model.blocks.%d.self_attn.self_attn.v_proj.weight" % player
126
+ new_state[name] = torch.tensor(state_v)
127
+ elif key_name.endswith("/o/kernel"):
128
+ name = "model.blocks.%d.self_attn.self_attn.out_proj.weight" % player
129
+ state = (
130
+ vnp.reshape([vnp.shape[0] * vnp.shape[1], vnp.shape[2]]).transpose([1, 0]).copy()
131
+ ) # Mesh-Tensorflow is a diagonal matrix
132
+ new_state[name] = torch.tensor(state)
133
+ elif key_name.startswith("model/an"):
134
+ player = int(key_name[8:].split("/")[0])
135
+ if key_name.endswith("/b"):
136
+ name = "model.blocks.%d.self_attn.norm.bias" % player
137
+ state = vnp.copy() # same because it is one dimensional
138
+ new_state[name] = torch.tensor(state)
139
+ elif key_name.endswith("/g"):
140
+ name = "model.blocks.%d.self_attn.norm.weight" % player
141
+ state = vnp.copy() # same because it is one dimensional
142
+ new_state[name] = torch.tensor(state)
143
+ elif (
144
+ key_name.startswith("model/wte")
145
+ or key_name.startswith("model/wpe")
146
+ or key_name.startswith("model/ete")
147
+ ):
148
+ nlayer = {"wte": "embed_tokens", "wpe": "position_embeddings", "ete": "extra_position_embeddings"}[
149
+ key_name[-3:]
150
+ ]
151
+ name = "model.%s.weight" % nlayer
152
+ state = vnp.copy() # same in embedded
153
+ new_state[name] = torch.tensor(state)
154
+ if key_name.startswith("model/wte"):
155
+ name = "lm_head.weight"
156
+ state = vnp.copy() # same in embedded
157
+ new_state[name] = torch.tensor(state)
158
+ elif key_name.startswith("model/wob"):
159
+ name = "final_logits_bias"
160
+ state = vnp.copy() # same in embedded
161
+ state = state.reshape((1, -1))
162
+ new_state[name] = torch.tensor(state)
163
+ elif key_name == "model/dense/kernel":
164
+ name = "model.last_project.weight"
165
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
166
+ new_state[name] = torch.tensor(state)
167
+ elif key_name == "model/dense_1/bias":
168
+ name = "model.last_project.bias"
169
+ state = vnp.copy() # same because it is one dimensional
170
+ new_state[name] = torch.tensor(state)
171
+ torch.save(new_state, args.output)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ parser = argparse.ArgumentParser(
176
+ description="model converter.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
177
+ )
178
+ parser.add_argument("--tf_model_dir", metavar="PATH", type=str, required=True, help="import model")
179
+ parser.add_argument("--output", metavar="PATH", type=str, required=True, help="output model")
180
+ args = parser.parse_args()
181
+ convert_tf_gptsan_to_pt(args)
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py ADDED
@@ -0,0 +1,1337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Toshiyuki Sakamoto(tanreinama) and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch GPTSANJapanese model."""
16
+
17
+ import copy
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from ....activations import ACT2FN
24
+ from ....modeling_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPastAndCrossAttentions
25
+ from ....modeling_utils import PreTrainedModel
26
+ from ....utils import (
27
+ DUMMY_INPUTS,
28
+ DUMMY_MASK,
29
+ add_start_docstrings,
30
+ add_start_docstrings_to_model_forward,
31
+ is_torch_fx_proxy,
32
+ logging,
33
+ )
34
+ from .configuration_gptsan_japanese import GPTSanJapaneseConfig
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ _CONFIG_FOR_DOC = "GPTSanJapaneseConfig"
40
+ _CHECKPOINT_FOR_DOC = "Tanrei/GPTSAN-japanese"
41
+
42
+ ####################################################
43
+ # This dict contains ids and associated url
44
+ # for the pretrained weights provided with the models
45
+ ####################################################
46
+
47
+
48
+ def router_z_loss_func(router_logits: torch.Tensor) -> float:
49
+ r"""
50
+ Compute the router z-loss implemented in PyTorch.
51
+
52
+ The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
53
+ It encourages router logits to remain small in an effort to improve stability.
54
+
55
+ Args:
56
+ router_logits (`float`):
57
+ Input logits of shape [batch_size, sequence_length, num_experts]
58
+
59
+ Returns:
60
+ Scalar router z-loss.
61
+ """
62
+ num_groups, tokens_per_group, _ = router_logits.shape
63
+ log_z = torch.logsumexp(router_logits, dim=-1)
64
+ z_loss = log_z**2
65
+ return torch.sum(z_loss) / (num_groups * tokens_per_group)
66
+
67
+
68
+ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
69
+ r"""
70
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
71
+
72
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
73
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
74
+ experts is too unbalanced.
75
+
76
+ Args:
77
+ router_probs (`torch.Tensor`):
78
+ Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts].
79
+ expert_indices (`torch.Tensor`):
80
+ Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token.
81
+
82
+ Returns:
83
+ The auxiliary loss.
84
+ """
85
+ num_experts = router_probs.shape[-1]
86
+
87
+ # cast the expert indices to int64, otherwise one-hot encoding will fail
88
+ if expert_indices.dtype != torch.int64:
89
+ expert_indices = expert_indices.to(torch.int64)
90
+
91
+ if len(expert_indices.shape) == 2:
92
+ expert_indices = expert_indices.unsqueeze(2)
93
+
94
+ expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
95
+
96
+ # For a given token, determine if it was routed to a given expert.
97
+ expert_mask = torch.max(expert_mask, axis=-2).values
98
+
99
+ # cast to float32 otherwise mean will fail
100
+ expert_mask = expert_mask.to(torch.float32)
101
+ tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
102
+
103
+ router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)
104
+ return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
105
+
106
+
107
+ class GPTSanJapaneseDenseActDense(nn.Module):
108
+ """
109
+ FFN Layer for Switch Transformer and Extra layers
110
+
111
+ GPTSAN can mix Switch Transformer layers and normal Transformer layers This class is used as Expert in Switch
112
+ Transformer layers and as FFN in regular Transformer layers. RELU is used in the Switch Transformer layer, and
113
+ Swish is used in the normal Transformer layer, so there is a choice of which is used in the argument.
114
+
115
+ """
116
+
117
+ def __init__(self, config: GPTSanJapaneseConfig, ext_layer=False):
118
+ super().__init__()
119
+ d_inter = config.d_ext if ext_layer else config.d_ff
120
+ self.wi = nn.Linear(config.d_model, d_inter, bias=ext_layer)
121
+ self.wo = nn.Linear(d_inter, config.d_model, bias=ext_layer)
122
+ self.dropout = nn.Identity() if ext_layer else nn.Dropout(config.dropout_rate)
123
+ self.act = ACT2FN["swish" if ext_layer else "relu"]
124
+
125
+ def forward(self, hidden_states):
126
+ r"""
127
+ Args:
128
+ hidden_states (`torch.Tensor`) :
129
+ [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
130
+ Returns:
131
+ torch.Tensor[num_groups, tokens_per_group, hidden_dim]
132
+
133
+ """
134
+ hidden_states = self.wi(hidden_states)
135
+ hidden_states = self.act(hidden_states)
136
+ hidden_states = self.dropout(hidden_states)
137
+ hidden_states = self.wo(hidden_states)
138
+ return hidden_states
139
+
140
+
141
+ class GPTSanJapaneseTop1Router(nn.Module):
142
+ """
143
+ Router using tokens choose top-1 experts assignment.
144
+
145
+ This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE
146
+ (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then
147
+ routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each
148
+ token is processed by an expert**, or that each expert receives at least one token.
149
+
150
+ """
151
+
152
+ def __init__(self, config: GPTSanJapaneseConfig):
153
+ super().__init__()
154
+ self.num_experts = config.num_experts
155
+ self.expert_capacity = config.expert_capacity
156
+ self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
157
+ self.jitter_noise = config.router_jitter_noise
158
+ self.ignore_padding_tokens = config.router_ignore_padding_tokens
159
+ self.dtype = getattr(torch, config.router_dtype)
160
+
161
+ def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ r"""
163
+ Computes router probabilities from input hidden states.
164
+
165
+ Args:
166
+ hidden_states (`torch.Tensor`):
167
+ (batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
168
+ Returns:
169
+ router_probabilities (`torch.Tensor`):
170
+ Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
171
+ token and expert. Used for routing tokens to experts.
172
+ router_logits (`torch.Tensor`):
173
+ Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
174
+ This is used later for computing router z-loss.
175
+ """
176
+ # float32 is used to ensure stability. See the discussion of "selective precision" in
177
+ # https://arxiv.org/abs/2101.03961.
178
+ # We also store the previous dtype to cast back the output to the previous dtype
179
+ self.input_dtype = hidden_states.dtype
180
+ hidden_states = hidden_states.to(self.dtype)
181
+
182
+ if self.training and self.jitter_noise > 0:
183
+ # Multiply the token inputs by the uniform distribution - adding some noise
184
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
185
+
186
+ # Shape: [num_groups, tokens_per_group, num_experts]
187
+ self._cast_classifier()
188
+ router_logits = self.classifier(hidden_states)
189
+
190
+ # Apply Softmax and cast back to the original `dtype`
191
+ router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
192
+ return router_probabilities, router_logits
193
+
194
+ def _cast_classifier(self):
195
+ r"""
196
+ `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
197
+ instance of the `Linear8bitLt` class by checking special attributes.
198
+ """
199
+ if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
200
+ self.classifier = self.classifier.to(self.dtype)
201
+
202
+ def forward(self, hidden_states: torch.Tensor) -> Tuple:
203
+ r"""
204
+ Generic forward function for every Router class. Each Router expects to have the same input hidden states
205
+ (`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
206
+ number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.
207
+
208
+ Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
209
+ `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
210
+ to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.
211
+
212
+ Args:
213
+ hidden_states (`torch.Tensor`) :
214
+ [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
215
+ Returns:
216
+ Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs
217
+ and the router logits. The router probabilities and logits are required to compute the loss.
218
+ """
219
+ router_probs, router_logits = self._compute_router_probabilities(hidden_states)
220
+
221
+ expert_index = torch.argmax(router_probs, dim=-1)
222
+ expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
223
+
224
+ # Mask tokens outside expert capacity. Sum over each sequence
225
+ token_priority = torch.cumsum(expert_index, dim=-2)
226
+ # mask if the token routed to to the expert will overflow
227
+ expert_capacity_mask = token_priority <= self.expert_capacity
228
+ expert_index = expert_index * expert_capacity_mask
229
+
230
+ router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
231
+ return expert_index, router_probs, router_logits
232
+
233
+
234
+ class GPTSanJapaneseSparseMLP(nn.Module):
235
+ r"""
236
+ Implementation of the Switch Transformers Sparse MLP module.
237
+ """
238
+
239
+ def __init__(self, config: GPTSanJapaneseConfig, expert_class: nn.Module = GPTSanJapaneseDenseActDense):
240
+ super().__init__()
241
+ # Step 1: Get the correct router according to its class
242
+ self.router = GPTSanJapaneseTop1Router(config)
243
+
244
+ # Step 2: Get the experts
245
+ self.experts = nn.ModuleDict()
246
+ for idx in range(config.num_experts):
247
+ self.experts[f"expert_{idx}"] = expert_class(config)
248
+
249
+ def forward(self, hidden_states):
250
+ r"""
251
+ Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:
252
+
253
+ 1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
254
+ and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
255
+ hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
256
+
257
+ 2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
258
+ expert the corresponding hidden states.
259
+
260
+ """
261
+ # Step 1: Get the router_mask from the router as wel as the probabilities
262
+ router_mask, router_probs, router_logits = self.router(hidden_states)
263
+ expert_index = torch.argmax(router_mask, dim=-1)
264
+
265
+ # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
266
+ # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
267
+
268
+ next_states = hidden_states.clone()
269
+ for idx, expert in enumerate(self.experts.values()):
270
+ token_indices = router_mask[:, :, idx].bool()
271
+ next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
272
+
273
+ hidden_states = router_probs * next_states
274
+ return hidden_states, (router_logits, expert_index)
275
+
276
+
277
+ class GPTSanJapaneseLayerSparseFF(nn.Module):
278
+ r"""
279
+ Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module.
280
+
281
+ Parameters:
282
+ config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
283
+ Initializing with a config file does not load the weights associated with the model, only the
284
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
285
+ """
286
+
287
+ def __init__(self, config: GPTSanJapaneseConfig):
288
+ super().__init__()
289
+ self.mlp = GPTSanJapaneseSparseMLP(config)
290
+ self.soft_bypass_mlp = nn.Linear(config.d_model, config.d_model, bias=False)
291
+ self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
292
+
293
+ def forward(self, hidden_states, output_router_logits):
294
+ r"""
295
+ Args:
296
+ hidden_states (`torch.Tensor`) :
297
+ [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
298
+ output_router_logits (`bool`) :
299
+ output experts router output.
300
+ Returns:
301
+ torch.Tensor[num_groups, tokens_per_group, hidden_dim]
302
+
303
+ """
304
+ forwarded_states, router_tuple = self.mlp(hidden_states)
305
+ forwarded_states += torch.tanh(self.soft_bypass_mlp(hidden_states))
306
+ output = hidden_states + self.norm(forwarded_states)
307
+
308
+ if output_router_logits and router_tuple is not None:
309
+ return output, router_tuple
310
+ else:
311
+ return output
312
+
313
+
314
+ class GPTSanJapaneseLayerDenseFF(nn.Module):
315
+ r"""
316
+ Extra Transformers Feed Forward layer module.
317
+
318
+ Parameters:
319
+ config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
320
+ Initializing with a config file does not load the weights associated with the model, only the
321
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
322
+ """
323
+
324
+ def __init__(self, config: GPTSanJapaneseConfig):
325
+ super().__init__()
326
+ # Check if it is a sparse layer, if not then it is a dense layer
327
+ self.mlp = GPTSanJapaneseDenseActDense(config, ext_layer=True)
328
+ self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
329
+
330
+ def forward(self, hidden_states):
331
+ r"""
332
+ Args:
333
+ hidden_states (`torch.Tensor`) :
334
+ [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
335
+ Returns:
336
+ torch.Tensor[num_groups, tokens_per_group, hidden_dim]
337
+
338
+ """
339
+ forwarded_states = self.mlp(hidden_states)
340
+ output = hidden_states + self.norm(forwarded_states)
341
+ return output
342
+
343
+
344
+ class GPTSanJapaneseAttention(nn.Module):
345
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
346
+
347
+ def __init__(
348
+ self,
349
+ embed_dim: int,
350
+ num_heads: int,
351
+ dropout: float = 0.0,
352
+ is_decoder: bool = False,
353
+ bias: bool = True,
354
+ is_causal: bool = False,
355
+ config: Optional[GPTSanJapaneseConfig] = None,
356
+ ):
357
+ super().__init__()
358
+ self.embed_dim = embed_dim
359
+ self.num_heads = num_heads
360
+ self.dropout = dropout
361
+ self.head_dim = embed_dim // num_heads
362
+ self.config = config
363
+
364
+ if (self.head_dim * num_heads) != self.embed_dim:
365
+ raise ValueError(
366
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
367
+ f" and `num_heads`: {num_heads})."
368
+ )
369
+ self.scaling = self.head_dim**-0.5
370
+ self.is_decoder = is_decoder
371
+ self.is_causal = is_causal
372
+
373
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
374
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
375
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
376
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
377
+
378
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
379
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states: torch.Tensor,
384
+ key_value_states: Optional[torch.Tensor] = None,
385
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
386
+ attention_mask: Optional[torch.Tensor] = None,
387
+ layer_head_mask: Optional[torch.Tensor] = None,
388
+ output_attentions: bool = False,
389
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
390
+ """Input shape: Batch x Time x Channel"""
391
+
392
+ # if key_value_states are provided this layer is used as a cross-attention layer
393
+ # for the decoder
394
+ is_cross_attention = key_value_states is not None
395
+
396
+ bsz, tgt_len, _ = hidden_states.size()
397
+
398
+ # get query proj
399
+ query_states = self.q_proj(hidden_states) * self.scaling
400
+ # get key, value proj
401
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
402
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
403
+ # the provided `key_value_states` to support prefix tuning
404
+ if (
405
+ is_cross_attention
406
+ and past_key_value is not None
407
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
408
+ ):
409
+ # reuse k,v, cross_attentions
410
+ key_states = past_key_value[0]
411
+ value_states = past_key_value[1]
412
+ elif is_cross_attention:
413
+ # cross_attentions
414
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
415
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
416
+ elif past_key_value is not None:
417
+ # reuse k, v, self_attention
418
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
419
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
420
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
421
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
422
+ else:
423
+ # self_attention
424
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
425
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
426
+
427
+ if self.is_decoder:
428
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
429
+ # Further calls to cross_attention layer can then reuse all cross-attention
430
+ # key/value_states (first "if" case)
431
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
432
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
433
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
434
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
435
+ past_key_value = (key_states, value_states)
436
+
437
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
438
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
439
+ key_states = key_states.reshape(*proj_shape)
440
+ value_states = value_states.reshape(*proj_shape)
441
+
442
+ src_len = key_states.size(1)
443
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
444
+
445
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
446
+ raise ValueError(
447
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
448
+ f" {attn_weights.size()}"
449
+ )
450
+
451
+ if attention_mask is not None:
452
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
453
+ raise ValueError(
454
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
455
+ )
456
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
457
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
458
+
459
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
460
+
461
+ if layer_head_mask is not None:
462
+ if layer_head_mask.size() != (self.num_heads,):
463
+ raise ValueError(
464
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
465
+ f" {layer_head_mask.size()}"
466
+ )
467
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
468
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
469
+
470
+ if output_attentions:
471
+ # this operation is a bit awkward, but it's required to
472
+ # make sure that attn_weights keeps its gradient.
473
+ # In order to do so, attn_weights have to be reshaped
474
+ # twice and have to be reused in the following
475
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
476
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
477
+ else:
478
+ attn_weights_reshaped = None
479
+
480
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
481
+
482
+ attn_output = torch.bmm(attn_probs, value_states)
483
+
484
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
485
+ raise ValueError(
486
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
487
+ f" {attn_output.size()}"
488
+ )
489
+
490
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
491
+ attn_output = attn_output.transpose(1, 2)
492
+
493
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
494
+ # partitioned across GPUs when using tensor-parallelism.
495
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
496
+
497
+ attn_output = self.out_proj(attn_output)
498
+
499
+ return attn_output, attn_weights_reshaped, past_key_value
500
+
501
+
502
+ class GPTSanJapaneseLayerSelfAttention(nn.Module):
503
+ """
504
+ Self Attention and Normalization Unit
505
+ """
506
+
507
+ def __init__(self, config, has_relative_attention_bias=False):
508
+ super().__init__()
509
+ self.self_attn = GPTSanJapaneseAttention(
510
+ embed_dim=config.d_model,
511
+ num_heads=config.num_heads,
512
+ is_decoder=True,
513
+ bias=has_relative_attention_bias,
514
+ )
515
+ self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
516
+
517
+ def forward(
518
+ self,
519
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
520
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
521
+ attention_mask: Optional[torch.FloatTensor] = None,
522
+ head_mask: Optional[torch.FloatTensor] = None,
523
+ use_cache: Optional[bool] = False,
524
+ output_attentions: Optional[bool] = False,
525
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
526
+ r"""
527
+ Self-attention and normalize block.
528
+
529
+ Args:
530
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
531
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
532
+ if the model is configured as a decoder.
533
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
534
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
535
+ decoding. If `past_key_values` are used, the user can optionally input only the last
536
+ `decoder_input_ids` (those that don't have their past key value states given to this model) of shape
537
+ `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
538
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
539
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
540
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
541
+
542
+ - 1 for tokens that are **not masked**,
543
+ - 0 for tokens that are **masked**.
544
+
545
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
546
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
547
+
548
+ - 1 indicates the head is **not masked**,
549
+ - 0 indicates the head is **masked**.
550
+
551
+ use_cache (`bool`, *optional*):
552
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
553
+ (see `past_key_values`).
554
+ output_attentions (`bool`, *optional*):
555
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
556
+ returned tensors for more detail.
557
+ Returns:
558
+ Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...]
559
+ """
560
+ # Self Attention
561
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
562
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
563
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
564
+ atten_out = self.self_attn(
565
+ hidden_states=hidden_states,
566
+ past_key_value=self_attn_past_key_value,
567
+ attention_mask=(1 - attention_mask) * torch.finfo(hidden_states.dtype).min,
568
+ layer_head_mask=head_mask,
569
+ output_attentions=output_attentions,
570
+ )
571
+ if output_attentions:
572
+ attn_weights = (atten_out[1],)
573
+ else:
574
+ attn_weights = ()
575
+
576
+ attention_output = atten_out[0]
577
+
578
+ hidden = hidden_states + self.norm(attention_output)
579
+
580
+ if use_cache:
581
+ outputs = (hidden, atten_out[2]) # hidden, present, (attentions)
582
+ else:
583
+ outputs = (hidden,) # hidden, (attentions)
584
+
585
+ return outputs + attn_weights
586
+
587
+
588
+ class GPTSanJapaneseBlock(nn.Module):
589
+ """
590
+ Self Attention and FFN Unit
591
+ """
592
+
593
+ def __init__(self, config, ext_layer=False):
594
+ super().__init__()
595
+ self.self_attn = GPTSanJapaneseLayerSelfAttention(config)
596
+ self.feed_forward = GPTSanJapaneseLayerDenseFF(config) if ext_layer else GPTSanJapaneseLayerSparseFF(config)
597
+
598
+ def forward(
599
+ self,
600
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
601
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
602
+ attention_mask: Optional[torch.FloatTensor] = None,
603
+ head_mask: Optional[torch.FloatTensor] = None,
604
+ use_cache: Optional[bool] = False,
605
+ output_attentions: Optional[bool] = False,
606
+ output_router_tuple: Optional[bool] = False,
607
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
608
+ r"""
609
+ GPTSAN transformer block.
610
+
611
+ Args:
612
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
613
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
614
+ if the model is configured as a decoder.
615
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
616
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
617
+ decoding. If `past_key_values` are used, the user can optionally input only the last
618
+ `decoder_input_ids` (those that don't have their past key value states given to this model) of shape
619
+ `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
620
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
621
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
622
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
623
+
624
+ - 1 for tokens that are **not masked**,
625
+ - 0 for tokens that are **masked**.
626
+
627
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
628
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
629
+
630
+ - 1 indicates the head is **not masked**,
631
+ - 0 indicates the head is **masked**.
632
+
633
+ use_cache (`bool`, *optional*):
634
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
635
+ (see `past_key_values`).
636
+ output_attentions (`bool`) :
637
+ output attention probabirities.
638
+ output_router_tuple:
639
+ output experts router logits and expert id.
640
+ Returns:
641
+ Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...]
642
+ """
643
+ atten_out = self.self_attn(
644
+ hidden_states=hidden_states,
645
+ past_key_value=past_key_value,
646
+ attention_mask=attention_mask,
647
+ head_mask=head_mask,
648
+ use_cache=use_cache,
649
+ output_attentions=output_attentions,
650
+ )
651
+ attention_output = atten_out[0]
652
+
653
+ if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF):
654
+ sparse_out = self.feed_forward(attention_output, output_router_tuple)
655
+ if output_router_tuple:
656
+ hidden, router_tuple = sparse_out
657
+ else:
658
+ hidden = sparse_out
659
+ else:
660
+ hidden = self.feed_forward(attention_output)
661
+
662
+ outputs = (hidden,) + atten_out[1:]
663
+
664
+ if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF) and output_router_tuple:
665
+ outputs += (router_tuple,)
666
+
667
+ return outputs
668
+
669
+
670
+ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
671
+ """
672
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
673
+ models.
674
+ """
675
+
676
+ config_class = GPTSanJapaneseConfig
677
+ base_model_prefix = "gptsan_japanese"
678
+ supports_gradient_checkpointing = False
679
+ _no_split_modules = ["GPTSanJapaneseBlock"]
680
+ _skip_keys_device_placement = "past_key_values"
681
+
682
+ @property
683
+ def dummy_inputs(self):
684
+ input_ids = torch.tensor(DUMMY_INPUTS)
685
+ input_mask = torch.tensor(DUMMY_MASK)
686
+ dummy_inputs = {
687
+ "input_ids": input_ids,
688
+ "attention_mask": input_mask,
689
+ }
690
+ return dummy_inputs
691
+
692
+ def _init_weights(self, module):
693
+ """Initialize the weights"""
694
+ factor = self.config.initializer_factor # Used for testing weights initialization
695
+ if isinstance(module, nn.LayerNorm):
696
+ module.weight.data.fill_(factor * 1.0)
697
+ module.bias.data.zero_()
698
+ elif isinstance(module, nn.Linear):
699
+ module.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
700
+ if hasattr(module, "bias") and module.bias is not None:
701
+ module.bias.data.zero_()
702
+ elif isinstance(module, nn.Embedding):
703
+ module.weight.data.normal_(mean=0.0, std=factor * 1.0)
704
+ elif isinstance(module, GPTSanJapaneseModel):
705
+ # Mesh TensorFlow embeddings initialization
706
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
707
+ module.embed_tokens.weight.data.normal_(mean=0.0, std=factor * 1.0)
708
+ module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0)
709
+ if hasattr(module, "extra_position_embeddings") and module.extra_position_embeddings is not None:
710
+ module.extra_position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0)
711
+ elif isinstance(module, (GPTSanJapaneseModel, GPTSanJapaneseForConditionalGeneration)):
712
+ # Mesh TensorFlow embeddings initialization
713
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
714
+ module.final_logits_bias.data.normal_(mean=0.0, std=factor * 1.0)
715
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
716
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
717
+ elif isinstance(module, GPTSanJapaneseDenseActDense):
718
+ # Mesh TensorFlow FF initialization
719
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
720
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
721
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
722
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
723
+ module.wi.bias.data.zero_()
724
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
725
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
726
+ module.wo.bias.data.zero_()
727
+ elif isinstance(module, GPTSanJapaneseAttention):
728
+ # Multi-headed attention
729
+ d_model = self.config.d_model
730
+ key_value_proj_dim = self.config.d_model
731
+ n_heads = self.config.num_heads
732
+ module.k_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
733
+ module.v_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
734
+ module.q_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
735
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
736
+ elif isinstance(module, GPTSanJapaneseSparseMLP):
737
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
738
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
739
+ d_model = self.config.d_model
740
+ key_value_proj_dim = self.config.d_model
741
+ n_heads = self.config.num_heads
742
+ module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1)
743
+ for idx in range(self.config.num_experts):
744
+ module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
745
+ module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
746
+
747
+ def _shift_right(self, input_ids):
748
+ decoder_start_token_id = self.config.decoder_start_token_id
749
+ pad_token_id = self.config.pad_token_id
750
+
751
+ if decoder_start_token_id is None:
752
+ raise ValueError(
753
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
754
+ "See T5 docs for more information."
755
+ )
756
+
757
+ # shift inputs to the right
758
+ if is_torch_fx_proxy(input_ids):
759
+ # Item assignment is not supported natively for proxies.
760
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
761
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
762
+ else:
763
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
764
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
765
+ shifted_input_ids[..., 0] = decoder_start_token_id
766
+
767
+ if pad_token_id is None:
768
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
769
+ # replace possible -100 values in labels by `pad_token_id`
770
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
771
+
772
+ return shifted_input_ids
773
+
774
+
775
+ GPTSAN_JAPANESE_START_DOCSTRING = r"""
776
+
777
+ The [GPTSAN-japanese](https://github.com/tanreinama/GPTSAN) model was proposed in General-purpose Swich transformer
778
+ based Japanese language model
779
+
780
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
781
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
782
+ and behavior.
783
+
784
+ Parameters:
785
+ config ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
786
+ Initializing with a config file does not load the weights associated with the model, only the
787
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
788
+ """
789
+
790
+ GPTSAN_JAPANESE_INPUTS_DOCSTRING = r"""
791
+ Args:
792
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
793
+ Indices of input sequence tokens in the vocabulary. GPTSAN-japanese is a model that generates sentence
794
+ continuations or predicts tokens at mask positions. Special tokens required for inputs to the model are
795
+ automatically appended.
796
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
797
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
798
+
799
+ - 1 for tokens that are **not masked**,
800
+ - 0 for tokens that are **masked**.
801
+
802
+ [What are attention masks?](../glossary#attention-mask)
803
+ token_type_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
804
+ An input that masks the Prefix part in the Prefix-LM input. Mask values selected in `[0, 1]`:
805
+
806
+ - 1 for tokens that are **prefix** input,
807
+ - 0 for tokens that are **not-prefix** input.
808
+ spout (`torch.Tensor` of shape `(batch_size, config.d_spout)`):
809
+ This vector is transformed through an 8-layer FFN and can be used instead of `past_key_values`.
810
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
811
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
812
+
813
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
814
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
815
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
816
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
817
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
818
+ use_cache (`bool`, *optional*):
819
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
820
+ `past_key_values`).
821
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
822
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
823
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
824
+ model's internal embedding lookup matrix.
825
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
826
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
827
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
828
+ input (see `past_key_values`). This is useful if you want more control over how to convert
829
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
830
+ output_attentions (`bool`, *optional*):
831
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
832
+ tensors for more detail.
833
+ output_hidden_states (`bool`, *optional*):
834
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
835
+ more detail.
836
+ return_dict (`bool`, *optional*):
837
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
838
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
839
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
840
+ Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
841
+ """
842
+
843
+
844
+ @add_start_docstrings(
845
+ "The bare GPTSAN-japanese Model transformer outputting raw hidden-states without any specific head on top.",
846
+ GPTSAN_JAPANESE_START_DOCSTRING,
847
+ )
848
+ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
849
+ def __init__(self, config: GPTSanJapaneseConfig):
850
+ super().__init__(config)
851
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
852
+ self.config = copy.deepcopy(config)
853
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
854
+ self.last_project = nn.Linear(config.d_model, config.d_model, bias=True)
855
+ self.act = ACT2FN["swish"]
856
+
857
+ self.blocks = torch.nn.ModuleList([])
858
+ for _ in range(config.num_switch_layers):
859
+ self.blocks.append(GPTSanJapaneseBlock(config))
860
+ for _ in range(config.num_ext_layers):
861
+ self.blocks.append(GPTSanJapaneseBlock(config, ext_layer=True))
862
+
863
+ if config.num_ext_layers > 0:
864
+ self.extra_position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
865
+
866
+ if config.d_spout:
867
+ spouts = []
868
+ for _ in range(8):
869
+ spouts.append(nn.Linear(config.d_spout, config.d_spout, bias=False))
870
+ spouts.append(nn.Tanh())
871
+ spouts.append(nn.Linear(config.d_spout, config.num_layers * 2 * config.d_model, bias=False))
872
+ self.spout = nn.Sequential(*spouts)
873
+
874
+ self.post_init()
875
+
876
+ def get_input_embeddings(self):
877
+ return self.embed_tokens
878
+
879
+ def set_input_embeddings(self, new_embeddings):
880
+ self.embed_tokens = new_embeddings
881
+
882
+ @add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING)
883
+ def forward(
884
+ self,
885
+ input_ids: Optional[torch.LongTensor] = None,
886
+ attention_mask: Optional[torch.FloatTensor] = None,
887
+ token_type_ids: Optional[torch.FloatTensor] = None,
888
+ spout: Optional[torch.FloatTensor] = None,
889
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
890
+ head_mask: Optional[torch.FloatTensor] = None,
891
+ use_cache: Optional[bool] = False,
892
+ inputs_embeds: Optional[torch.FloatTensor] = None,
893
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
894
+ output_attentions: Optional[bool] = None,
895
+ output_hidden_states: Optional[bool] = None,
896
+ return_dict: Optional[bool] = None,
897
+ output_router_logits: Optional[bool] = None,
898
+ num_precontext: Optional[torch.LongTensor] = None,
899
+ ) -> Union[MoEModelOutputWithPastAndCrossAttentions, Tuple[torch.FloatTensor]]:
900
+ r"""
901
+ num_precontext (`torch.LongTensor` of shape `(batch_size,1)`):
902
+ length of `hybrid` input tokens in the input. Tokens up to this length refer to both front and back like
903
+ BERT, tokens after that refer only to front like GPT. see also:
904
+ https://github.com/tanreinama/GPTSAN/blob/main/report/model.md
905
+
906
+ Returns:
907
+ `MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns
908
+ MoEModelOutputWithPastAndCrossAttentions insted of tuple
909
+ """
910
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
911
+ device = self.position_embeddings.weight.device
912
+ if input_ids is None:
913
+ input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None
914
+ if inputs_embeds is not None:
915
+ raise NotImplementedError(
916
+ "GPTSanJapaneseModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
917
+ )
918
+ num_pasts_contexts = 0
919
+ num_batch = input_ids.shape[0]
920
+ pasts_or_spout_value = None
921
+ if past_key_values is not None:
922
+ num_pasts_contexts = past_key_values[0][0].shape[2]
923
+ elif self.config.d_spout and spout is not None:
924
+ # `spout` is a special input vector specific to GPTSAN
925
+ # This controls the output by projecting embedded information such as the class of sentences during learning.
926
+ # It should passed instead of the first past_key_value.
927
+ # See the original GPTSAN repository for details
928
+ num_pasts_contexts += 1
929
+
930
+ # If there is an attention_mask, increase first one for spout
931
+ if self.config.d_spout and spout is not None and attention_mask is not None:
932
+ attention_mask_with_spout = torch.ones(num_batch, attention_mask.shape[1] + 1, device=device)
933
+ attention_mask_with_spout[:, 1:] -= 1 - attention_mask # 1st token should be spout
934
+ attention_mask = attention_mask_with_spout # update attention_mask
935
+
936
+ if num_precontext is not None:
937
+ # `num_precontext` is the number of tokens that refer to each other in prefix-lm
938
+ # created per batch, so dimension of num_precontext should be [batch, 1]
939
+ if not (
940
+ len(num_precontext.shape) == 2 and num_precontext.shape[1] == 1
941
+ ): # num_precontext Should be [batch,1]
942
+ raise ValueError("num_precontext should be [batch, 1] size.")
943
+ num_precontext = torch.reshape(num_precontext, [-1])
944
+ else:
945
+ num_precontext = torch.zeros([num_batch]).int().to(device)
946
+
947
+ num_input_contexts = input_ids.shape[1]
948
+ num_output_contexts = num_input_contexts + num_pasts_contexts
949
+
950
+ hidden_states = self.embed_tokens(input_ids)
951
+
952
+ if past_key_values is not None:
953
+ pasts_or_spout_value = past_key_values
954
+ elif self.config.d_spout and spout is not None:
955
+ # Make vector from `spout` of GPTSAN to the same shape as past_key_values
956
+ pasts_or_spout_value = self.spout(spout) # projecting `spout` vector
957
+ pasts_or_spout_value = torch.reshape(
958
+ pasts_or_spout_value,
959
+ [
960
+ num_batch,
961
+ self.config.num_layers,
962
+ 2,
963
+ self.config.num_heads,
964
+ num_pasts_contexts,
965
+ self.config.d_model // self.config.num_heads,
966
+ ],
967
+ )
968
+ pasts_or_spout_value = torch.split(pasts_or_spout_value, [1] * self.config.num_layers, dim=1)
969
+ # make same shape as past_key_values
970
+ pasts_or_spout_value = tuple(
971
+ tuple([b.squeeze(1) for b in torch.split(a.squeeze(1), [1, 1], dim=1)]) for a in pasts_or_spout_value
972
+ )
973
+ else:
974
+ pasts_or_spout_value = [None] * self.config.num_layers
975
+
976
+ # Token position considering spout and pasts
977
+ token_position = torch.arange(num_input_contexts).to(device) + num_pasts_contexts
978
+
979
+ if attention_mask is None:
980
+ attention_mask = torch.ones(num_batch, num_input_contexts, device=device)
981
+
982
+ # positions for get position_embeddings
983
+ gather_position = (
984
+ (
985
+ torch.zeros((num_batch, self.config.d_model, num_input_contexts)).to(device)
986
+ + token_position.unsqueeze(0)
987
+ )
988
+ .transpose(1, 2)
989
+ .long()
990
+ )
991
+ # When padding with padding_side="left", zeros line up on the left side of attention_mask, so position_embeddings is shifted accordingly
992
+ gather_position -= (1 - attention_mask).argmin(dim=-1).unsqueeze(1).unsqueeze(2)
993
+ gather_position = torch.clip(gather_position, num_pasts_contexts, self.config.max_position_embeddings - 1)
994
+
995
+ # attention_mask is applied per batch
996
+ for i in range(num_batch):
997
+ hidden_states[i] += torch.gather(self.position_embeddings.weight, dim=0, index=gather_position[i])
998
+
999
+ # Create a mask to be used when making the prefix Input length of Prefix-LM variable
1000
+ causal_mask = (
1001
+ torch.tril(torch.ones((num_output_contexts, num_output_contexts), dtype=torch.uint8))
1002
+ .view(1, 1, num_output_contexts, num_output_contexts)
1003
+ .to(device)
1004
+ )
1005
+ prefix_lm_mask = causal_mask[:, :, -num_input_contexts:, :]
1006
+ if token_type_ids is not None:
1007
+ token_type_ids = token_type_ids.unsqueeze(1).unsqueeze(2)
1008
+ prefix_lm_mask = ((prefix_lm_mask + token_type_ids) > 0).float()
1009
+ # Marge prefix_lm_mask and attention_mask
1010
+ extended_attention_mask = prefix_lm_mask * attention_mask.unsqueeze(1).unsqueeze(2)
1011
+
1012
+ # Prepare head mask if needed
1013
+ if head_mask is not None:
1014
+ head_mask = self.get_head_mask(
1015
+ head_mask, self.config.num_switch_layers + self.config.num_ext_layers
1016
+ ) # n_layer x batch x n_heads x N x N
1017
+
1018
+ # outputs
1019
+ present_key_value_states = () if self.config.use_cache or use_cache else None
1020
+ all_hidden_states = () if self.config.output_hidden_states or output_hidden_states else None
1021
+ all_attentions = () if self.config.output_attentions or output_attentions else None
1022
+ all_router_probs = () if self.config.output_router_logits or output_router_logits else None
1023
+
1024
+ for layer, past in enumerate(pasts_or_spout_value):
1025
+ if layer == self.config.num_switch_layers:
1026
+ if self.config.num_ext_layers > 0:
1027
+ # extra_position_embeddings are extra position embeddings that are only created when extending the model with code from the original GPTSAN repository. Not used in the default model.
1028
+ # However, it is created when you create an additional layer and partially train only that location.
1029
+ # Therefore, convert_gptsan_tf_checkpoint_to_pytorch.py is used when converting and loading models created in the original GPTSAN repository.
1030
+ for i in range(num_batch):
1031
+ hidden_states[i] += torch.gather(
1032
+ self.extra_position_embeddings.weight, dim=0, index=gather_position[i]
1033
+ )
1034
+
1035
+ output_router_tuple = (
1036
+ self.config.output_router_logits or output_router_logits
1037
+ ) and layer < self.config.num_switch_layers
1038
+ block_output = self.blocks[layer](
1039
+ hidden_states=hidden_states,
1040
+ past_key_value=past,
1041
+ attention_mask=extended_attention_mask,
1042
+ head_mask=head_mask,
1043
+ use_cache=self.config.use_cache or use_cache,
1044
+ output_attentions=self.config.output_attentions or output_attentions,
1045
+ output_router_tuple=output_router_tuple,
1046
+ )
1047
+
1048
+ outpos = 0
1049
+ hidden_states = block_output[outpos]
1050
+ if self.config.output_hidden_states or output_hidden_states:
1051
+ all_hidden_states += (hidden_states,)
1052
+ if self.config.use_cache or use_cache:
1053
+ outpos += 1
1054
+ present = block_output[outpos]
1055
+ present_key_value_states += (present,)
1056
+ if self.config.output_attentions or output_attentions:
1057
+ outpos += 1
1058
+ attention_probs = block_output[outpos]
1059
+ all_attentions += (attention_probs,)
1060
+ if output_router_tuple:
1061
+ outpos += 1
1062
+ router_tuple = block_output[outpos]
1063
+ all_router_probs.append(router_tuple[0])
1064
+
1065
+ hidden_states = self.last_project(hidden_states)
1066
+ hidden_states = self.act(hidden_states)
1067
+
1068
+ if self.config.output_hidden_states or output_hidden_states:
1069
+ all_hidden_states = all_hidden_states + (hidden_states,)
1070
+
1071
+ if not return_dict:
1072
+ return tuple(
1073
+ v
1074
+ for v in [
1075
+ hidden_states,
1076
+ present_key_value_states,
1077
+ all_hidden_states,
1078
+ all_attentions,
1079
+ all_router_probs,
1080
+ ]
1081
+ if v is not None
1082
+ )
1083
+
1084
+ return MoEModelOutputWithPastAndCrossAttentions(
1085
+ last_hidden_state=hidden_states,
1086
+ past_key_values=present_key_value_states,
1087
+ hidden_states=all_hidden_states,
1088
+ attentions=all_attentions,
1089
+ router_probs=all_router_probs,
1090
+ )
1091
+
1092
+
1093
+ @add_start_docstrings(
1094
+ "The bare GPTSAN-japanese Model with a language modeling head.",
1095
+ GPTSAN_JAPANESE_START_DOCSTRING,
1096
+ )
1097
+ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
1098
+ _tied_weights_keys = ["lm_head.weight"]
1099
+
1100
+ def __init__(self, config: GPTSanJapaneseConfig):
1101
+ super().__init__(config)
1102
+ self.model = GPTSanJapaneseModel(config)
1103
+ self.register_buffer("final_logits_bias", torch.zeros([1, config.vocab_size]))
1104
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1105
+ if not self.config.torchscript:
1106
+ self.lm_head.weight = self.model.embed_tokens.weight
1107
+
1108
+ @add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING)
1109
+ def forward(
1110
+ self,
1111
+ input_ids: Optional[torch.LongTensor] = None,
1112
+ attention_mask: Optional[torch.FloatTensor] = None,
1113
+ token_type_ids: Optional[torch.FloatTensor] = None,
1114
+ spout: Optional[torch.FloatTensor] = None,
1115
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1116
+ head_mask: Optional[torch.FloatTensor] = None,
1117
+ use_cache: Optional[bool] = False,
1118
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1119
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1120
+ output_attentions: Optional[bool] = None,
1121
+ output_hidden_states: Optional[bool] = None,
1122
+ return_dict: Optional[bool] = None,
1123
+ output_router_logits: Optional[bool] = None,
1124
+ labels: Optional[torch.LongTensor] = None,
1125
+ ) -> Union[Tuple[torch.FloatTensor], MoECausalLMOutputWithPast]:
1126
+ r"""
1127
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1128
+ Labels for computing the sequence classification loss. Indices should be in `[-100, 0, ...,
1129
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1130
+ labels in `[0, ..., config.vocab_size]`
1131
+
1132
+ Returns:
1133
+ `MoECausalLMOutputWithPast` or `tuple` if `return_dict` returns MoECausalLMOutputWithPast insted of tuple
1134
+
1135
+ Example:
1136
+
1137
+ Text Generation with regular LM Model
1138
+ ```python
1139
+ >>> from transformers import AutoModel, AutoTokenizer, trainer_utils
1140
+
1141
+ >>> device = "cuda"
1142
+ >>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
1143
+ >>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
1144
+ >>> x_token = tokenizer("織田信長は、", return_tensors="pt")
1145
+ >>> trainer_utils.set_seed(30)
1146
+ >>> input_ids = x_token.input_ids.to(device)
1147
+ >>> gen_token = model.generate(input_ids, max_new_tokens=50)
1148
+ >>> tokenizer.decode(gen_token[0])
1149
+ "織田信長は、政治・軍事の中枢まで掌握した政治家であり、日本史上類を見ない驚異的な軍事侵攻を続け..."
1150
+ ```
1151
+
1152
+ Text Generation with Prefix-LM Model
1153
+ ```python
1154
+ >>> from transformers import AutoModel, AutoTokenizer, trainer_utils
1155
+
1156
+ >>> device = "cuda"
1157
+ >>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
1158
+ >>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
1159
+ >>> x_token = tokenizer("", prefix_text="織田信長は、", return_tensors="pt")
1160
+ >>> trainer_utils.set_seed(30)
1161
+ >>> input_ids = x_token.input_ids.to(device)
1162
+ >>> token_type_ids = x_token.token_type_ids.to(device)
1163
+ >>> gen_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=50)
1164
+ >>> tokenizer.decode(gen_token[0])
1165
+ "織田信長は、政治・外交で数々の戦果を上げるが、1568年からは、いわゆる本能寺の変で細川晴元に暗殺される..."
1166
+ ```
1167
+
1168
+ Simultaneously Text Generation And Masked Language Model
1169
+ ```python
1170
+ >>> from transformers import AutoModel, AutoTokenizer, trainer_utils
1171
+
1172
+ >>> device = "cuda"
1173
+ >>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
1174
+ >>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
1175
+ >>> masked_sentence = "武田信玄は、<|inputmask|>時代ファンならぜひ押さえ<|inputmask|>きたい名将の一人。"
1176
+ >>> x_token = tokenizer("", prefix_text=masked_sentence, return_tensors="pt")
1177
+ >>> trainer_utils.set_seed(30)
1178
+ >>> input_ids = x_token.input_ids.to(device)
1179
+ >>> token_type_ids = x_token.token_type_ids.to(device)
1180
+ >>> out_lm_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=50)
1181
+ >>> out_mlm_token = model(input_ids, token_type_ids=token_type_ids).logits.argmax(axis=-1)
1182
+ >>> tokenizer.decode(out_mlm_token[0])
1183
+ "武田信玄は、戦国時代ファンならぜひ押さえておきたい名将の一人。"
1184
+
1185
+ >>> tokenizer.decode(out_lm_token[0][input_ids.shape[1] :])
1186
+ "武田氏の三代に渡った武田家のひとり\n甲斐市に住む、日本史上最大の戦国大名。..."
1187
+ ```"""
1188
+ SEG_TOKEN = self.config.separator_token_id
1189
+ use_cache = use_cache or self.config.use_cache
1190
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1191
+ model_return_dict = True
1192
+ num_precontext = None
1193
+ if input_ids is not None:
1194
+ num_batch = input_ids.shape[0]
1195
+ num_precontext = torch.zeros([num_batch]).int().to(input_ids.device)
1196
+ where_separators = torch.where(input_ids == SEG_TOKEN)
1197
+ num_precontext[where_separators[0]] += where_separators[1]
1198
+ num_precontext = num_precontext.unsqueeze(1)
1199
+
1200
+ outputs = self.model(
1201
+ input_ids,
1202
+ attention_mask,
1203
+ token_type_ids,
1204
+ spout,
1205
+ past_key_values,
1206
+ head_mask,
1207
+ use_cache,
1208
+ inputs_embeds,
1209
+ decoder_inputs_embeds,
1210
+ output_attentions,
1211
+ output_hidden_states,
1212
+ model_return_dict,
1213
+ output_router_logits,
1214
+ num_precontext,
1215
+ )
1216
+
1217
+ lm_logits = self.lm_head(outputs[0])
1218
+ if lm_logits.shape[-1] == self.final_logits_bias.shape[-1]:
1219
+ lm_logits = lm_logits + self.final_logits_bias
1220
+
1221
+ loss = None
1222
+ z_loss = None
1223
+ router_probs = None
1224
+ aux_loss = None
1225
+ if labels is not None:
1226
+ # move labels to correct device to enable model parallelism
1227
+ labels = labels.to(lm_logits.device)
1228
+
1229
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
1230
+
1231
+ if output_router_logits:
1232
+ # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
1233
+ router_logits, expert_indexes = self._unpack_router_logits(outputs.router_probs)
1234
+ z_loss = router_z_loss_func(router_logits)
1235
+ router_probs = nn.Softmax(dim=-1)(router_logits)
1236
+ aux_loss = load_balancing_loss_func(router_probs, expert_indexes)
1237
+
1238
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1239
+
1240
+ if not return_dict:
1241
+ return tuple(
1242
+ v
1243
+ for v in [
1244
+ loss,
1245
+ lm_logits,
1246
+ outputs.past_key_values,
1247
+ outputs.hidden_states,
1248
+ outputs.router_probs,
1249
+ z_loss,
1250
+ aux_loss,
1251
+ ]
1252
+ if v is not None
1253
+ )
1254
+
1255
+ return MoECausalLMOutputWithPast(
1256
+ loss=loss,
1257
+ logits=lm_logits,
1258
+ past_key_values=outputs.past_key_values,
1259
+ hidden_states=outputs.hidden_states,
1260
+ attentions=outputs.attentions,
1261
+ router_logits=outputs.router_probs,
1262
+ z_loss=z_loss,
1263
+ aux_loss=aux_loss,
1264
+ )
1265
+
1266
+ def prepare_inputs_for_generation(
1267
+ self,
1268
+ input_ids: torch.LongTensor,
1269
+ attention_mask: torch.FloatTensor,
1270
+ token_type_ids: Optional[torch.FloatTensor] = None,
1271
+ spout: Optional[Union[List, torch.FloatTensor]] = None,
1272
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1273
+ **kwargs,
1274
+ ):
1275
+ if isinstance(spout, list):
1276
+ spout = torch.tensor(spout).float()
1277
+ if input_ids is not None:
1278
+ spout = spout.to(input_ids.device)
1279
+ if past_key_values is not None:
1280
+ return {
1281
+ "input_ids": input_ids[:, -1:] if input_ids is not None else None,
1282
+ "attention_mask": attention_mask,
1283
+ "token_type_ids": token_type_ids[:, -1:] if token_type_ids is not None else None,
1284
+ "spout": spout,
1285
+ "past_key_values": past_key_values,
1286
+ }
1287
+ return {
1288
+ "input_ids": input_ids,
1289
+ "attention_mask": attention_mask,
1290
+ "token_type_ids": token_type_ids,
1291
+ "spout": spout,
1292
+ "past_key_values": None,
1293
+ }
1294
+
1295
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1296
+ return self._shift_right(labels)
1297
+
1298
+ def resize_token_embeddings(
1299
+ self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
1300
+ ) -> nn.Embedding:
1301
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
1302
+ self._resize_final_logits_bias(new_embeddings.weight.shape[0])
1303
+ return new_embeddings
1304
+
1305
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
1306
+ old_num_tokens = self.final_logits_bias.shape[-1]
1307
+ if new_num_tokens <= old_num_tokens:
1308
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
1309
+ else:
1310
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1311
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1312
+ self.register_buffer("final_logits_bias", new_bias)
1313
+
1314
+ def get_input_embeddings(self):
1315
+ return self.model.get_input_embeddings()
1316
+
1317
+ def set_input_embeddings(self, new_embeddings):
1318
+ self.model.set_input_embeddings(new_embeddings)
1319
+
1320
+ def set_output_embeddings(self, new_embeddings):
1321
+ self.lm_head = new_embeddings
1322
+
1323
+ def get_output_embeddings(self):
1324
+ return self.lm_head
1325
+
1326
+ def _unpack_router_logits(self, router_outputs):
1327
+ total_router_logits = []
1328
+ total_expert_indexes = []
1329
+ for router_output in router_outputs:
1330
+ if len(router_output[0].shape) > 1:
1331
+ router_logits, expert_indexes = router_output
1332
+ total_router_logits.append(router_logits)
1333
+ total_expert_indexes.append(expert_indexes)
1334
+ return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
1335
+
1336
+
1337
+ __all__ = ["GPTSanJapaneseForConditionalGeneration", "GPTSanJapaneseModel", "GPTSanJapanesePreTrainedModel"]
docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for GPTSANJapanese."""
16
+
17
+ import collections
18
+ import json
19
+ import os
20
+ import re
21
+ import sys
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+
26
+ from ....tokenization_utils import PreTrainedTokenizer
27
+ from ....tokenization_utils_base import (
28
+ BatchEncoding,
29
+ PreTokenizedInput,
30
+ PreTokenizedInputPair,
31
+ TextInput,
32
+ TextInputPair,
33
+ TruncationStrategy,
34
+ )
35
+ from ....utils import PaddingStrategy, logging
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"}
41
+
42
+
43
+ def load_vocab_and_emoji(vocab_file, emoji_file):
44
+ """Loads a vocabulary file and emoji file into a dictionary."""
45
+ with open(emoji_file, "r", encoding="utf-8") as f:
46
+ emoji = json.loads(f.read())
47
+
48
+ vocab = collections.OrderedDict()
49
+ raw_vocab = collections.OrderedDict()
50
+ ids_to_tokens = collections.OrderedDict()
51
+ with open(vocab_file, "r", encoding="utf-8") as f:
52
+ token = f.readlines()
53
+ token = [[t.rstrip("\n")] if (t == ",\n" or "," not in t) else t.rstrip("\n").split(",") for t in token]
54
+ for idx, b in enumerate(token):
55
+ ids_to_tokens[idx] = b
56
+ raw_vocab[",".join(b)] = idx
57
+ for wd in b:
58
+ vocab[wd] = idx
59
+
60
+ return vocab, raw_vocab, ids_to_tokens, emoji
61
+
62
+
63
+ class GPTSanJapaneseTokenizer(PreTrainedTokenizer):
64
+ """
65
+ This tokenizer is based on GPTNeoXJapaneseTokenizer and has the following modifications
66
+ - Decoding byte0~byte255 tokens correctly
67
+ - Added bagofword token handling
68
+ - Return token_type_ids for Prefix-LM model
69
+ The bagofword token represents a repetition of the previous token and is converted to 3 consecutive tokens when
70
+ decoding In addition, the original Japanese special Sub-Word-Encoding has been released in this repository
71
+ (https://github.com/tanreinama/Japanese-BPEEncoder_V2). The token_type_ids is a mask indicating the prefix input
72
+ position of the Prefix-LM model. To specify a prefix position, specify a prefix input for prefix_text, or specify a
73
+ sentence of the prefix part and the part after it as a text pair of batch input.
74
+
75
+ Example:
76
+
77
+ ```python
78
+ >>> from transformers import GPTSanJapaneseTokenizer
79
+
80
+ >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
81
+ >>> # You can confirm both 慶応 and 慶應 are encoded to 17750
82
+ >>> tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"]
83
+ [35993, 35998, 34347, 31459, 30647, 31448, 25, 30659, 35729, 35676, 32417, 30647, 17750, 35589, 17750, 35590, 321, 1281]
84
+
85
+ >>> # Both 慶応 and 慶應 are decoded to 慶応
86
+ >>> tokenizer.decode(tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"])
87
+ '吾輩は猫である🐯。実は慶応(慶応)大学出身'
88
+ ```
89
+
90
+ Example for Prefix-LM:
91
+
92
+ ```python
93
+ >>> from transformers import GPTSanJapaneseTokenizer
94
+
95
+ >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
96
+ >>> tokenizer("実は慶応(慶應)大学出身", prefix_text="吾輩は猫である🐯。")["input_ids"]
97
+ [35993, 34347, 31459, 30647, 31448, 25, 30659, 35729, 35676, 35998, 32417, 30647, 17750, 35589, 17750, 35590, 321, 1281]
98
+
99
+ >>> # Mask for Prefix-LM inputs
100
+ >>> tokenizer("実は慶応(慶應)大学出身", prefix_text="吾輩は猫である🐯。")["token_type_ids"]
101
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
102
+ ```
103
+
104
+ Example for batch encode:
105
+
106
+ ```python
107
+ >>> from transformers import GPTSanJapaneseTokenizer
108
+
109
+ >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
110
+ >>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["input_ids"]
111
+ [[35993, 35998, 8640, 25948, 35993, 35998, 30647, 35675, 35999, 35999], [35993, 35998, 10382, 9868, 35993, 35998, 30646, 9459, 30646, 35675]]
112
+
113
+ >>> # Mask for Prefix-LM inputs
114
+ >>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["token_type_ids"]
115
+ [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
116
+
117
+ >>> # Mask for padding
118
+ >>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["attention_mask"]
119
+ [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
120
+ ```
121
+
122
+ Args:
123
+ vocab_file (`str`):
124
+ File containing the vocabulary.
125
+ emoji_file (`str`):
126
+ File containing the emoji.
127
+ unk_token (`str`, *optional*, defaults to `"<|nottoken|>"`):
128
+ The token used for unknown charactor
129
+ pad_token (`str`, *optional*, defaults to `"<|separator|>"`):
130
+ The token used for padding
131
+ bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`):
132
+ The beginning of sequence token.
133
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
134
+ The end of sequence token.
135
+ sep_token (`str`, *optional*, defaults to `"<|segmenter|>"`):
136
+ A special token to separate token to prefix part and general input part.
137
+ do_clean_text (`bool`, *optional*, defaults to `False`):
138
+ Whether or not to clean text for URL, EMAIL, TEL, Japanese DATE and Japanese PRICE.
139
+ """
140
+
141
+ vocab_files_names = VOCAB_FILES_NAMES
142
+ model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
143
+
144
+ def __init__(
145
+ self,
146
+ vocab_file,
147
+ emoji_file,
148
+ unk_token="<|nottoken|>",
149
+ pad_token="<|separator|>",
150
+ bos_token="<|startoftext|>",
151
+ eos_token="<|endoftext|>",
152
+ sep_token="<|segmenter|>",
153
+ do_clean_text=False,
154
+ **kwargs,
155
+ ):
156
+ if not os.path.isfile(vocab_file):
157
+ raise ValueError(
158
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
159
+ " model use `tokenizer = GPTSanJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
160
+ )
161
+ if not os.path.isfile(emoji_file):
162
+ raise ValueError(
163
+ f"Can't find a emoji file at path '{emoji_file}'. To load the emoji information from a Google"
164
+ " pretrained model use `tokenizer = GPTSanJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
165
+ )
166
+ self.do_clean_text = do_clean_text
167
+ self.vocab, self.raw_vocab, self.ids_to_tokens, self.emoji = load_vocab_and_emoji(vocab_file, emoji_file)
168
+ self.subword_tokenizer = SubWordJapaneseTokenizer(
169
+ vocab=self.vocab, ids_to_tokens=self.ids_to_tokens, emoji=self.emoji
170
+ )
171
+
172
+ super().__init__(
173
+ unk_token=unk_token,
174
+ pad_token=pad_token,
175
+ bos_token=bos_token,
176
+ eos_token=eos_token,
177
+ sep_token=sep_token,
178
+ do_clean_text=do_clean_text,
179
+ **kwargs,
180
+ )
181
+
182
+ @property
183
+ def vocab_size(self):
184
+ # self.vocab contains support for character fluctuation unique to Japanese, and has a large number of vocab
185
+ return len(self.raw_vocab)
186
+
187
+ def get_vocab(self):
188
+ return dict(self.raw_vocab, **self.added_tokens_encoder)
189
+
190
+ def _tokenize(self, text):
191
+ return self.subword_tokenizer.tokenize(text, clean=self.do_clean_text)
192
+
193
+ def _convert_token_to_id(self, token):
194
+ """Converts a token (str) in an id using the vocab."""
195
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
196
+
197
+ def _convert_id_to_token(self, index):
198
+ """Converts an index (integer) in a token (str) using the vocab."""
199
+ return self.subword_tokenizer.convert_id_to_token(index)
200
+
201
+ def convert_tokens_to_string(self, tokens):
202
+ """Converts a sequence of tokens (string) in a single string."""
203
+ words = []
204
+ byte_tokens = []
205
+ for word in tokens:
206
+ if word[:6] == "<|byte" and word[-2:] == "|>":
207
+ byte_tokens.append(int(word[6:-2]))
208
+ else:
209
+ if len(byte_tokens) > 0:
210
+ words.append(bytearray(byte_tokens).decode("utf-8", errors="replace"))
211
+ byte_tokens = []
212
+ if word[:7] == "<|emoji" and word[-2:] == "|>":
213
+ words.append(self.emoji["emoji_inv"][word])
214
+ elif word == "<SP>":
215
+ words.append(" ")
216
+ elif word == "<BR>":
217
+ words.append("\n")
218
+ elif word == "<TAB>":
219
+ words.append("\t")
220
+ elif word == "<BLOCK>":
221
+ words.append("▀")
222
+ elif word == "<KIGOU>":
223
+ words.append("ǀ")
224
+ elif word == "<U2000U2BFF>":
225
+ words.append("‖")
226
+ elif word == "<|bagoftoken|>":
227
+ if len(words) > 0:
228
+ words.append(words[-1])
229
+ words.append(words[-1])
230
+ words.append(words[-1])
231
+ elif word.startswith("<|") and word.endswith("|>"):
232
+ words.append("")
233
+ else:
234
+ words.append(word)
235
+ if len(byte_tokens) > 0:
236
+ words.append(bytearray(byte_tokens).decode("utf-8", errors="replace"))
237
+ text = "".join(words)
238
+ return text
239
+
240
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
241
+ index = 0
242
+ if os.path.isdir(save_directory):
243
+ vocab_file = os.path.join(
244
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
245
+ )
246
+ emoji_file = os.path.join(
247
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["emoji_file"]
248
+ )
249
+ else:
250
+ vocab_file = (
251
+ (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["vocab_file"]
252
+ )
253
+ emoji_file = (
254
+ (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["emoji_file"]
255
+ )
256
+ with open(vocab_file, "w", encoding="utf-8") as writer:
257
+ for token_index, token in self.ids_to_tokens.items():
258
+ if index != token_index:
259
+ logger.warning(
260
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
261
+ " Please check that the vocabulary is not corrupted!"
262
+ )
263
+ index = token_index
264
+ writer.write(",".join(token) + "\n")
265
+ index += 1
266
+ with open(emoji_file, "w", encoding="utf-8") as writer:
267
+ json.dump(self.emoji, writer)
268
+ return vocab_file, emoji_file
269
+
270
+ def create_token_type_ids_from_sequences(
271
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
272
+ ) -> List[int]:
273
+ # docstyle-ignore
274
+ """
275
+ The tokenizer returns token_type_ids as separators between the Prefix part and the rest.
276
+ token_type_ids is 1 for the Prefix part and 0 for the rest of the token.
277
+
278
+ Example:
279
+ ```python
280
+ >>> from transformers import GPTSanJapaneseTokenizer
281
+
282
+ >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
283
+ >>> x_token = tokenizer("アイウエ")
284
+ >>> # input_ids: | SOT | SEG | ア | イ | ウ | エ |
285
+ >>> # token_type_ids: | 1 | 0 | 0 | 0 | 0 | 0 |
286
+
287
+ >>> x_token = tokenizer("", prefix_text="アイウエ")
288
+ >>> # input_ids: | SOT | ア | イ | ウ | エ | SEG |
289
+ >>> # token_type_ids: | 1 | 1 | 1 | 1 | 1 | 0 |
290
+
291
+ >>> x_token = tokenizer("ウエ", prefix_text="アイ")
292
+ >>> # input_ids: | SOT | ア | イ | SEG | ウ | エ |
293
+ >>> # token_type_ids: | 1 | 1 | 1 | 0 | 0 | 0 |
294
+ ```"""
295
+ prefix_len = 0
296
+ if self.sep_token in self.vocab:
297
+ segid = self.vocab[self.sep_token]
298
+ if segid in token_ids_0:
299
+ prefix_len = token_ids_0.index(segid)
300
+ if token_ids_1 is None:
301
+ total_len = len(token_ids_0)
302
+ else:
303
+ total_len = len(token_ids_0 + token_ids_1)
304
+ return prefix_len * [1] + (total_len - prefix_len) * [0]
305
+
306
+ def prepare_for_tokenization(self, text, prefix_text=None, add_sep_token=None, **kwargs):
307
+ # GPTSAN inserts extra SEP tokens in Prefix-LM in addition to SOT for text generation.
308
+ # SOT at the beginning of the text, and SEP at the separator between the Prefix part and the rest.
309
+ if add_sep_token is None:
310
+ add_sep_token = self.sep_token not in text # If insert un-prefix position explicitly
311
+ prepared = self.bos_token if self.bos_token in self.vocab else ""
312
+ prepared += prefix_text if prefix_text is not None else ""
313
+ if add_sep_token:
314
+ prepared += self.sep_token if self.sep_token in self.vocab else ""
315
+ prepared += text
316
+ return (prepared, kwargs)
317
+
318
+ def _batch_encode_plus(
319
+ self,
320
+ batch_text_or_text_pairs: Union[
321
+ List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair]
322
+ ],
323
+ add_special_tokens: bool = True,
324
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
325
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
326
+ max_length: Optional[int] = None,
327
+ stride: int = 0,
328
+ is_split_into_words: bool = False,
329
+ pad_to_multiple_of: Optional[int] = None,
330
+ return_tensors: Optional[str] = None,
331
+ return_token_type_ids: Optional[bool] = None,
332
+ return_attention_mask: Optional[bool] = None,
333
+ return_overflowing_tokens: bool = False,
334
+ return_special_tokens_mask: bool = False,
335
+ return_offsets_mapping: bool = False,
336
+ return_length: bool = False,
337
+ verbose: bool = True,
338
+ **kwargs,
339
+ ) -> BatchEncoding:
340
+ # This tokenizer converts input text pairs into Prefix input and subsequent input
341
+ if isinstance(batch_text_or_text_pairs[0], tuple) or isinstance(tuple(batch_text_or_text_pairs[0]), list):
342
+ # As a single text with an explicit un-prefix position
343
+ batch_prefix_texts = []
344
+ for pref, txt in batch_text_or_text_pairs:
345
+ batch_prefix_texts.append(pref + self.sep_token + txt)
346
+ batch_text_or_text_pairs = batch_prefix_texts
347
+
348
+ return super()._batch_encode_plus(
349
+ batch_text_or_text_pairs,
350
+ add_special_tokens,
351
+ padding_strategy,
352
+ truncation_strategy,
353
+ max_length,
354
+ stride,
355
+ is_split_into_words,
356
+ pad_to_multiple_of,
357
+ return_tensors,
358
+ return_token_type_ids,
359
+ return_attention_mask,
360
+ return_overflowing_tokens,
361
+ return_special_tokens_mask,
362
+ return_offsets_mapping,
363
+ return_length,
364
+ verbose,
365
+ **kwargs,
366
+ )
367
+
368
+
369
+ class SubWordJapaneseTokenizer:
370
+ """
371
+ This tokenizer is based on GPTNeoXJapaneseTokenizer and has the following modifications
372
+ - Decoding byte0~byte255 tokens correctly
373
+ - Added bagofword token handling
374
+
375
+ https://github.com/tanreinama/Japanese-BPEEncoder_V2 This tokenizer class is under MIT Lisence according to the
376
+ original repository.
377
+
378
+ MIT License
379
+
380
+ Copyright (c) 2020 tanreinama
381
+
382
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
383
+ documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
384
+ rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
385
+ permit persons to whom the Software is furnished to do so, subject to the following conditions:
386
+
387
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of
388
+ the Software.
389
+
390
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
391
+ THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
392
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
393
+ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
394
+ SOFTWARE.
395
+ """
396
+
397
+ def __init__(self, vocab, ids_to_tokens, emoji):
398
+ self.vocab = vocab # same as swe
399
+ self.ids_to_tokens = ids_to_tokens # same as bpe
400
+ self.emoji = emoji
401
+ self.maxlen = np.max([len(w) for w in self.vocab.keys()])
402
+ self.content_repatter1 = re.compile(r"(https?|ftp)(:\/\/[-_\.!~*\'()a-zA-Z0-9;\/?:\@&=\+$,%#]+)")
403
+ self.content_repatter2 = re.compile(r"[A-Za-z0-9\._+]*@[\-_0-9A-Za-z]+(\.[A-Za-z]+)*")
404
+ self.content_repatter3 = re.compile(r"[\(]{0,1}[0-9]{2,4}[\)\-\(]{0,1}[0-9]{2,4}[\)\-]{0,1}[0-9]{3,4}")
405
+ self.content_repatter4 = re.compile(
406
+ r"([12]\d{3}[/\-年])*(0?[1-9]|1[0-2])[/\-月]((0?[1-9]|[12][0-9]|3[01])日?)*(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
407
+ )
408
+ self.content_repatter5 = re.compile(
409
+ r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
410
+ )
411
+ # The original version of this regex displays catastrophic backtracking behaviour. We avoid this using
412
+ # possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly
413
+ # different regex that should generally have the same behaviour in most non-pathological cases.
414
+ if sys.version_info >= (3, 11):
415
+ self.content_repatter6 = re.compile(
416
+ r"(?:\d,\d{3}|[\d億])*+"
417
+ r"(?:\d,\d{3}|[\d万])*+"
418
+ r"(?:\d,\d{3}|[\d千])*+"
419
+ r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
420
+ r"(?:\(税込\)|\(税抜\)|\+tax)*"
421
+ )
422
+ else:
423
+ self.content_repatter6 = re.compile(
424
+ r"(?:\d,\d{3}|[\d億万千])*"
425
+ r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
426
+ r"(?:\(税込\)|\(税抜\)|\+tax)*"
427
+ )
428
+ keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
429
+ blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
430
+ self.content_trans1 = str.maketrans(dict.fromkeys(keisen + blocks, "<BLOCK>"))
431
+
432
+ def __len__(self):
433
+ return len(self.ids_to_tokens)
434
+
435
+ def clean_text(self, content):
436
+ content = self.content_repatter1.sub("<URL>", content)
437
+ content = self.content_repatter2.sub("<EMAIL>", content)
438
+ content = self.content_repatter3.sub("<TEL>", content)
439
+ content = self.content_repatter4.sub("<DATE>", content)
440
+ content = self.content_repatter5.sub("<DATE>", content)
441
+ content = self.content_repatter6.sub("<PRICE>", content)
442
+ content = content.translate(self.content_trans1)
443
+ while "<BLOCK><BLOCK>" in content:
444
+ content = content.replace("<BLOCK><BLOCK>", "<BLOCK>")
445
+ return content
446
+
447
+ def tokenize(self, text, clean=False):
448
+ text = text.replace(" ", "<SP>")
449
+ text = text.replace(" ", "<SP>")
450
+ text = text.replace("\r\n", "<BR>")
451
+ text = text.replace("\n", "<BR>")
452
+ text = text.replace("\r", "<BR>")
453
+ text = text.replace("\t", "<TAB>")
454
+ text = text.replace("—", "ー")
455
+ text = text.replace("−", "ー")
456
+ for k, v in self.emoji["emoji"].items():
457
+ if k in text:
458
+ text = text.replace(k, v)
459
+ if clean:
460
+ text = self.clean_text(text)
461
+
462
+ def check_simbol(x):
463
+ e = x.encode()
464
+ if len(x) == 1 and len(e) == 2:
465
+ c = (int(e[0]) << 8) + int(e[1])
466
+ if (
467
+ (c >= 0xC2A1 and c <= 0xC2BF)
468
+ or (c >= 0xC780 and c <= 0xC783)
469
+ or (c >= 0xCAB9 and c <= 0xCBBF)
470
+ or (c >= 0xCC80 and c <= 0xCDA2)
471
+ ):
472
+ return True
473
+ return False
474
+
475
+ def checku2e(x):
476
+ e = x.encode()
477
+ if len(x) == 1 and len(e) == 3:
478
+ c = (int(e[0]) << 16) + (int(e[1]) << 8) + int(e[2])
479
+ if c >= 0xE28080 and c <= 0xE2B07F:
480
+ return True
481
+ return False
482
+
483
+ pos = 0
484
+ result = []
485
+ while pos < len(text):
486
+ end = min(len(text), pos + self.maxlen + 1) if text[pos] == "<" else pos + 3
487
+ candidates = [] # (token_id, token, pos)
488
+ for e in range(end, pos, -1):
489
+ wd = text[pos:e]
490
+ if wd in self.vocab:
491
+ if wd[0] == "<" and len(wd) > 2:
492
+ candidates = [(self.vocab[wd], wd, e)]
493
+ break
494
+ else:
495
+ candidates.append((self.vocab[wd], wd, e))
496
+ if len(candidates) > 0:
497
+ # the smallest token_id is adopted
498
+ _, wd, e = sorted(candidates, key=lambda x: x[0])[0]
499
+ result.append(wd)
500
+ pos = e
501
+ else:
502
+ end = pos + 1
503
+ wd = text[pos:end]
504
+ if check_simbol(wd):
505
+ result.append("<KIGOU>")
506
+ elif checku2e(wd):
507
+ result.append("<U2000U2BFF>")
508
+ else:
509
+ for i in wd.encode("utf-8"):
510
+ result.append("<|byte%d|>" % i)
511
+ pos = end
512
+ return result
513
+
514
+ def convert_id_to_token(self, index):
515
+ return self.ids_to_tokens[index][0]
516
+
517
+
518
+ __all__ = ["GPTSanJapaneseTokenizer"]
docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_graphormer import *
22
+ from .modeling_graphormer import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation and HuggingFace
2
+ # Licensed under the MIT License.
3
+
4
+ import cython
5
+
6
+ cimport numpy
7
+ from cython.parallel cimport parallel, prange
8
+
9
+ import numpy as np
10
+
11
+
12
+ # Reduce this number if matrices are too big for large graphs
13
+ UNREACHABLE_NODE_DISTANCE = 510
14
+
15
+ def floyd_warshall(adjacency_matrix):
16
+ """
17
+ Applies the Floyd-Warshall algorithm to the adjacency matrix, to compute the
18
+ shortest paths distance between all nodes, up to UNREACHABLE_NODE_DISTANCE.
19
+ """
20
+ (nrows, ncols) = adjacency_matrix.shape
21
+ assert nrows == ncols
22
+ cdef unsigned int n = nrows
23
+
24
+ adj_mat_copy = adjacency_matrix.astype(np.int32, order='C', casting='safe', copy=True)
25
+ assert adj_mat_copy.flags['C_CONTIGUOUS']
26
+ cdef numpy.ndarray[numpy.int32_t, ndim=2, mode='c'] M = adj_mat_copy
27
+ cdef numpy.ndarray[numpy.int32_t, ndim=2, mode='c'] path = -1 * np.ones([n, n], dtype=np.int32)
28
+
29
+ cdef unsigned int i, j, k
30
+ cdef numpy.int32_t M_ij, M_ik, cost_ikkj
31
+ cdef numpy.int32_t* M_ptr = &M[0,0]
32
+ cdef numpy.int32_t* M_i_ptr
33
+ cdef numpy.int32_t* M_k_ptr
34
+
35
+ # set unreachable nodes distance to UNREACHABLE_NODE_DISTANCE
36
+ for i in range(n):
37
+ for j in range(n):
38
+ if i == j:
39
+ M[i][j] = 0
40
+ elif M[i][j] == 0:
41
+ M[i][j] = UNREACHABLE_NODE_DISTANCE
42
+
43
+ # floyed algo
44
+ for k in range(n):
45
+ M_k_ptr = M_ptr + n*k
46
+ for i in range(n):
47
+ M_i_ptr = M_ptr + n*i
48
+ M_ik = M_i_ptr[k]
49
+ for j in range(n):
50
+ cost_ikkj = M_ik + M_k_ptr[j]
51
+ M_ij = M_i_ptr[j]
52
+ if M_ij > cost_ikkj:
53
+ M_i_ptr[j] = cost_ikkj
54
+ path[i][j] = k
55
+
56
+ # set unreachable path to UNREACHABLE_NODE_DISTANCE
57
+ for i in range(n):
58
+ for j in range(n):
59
+ if M[i][j] >= UNREACHABLE_NODE_DISTANCE:
60
+ path[i][j] = UNREACHABLE_NODE_DISTANCE
61
+ M[i][j] = UNREACHABLE_NODE_DISTANCE
62
+
63
+ return M, path
64
+
65
+
66
+ def get_all_edges(path, i, j):
67
+ """
68
+ Recursive function to compute all possible paths between two nodes from the graph adjacency matrix.
69
+ """
70
+ cdef int k = path[i][j]
71
+ if k == -1:
72
+ return []
73
+ else:
74
+ return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j)
75
+
76
+
77
+ def gen_edge_input(max_dist, path, edge_feat):
78
+ """
79
+ Generates the full edge feature and adjacency matrix.
80
+ Shape: num_nodes * num_nodes * max_distance_between_nodes * num_edge_features
81
+ Dim 1 is the input node, dim 2 the output node of the edge, dim 3 the depth of the edge, dim 4 the feature
82
+ """
83
+ (nrows, ncols) = path.shape
84
+ assert nrows == ncols
85
+ cdef unsigned int n = nrows
86
+ cdef unsigned int max_dist_copy = max_dist
87
+
88
+ path_copy = path.astype(long, order='C', casting='safe', copy=True)
89
+ edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True)
90
+ assert path_copy.flags['C_CONTIGUOUS']
91
+ assert edge_feat_copy.flags['C_CONTIGUOUS']
92
+
93
+ cdef numpy.ndarray[numpy.int32_t, ndim=4, mode='c'] edge_fea_all = -1 * np.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=np.int32)
94
+ cdef unsigned int i, j, k, num_path, cur
95
+
96
+ for i in range(n):
97
+ for j in range(n):
98
+ if i == j:
99
+ continue
100
+ if path_copy[i][j] == UNREACHABLE_NODE_DISTANCE:
101
+ continue
102
+ path = [i] + get_all_edges(path_copy, i, j) + [j]
103
+ num_path = len(path) - 1
104
+ for k in range(num_path):
105
+ edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :]
106
+
107
+ return edge_fea_all
docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation and HuggingFace
2
+ # Licensed under the MIT License.
3
+
4
+ from typing import Any, Dict, List, Mapping
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from ....utils import is_cython_available, requires_backends
10
+
11
+
12
+ if is_cython_available():
13
+ import pyximport
14
+
15
+ pyximport.install(setup_args={"include_dirs": np.get_include()})
16
+ from . import algos_graphormer # noqa E402
17
+
18
+
19
+ def convert_to_single_emb(x, offset: int = 512):
20
+ feature_num = x.shape[1] if len(x.shape) > 1 else 1
21
+ feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
22
+ x = x + feature_offset
23
+ return x
24
+
25
+
26
+ def preprocess_item(item, keep_features=True):
27
+ requires_backends(preprocess_item, ["cython"])
28
+
29
+ if keep_features and "edge_attr" in item.keys(): # edge_attr
30
+ edge_attr = np.asarray(item["edge_attr"], dtype=np.int64)
31
+ else:
32
+ edge_attr = np.ones((len(item["edge_index"][0]), 1), dtype=np.int64) # same embedding for all
33
+
34
+ if keep_features and "node_feat" in item.keys(): # input_nodes
35
+ node_feature = np.asarray(item["node_feat"], dtype=np.int64)
36
+ else:
37
+ node_feature = np.ones((item["num_nodes"], 1), dtype=np.int64) # same embedding for all
38
+
39
+ edge_index = np.asarray(item["edge_index"], dtype=np.int64)
40
+
41
+ input_nodes = convert_to_single_emb(node_feature) + 1
42
+ num_nodes = item["num_nodes"]
43
+
44
+ if len(edge_attr.shape) == 1:
45
+ edge_attr = edge_attr[:, None]
46
+ attn_edge_type = np.zeros([num_nodes, num_nodes, edge_attr.shape[-1]], dtype=np.int64)
47
+ attn_edge_type[edge_index[0], edge_index[1]] = convert_to_single_emb(edge_attr) + 1
48
+
49
+ # node adj matrix [num_nodes, num_nodes] bool
50
+ adj = np.zeros([num_nodes, num_nodes], dtype=bool)
51
+ adj[edge_index[0], edge_index[1]] = True
52
+
53
+ shortest_path_result, path = algos_graphormer.floyd_warshall(adj)
54
+ max_dist = np.amax(shortest_path_result)
55
+
56
+ input_edges = algos_graphormer.gen_edge_input(max_dist, path, attn_edge_type)
57
+ attn_bias = np.zeros([num_nodes + 1, num_nodes + 1], dtype=np.single) # with graph token
58
+
59
+ # combine
60
+ item["input_nodes"] = input_nodes + 1 # we shift all indices by one for padding
61
+ item["attn_bias"] = attn_bias
62
+ item["attn_edge_type"] = attn_edge_type
63
+ item["spatial_pos"] = shortest_path_result.astype(np.int64) + 1 # we shift all indices by one for padding
64
+ item["in_degree"] = np.sum(adj, axis=1).reshape(-1) + 1 # we shift all indices by one for padding
65
+ item["out_degree"] = item["in_degree"] # for undirected graph
66
+ item["input_edges"] = input_edges + 1 # we shift all indices by one for padding
67
+ if "labels" not in item:
68
+ item["labels"] = item["y"]
69
+
70
+ return item
71
+
72
+
73
+ class GraphormerDataCollator:
74
+ def __init__(self, spatial_pos_max=20, on_the_fly_processing=False):
75
+ if not is_cython_available():
76
+ raise ImportError("Graphormer preprocessing needs Cython (pyximport)")
77
+
78
+ self.spatial_pos_max = spatial_pos_max
79
+ self.on_the_fly_processing = on_the_fly_processing
80
+
81
+ def __call__(self, features: List[dict]) -> Dict[str, Any]:
82
+ if self.on_the_fly_processing:
83
+ features = [preprocess_item(i) for i in features]
84
+
85
+ if not isinstance(features[0], Mapping):
86
+ features = [vars(f) for f in features]
87
+ batch = {}
88
+
89
+ max_node_num = max(len(i["input_nodes"]) for i in features)
90
+ node_feat_size = len(features[0]["input_nodes"][0])
91
+ edge_feat_size = len(features[0]["attn_edge_type"][0][0])
92
+ max_dist = max(len(i["input_edges"][0][0]) for i in features)
93
+ edge_input_size = len(features[0]["input_edges"][0][0][0])
94
+ batch_size = len(features)
95
+
96
+ batch["attn_bias"] = torch.zeros(batch_size, max_node_num + 1, max_node_num + 1, dtype=torch.float)
97
+ batch["attn_edge_type"] = torch.zeros(batch_size, max_node_num, max_node_num, edge_feat_size, dtype=torch.long)
98
+ batch["spatial_pos"] = torch.zeros(batch_size, max_node_num, max_node_num, dtype=torch.long)
99
+ batch["in_degree"] = torch.zeros(batch_size, max_node_num, dtype=torch.long)
100
+ batch["input_nodes"] = torch.zeros(batch_size, max_node_num, node_feat_size, dtype=torch.long)
101
+ batch["input_edges"] = torch.zeros(
102
+ batch_size, max_node_num, max_node_num, max_dist, edge_input_size, dtype=torch.long
103
+ )
104
+
105
+ for ix, f in enumerate(features):
106
+ for k in ["attn_bias", "attn_edge_type", "spatial_pos", "in_degree", "input_nodes", "input_edges"]:
107
+ f[k] = torch.tensor(f[k])
108
+
109
+ if len(f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max]) > 0:
110
+ f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max] = float("-inf")
111
+
112
+ batch["attn_bias"][ix, : f["attn_bias"].shape[0], : f["attn_bias"].shape[1]] = f["attn_bias"]
113
+ batch["attn_edge_type"][ix, : f["attn_edge_type"].shape[0], : f["attn_edge_type"].shape[1], :] = f[
114
+ "attn_edge_type"
115
+ ]
116
+ batch["spatial_pos"][ix, : f["spatial_pos"].shape[0], : f["spatial_pos"].shape[1]] = f["spatial_pos"]
117
+ batch["in_degree"][ix, : f["in_degree"].shape[0]] = f["in_degree"]
118
+ batch["input_nodes"][ix, : f["input_nodes"].shape[0], :] = f["input_nodes"]
119
+ batch["input_edges"][
120
+ ix, : f["input_edges"].shape[0], : f["input_edges"].shape[1], : f["input_edges"].shape[2], :
121
+ ] = f["input_edges"]
122
+
123
+ batch["out_degree"] = batch["in_degree"]
124
+
125
+ sample = features[0]["labels"]
126
+ if len(sample) == 1: # one task
127
+ if isinstance(sample[0], float): # regression
128
+ batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
129
+ else: # binary classification
130
+ batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
131
+ else: # multi task classification, left to float to keep the NaNs
132
+ batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0))
133
+
134
+ return batch
docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft, clefourrier and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Graphormer model configuration"""
16
+
17
+ from typing import Optional
18
+
19
+ from ....configuration_utils import PretrainedConfig
20
+ from ....utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class GraphormerConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`~GraphormerModel`]. It is used to instantiate an
29
+ Graphormer model according to the specified arguments, defining the model architecture. Instantiating a
30
+ configuration with the defaults will yield a similar configuration to that of the Graphormer
31
+ [graphormer-base-pcqm4mv1](https://huggingface.co/graphormer-base-pcqm4mv1) architecture.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ num_classes (`int`, *optional*, defaults to 1):
39
+ Number of target classes or labels, set to n for binary classification of n tasks.
40
+ num_atoms (`int`, *optional*, defaults to 512*9):
41
+ Number of node types in the graphs.
42
+ num_edges (`int`, *optional*, defaults to 512*3):
43
+ Number of edges types in the graph.
44
+ num_in_degree (`int`, *optional*, defaults to 512):
45
+ Number of in degrees types in the input graphs.
46
+ num_out_degree (`int`, *optional*, defaults to 512):
47
+ Number of out degrees types in the input graphs.
48
+ num_edge_dis (`int`, *optional*, defaults to 128):
49
+ Number of edge dis in the input graphs.
50
+ multi_hop_max_dist (`int`, *optional*, defaults to 20):
51
+ Maximum distance of multi hop edges between two nodes.
52
+ spatial_pos_max (`int`, *optional*, defaults to 1024):
53
+ Maximum distance between nodes in the graph attention bias matrices, used during preprocessing and
54
+ collation.
55
+ edge_type (`str`, *optional*, defaults to multihop):
56
+ Type of edge relation chosen.
57
+ max_nodes (`int`, *optional*, defaults to 512):
58
+ Maximum number of nodes which can be parsed for the input graphs.
59
+ share_input_output_embed (`bool`, *optional*, defaults to `False`):
60
+ Shares the embedding layer between encoder and decoder - careful, True is not implemented.
61
+ num_layers (`int`, *optional*, defaults to 12):
62
+ Number of layers.
63
+ embedding_dim (`int`, *optional*, defaults to 768):
64
+ Dimension of the embedding layer in encoder.
65
+ ffn_embedding_dim (`int`, *optional*, defaults to 768):
66
+ Dimension of the "intermediate" (often named feed-forward) layer in encoder.
67
+ num_attention_heads (`int`, *optional*, defaults to 32):
68
+ Number of attention heads in the encoder.
69
+ self_attention (`bool`, *optional*, defaults to `True`):
70
+ Model is self attentive (False not implemented).
71
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
72
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
73
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
74
+ dropout (`float`, *optional*, defaults to 0.1):
75
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
76
+ attention_dropout (`float`, *optional*, defaults to 0.1):
77
+ The dropout probability for the attention weights.
78
+ activation_dropout (`float`, *optional*, defaults to 0.1):
79
+ The dropout probability for the activation of the linear transformer layer.
80
+ layerdrop (`float`, *optional*, defaults to 0.0):
81
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
82
+ for more details.
83
+ bias (`bool`, *optional*, defaults to `True`):
84
+ Uses bias in the attention module - unsupported at the moment.
85
+ embed_scale(`float`, *optional*, defaults to None):
86
+ Scaling factor for the node embeddings.
87
+ num_trans_layers_to_freeze (`int`, *optional*, defaults to 0):
88
+ Number of transformer layers to freeze.
89
+ encoder_normalize_before (`bool`, *optional*, defaults to `False`):
90
+ Normalize features before encoding the graph.
91
+ pre_layernorm (`bool`, *optional*, defaults to `False`):
92
+ Apply layernorm before self attention and the feed forward network. Without this, post layernorm will be
93
+ used.
94
+ apply_graphormer_init (`bool`, *optional*, defaults to `False`):
95
+ Apply a custom graphormer initialisation to the model before training.
96
+ freeze_embeddings (`bool`, *optional*, defaults to `False`):
97
+ Freeze the embedding layer, or train it along the model.
98
+ encoder_normalize_before (`bool`, *optional*, defaults to `False`):
99
+ Apply the layer norm before each encoder block.
100
+ q_noise (`float`, *optional*, defaults to 0.0):
101
+ Amount of quantization noise (see "Training with Quantization Noise for Extreme Model Compression"). (For
102
+ more detail, see fairseq's documentation on quant_noise).
103
+ qn_block_size (`int`, *optional*, defaults to 8):
104
+ Size of the blocks for subsequent quantization with iPQ (see q_noise).
105
+ kdim (`int`, *optional*, defaults to None):
106
+ Dimension of the key in the attention, if different from the other values.
107
+ vdim (`int`, *optional*, defaults to None):
108
+ Dimension of the value in the attention, if different from the other values.
109
+ use_cache (`bool`, *optional*, defaults to `True`):
110
+ Whether or not the model should return the last key/values attentions (not used by all models).
111
+ traceable (`bool`, *optional*, defaults to `False`):
112
+ Changes return value of the encoder's inner_state to stacked tensors.
113
+
114
+ Example:
115
+ ```python
116
+ >>> from transformers import GraphormerForGraphClassification, GraphormerConfig
117
+
118
+ >>> # Initializing a Graphormer graphormer-base-pcqm4mv2 style configuration
119
+ >>> configuration = GraphormerConfig()
120
+
121
+ >>> # Initializing a model from the graphormer-base-pcqm4mv1 style configuration
122
+ >>> model = GraphormerForGraphClassification(configuration)
123
+
124
+ >>> # Accessing the model configuration
125
+ >>> configuration = model.config
126
+ ```
127
+ """
128
+
129
+ model_type = "graphormer"
130
+ keys_to_ignore_at_inference = ["past_key_values"]
131
+
132
+ def __init__(
133
+ self,
134
+ num_classes: int = 1,
135
+ num_atoms: int = 512 * 9,
136
+ num_edges: int = 512 * 3,
137
+ num_in_degree: int = 512,
138
+ num_out_degree: int = 512,
139
+ num_spatial: int = 512,
140
+ num_edge_dis: int = 128,
141
+ multi_hop_max_dist: int = 5, # sometimes is 20
142
+ spatial_pos_max: int = 1024,
143
+ edge_type: str = "multi_hop",
144
+ max_nodes: int = 512,
145
+ share_input_output_embed: bool = False,
146
+ num_hidden_layers: int = 12,
147
+ embedding_dim: int = 768,
148
+ ffn_embedding_dim: int = 768,
149
+ num_attention_heads: int = 32,
150
+ dropout: float = 0.1,
151
+ attention_dropout: float = 0.1,
152
+ activation_dropout: float = 0.1,
153
+ layerdrop: float = 0.0,
154
+ encoder_normalize_before: bool = False,
155
+ pre_layernorm: bool = False,
156
+ apply_graphormer_init: bool = False,
157
+ activation_fn: str = "gelu",
158
+ embed_scale: Optional[float] = None,
159
+ freeze_embeddings: bool = False,
160
+ num_trans_layers_to_freeze: int = 0,
161
+ traceable: bool = False,
162
+ q_noise: float = 0.0,
163
+ qn_block_size: int = 8,
164
+ kdim: Optional[int] = None,
165
+ vdim: Optional[int] = None,
166
+ bias: bool = True,
167
+ self_attention: bool = True,
168
+ pad_token_id=0,
169
+ bos_token_id=1,
170
+ eos_token_id=2,
171
+ **kwargs,
172
+ ):
173
+ self.num_classes = num_classes
174
+ self.num_atoms = num_atoms
175
+ self.num_in_degree = num_in_degree
176
+ self.num_out_degree = num_out_degree
177
+ self.num_edges = num_edges
178
+ self.num_spatial = num_spatial
179
+ self.num_edge_dis = num_edge_dis
180
+ self.edge_type = edge_type
181
+ self.multi_hop_max_dist = multi_hop_max_dist
182
+ self.spatial_pos_max = spatial_pos_max
183
+ self.max_nodes = max_nodes
184
+ self.num_hidden_layers = num_hidden_layers
185
+ self.embedding_dim = embedding_dim
186
+ self.hidden_size = embedding_dim
187
+ self.ffn_embedding_dim = ffn_embedding_dim
188
+ self.num_attention_heads = num_attention_heads
189
+ self.dropout = dropout
190
+ self.attention_dropout = attention_dropout
191
+ self.activation_dropout = activation_dropout
192
+ self.layerdrop = layerdrop
193
+ self.encoder_normalize_before = encoder_normalize_before
194
+ self.pre_layernorm = pre_layernorm
195
+ self.apply_graphormer_init = apply_graphormer_init
196
+ self.activation_fn = activation_fn
197
+ self.embed_scale = embed_scale
198
+ self.freeze_embeddings = freeze_embeddings
199
+ self.num_trans_layers_to_freeze = num_trans_layers_to_freeze
200
+ self.share_input_output_embed = share_input_output_embed
201
+ self.traceable = traceable
202
+ self.q_noise = q_noise
203
+ self.qn_block_size = qn_block_size
204
+
205
+ # These parameters are here for future extensions
206
+ # atm, the model only supports self attention
207
+ self.kdim = kdim
208
+ self.vdim = vdim
209
+ self.self_attention = self_attention
210
+ self.bias = bias
211
+
212
+ super().__init__(
213
+ pad_token_id=pad_token_id,
214
+ bos_token_id=bos_token_id,
215
+ eos_token_id=eos_token_id,
216
+ **kwargs,
217
+ )
218
+
219
+
220
+ __all__ = ["GraphormerConfig"]
docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft, clefourrier The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Graphormer model."""
16
+
17
+ import math
18
+ from typing import Iterable, Iterator, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+
24
+ from ....activations import ACT2FN
25
+ from ....modeling_outputs import (
26
+ BaseModelOutputWithNoAttention,
27
+ SequenceClassifierOutput,
28
+ )
29
+ from ....modeling_utils import PreTrainedModel
30
+ from ....utils import logging
31
+ from .configuration_graphormer import GraphormerConfig
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ _CHECKPOINT_FOR_DOC = "graphormer-base-pcqm4mv1"
37
+ _CONFIG_FOR_DOC = "GraphormerConfig"
38
+
39
+
40
+ def quant_noise(module: nn.Module, p: float, block_size: int):
41
+ """
42
+ From:
43
+ https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/quant_noise.py
44
+
45
+ Wraps modules and applies quantization noise to the weights for subsequent quantization with Iterative Product
46
+ Quantization as described in "Training with Quantization Noise for Extreme Model Compression"
47
+
48
+ Args:
49
+ - module: nn.Module
50
+ - p: amount of Quantization Noise
51
+ - block_size: size of the blocks for subsequent quantization with iPQ
52
+
53
+ Remarks:
54
+ - Module weights must have the right sizes wrt the block size
55
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
56
+ - For more detail on how to quantize by blocks with convolutional weights, see "And the Bit Goes Down:
57
+ Revisiting the Quantization of Neural Networks"
58
+ - We implement the simplest form of noise here as stated in the paper which consists in randomly dropping
59
+ blocks
60
+ """
61
+
62
+ # if no quantization noise, don't register hook
63
+ if p <= 0:
64
+ return module
65
+
66
+ # supported modules
67
+ if not isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)):
68
+ raise NotImplementedError("Module unsupported for quant_noise.")
69
+
70
+ # test whether module.weight has the right sizes wrt block_size
71
+ is_conv = module.weight.ndim == 4
72
+
73
+ # 2D matrix
74
+ if not is_conv:
75
+ if module.weight.size(1) % block_size != 0:
76
+ raise AssertionError("Input features must be a multiple of block sizes")
77
+
78
+ # 4D matrix
79
+ else:
80
+ # 1x1 convolutions
81
+ if module.kernel_size == (1, 1):
82
+ if module.in_channels % block_size != 0:
83
+ raise AssertionError("Input channels must be a multiple of block sizes")
84
+ # regular convolutions
85
+ else:
86
+ k = module.kernel_size[0] * module.kernel_size[1]
87
+ if k % block_size != 0:
88
+ raise AssertionError("Kernel size must be a multiple of block size")
89
+
90
+ def _forward_pre_hook(mod, input):
91
+ # no noise for evaluation
92
+ if mod.training:
93
+ if not is_conv:
94
+ # gather weight and sizes
95
+ weight = mod.weight
96
+ in_features = weight.size(1)
97
+ out_features = weight.size(0)
98
+
99
+ # split weight matrix into blocks and randomly drop selected blocks
100
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
101
+ mask.bernoulli_(p)
102
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
103
+
104
+ else:
105
+ # gather weight and sizes
106
+ weight = mod.weight
107
+ in_channels = mod.in_channels
108
+ out_channels = mod.out_channels
109
+
110
+ # split weight matrix into blocks and randomly drop selected blocks
111
+ if mod.kernel_size == (1, 1):
112
+ mask = torch.zeros(
113
+ int(in_channels // block_size * out_channels),
114
+ device=weight.device,
115
+ )
116
+ mask.bernoulli_(p)
117
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
118
+ else:
119
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
120
+ mask.bernoulli_(p)
121
+ mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
122
+
123
+ # scale weights and apply mask
124
+ mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
125
+ s = 1 / (1 - p)
126
+ mod.weight.data = s * weight.masked_fill(mask, 0)
127
+
128
+ module.register_forward_pre_hook(_forward_pre_hook)
129
+ return module
130
+
131
+
132
+ class LayerDropModuleList(nn.ModuleList):
133
+ """
134
+ From:
135
+ https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/layer_drop.py
136
+ A LayerDrop implementation based on [`torch.nn.ModuleList`]. LayerDrop as described in
137
+ https://arxiv.org/abs/1909.11556.
138
+
139
+ We refresh the choice of which layers to drop every time we iterate over the LayerDropModuleList instance. During
140
+ evaluation we always iterate over all layers.
141
+
142
+ Usage:
143
+
144
+ ```python
145
+ layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
146
+ for layer in layers: # this might iterate over layers 1 and 3
147
+ x = layer(x)
148
+ for layer in layers: # this might iterate over all layers
149
+ x = layer(x)
150
+ for layer in layers: # this might not iterate over any layers
151
+ x = layer(x)
152
+ ```
153
+
154
+ Args:
155
+ p (float): probability of dropping out each layer
156
+ modules (iterable, optional): an iterable of modules to add
157
+ """
158
+
159
+ def __init__(self, p: float, modules: Optional[Iterable[nn.Module]] = None):
160
+ super().__init__(modules)
161
+ self.p = p
162
+
163
+ def __iter__(self) -> Iterator[nn.Module]:
164
+ dropout_probs = torch.empty(len(self)).uniform_()
165
+ for i, m in enumerate(super().__iter__()):
166
+ if not self.training or (dropout_probs[i] > self.p):
167
+ yield m
168
+
169
+
170
+ class GraphormerGraphNodeFeature(nn.Module):
171
+ """
172
+ Compute node features for each node in the graph.
173
+ """
174
+
175
+ def __init__(self, config: GraphormerConfig):
176
+ super().__init__()
177
+ self.num_heads = config.num_attention_heads
178
+ self.num_atoms = config.num_atoms
179
+
180
+ self.atom_encoder = nn.Embedding(config.num_atoms + 1, config.hidden_size, padding_idx=config.pad_token_id)
181
+ self.in_degree_encoder = nn.Embedding(
182
+ config.num_in_degree, config.hidden_size, padding_idx=config.pad_token_id
183
+ )
184
+ self.out_degree_encoder = nn.Embedding(
185
+ config.num_out_degree, config.hidden_size, padding_idx=config.pad_token_id
186
+ )
187
+
188
+ self.graph_token = nn.Embedding(1, config.hidden_size)
189
+
190
+ def forward(
191
+ self,
192
+ input_nodes: torch.LongTensor,
193
+ in_degree: torch.LongTensor,
194
+ out_degree: torch.LongTensor,
195
+ ) -> torch.Tensor:
196
+ n_graph, n_node = input_nodes.size()[:2]
197
+
198
+ node_feature = ( # node feature + graph token
199
+ self.atom_encoder(input_nodes).sum(dim=-2) # [n_graph, n_node, n_hidden]
200
+ + self.in_degree_encoder(in_degree)
201
+ + self.out_degree_encoder(out_degree)
202
+ )
203
+
204
+ graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)
205
+
206
+ graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1)
207
+
208
+ return graph_node_feature
209
+
210
+
211
+ class GraphormerGraphAttnBias(nn.Module):
212
+ """
213
+ Compute attention bias for each head.
214
+ """
215
+
216
+ def __init__(self, config: GraphormerConfig):
217
+ super().__init__()
218
+ self.num_heads = config.num_attention_heads
219
+ self.multi_hop_max_dist = config.multi_hop_max_dist
220
+
221
+ # We do not change edge feature embedding learning, as edge embeddings are represented as a combination of the original features
222
+ # + shortest path
223
+ self.edge_encoder = nn.Embedding(config.num_edges + 1, config.num_attention_heads, padding_idx=0)
224
+
225
+ self.edge_type = config.edge_type
226
+ if self.edge_type == "multi_hop":
227
+ self.edge_dis_encoder = nn.Embedding(
228
+ config.num_edge_dis * config.num_attention_heads * config.num_attention_heads,
229
+ 1,
230
+ )
231
+
232
+ self.spatial_pos_encoder = nn.Embedding(config.num_spatial, config.num_attention_heads, padding_idx=0)
233
+
234
+ self.graph_token_virtual_distance = nn.Embedding(1, config.num_attention_heads)
235
+
236
+ def forward(
237
+ self,
238
+ input_nodes: torch.LongTensor,
239
+ attn_bias: torch.Tensor,
240
+ spatial_pos: torch.LongTensor,
241
+ input_edges: torch.LongTensor,
242
+ attn_edge_type: torch.LongTensor,
243
+ ) -> torch.Tensor:
244
+ n_graph, n_node = input_nodes.size()[:2]
245
+ graph_attn_bias = attn_bias.clone()
246
+ graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
247
+ 1, self.num_heads, 1, 1
248
+ ) # [n_graph, n_head, n_node+1, n_node+1]
249
+
250
+ # spatial pos
251
+ # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
252
+ spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
253
+ graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias
254
+
255
+ # reset spatial pos here
256
+ t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
257
+ graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
258
+ graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
259
+
260
+ # edge feature
261
+ if self.edge_type == "multi_hop":
262
+ spatial_pos_ = spatial_pos.clone()
263
+
264
+ spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
265
+ # set 1 to 1, input_nodes > 1 to input_nodes - 1
266
+ spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
267
+ if self.multi_hop_max_dist > 0:
268
+ spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
269
+ input_edges = input_edges[:, :, :, : self.multi_hop_max_dist, :]
270
+ # [n_graph, n_node, n_node, max_dist, n_head]
271
+
272
+ input_edges = self.edge_encoder(input_edges).mean(-2)
273
+ max_dist = input_edges.size(-2)
274
+ edge_input_flat = input_edges.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads)
275
+ edge_input_flat = torch.bmm(
276
+ edge_input_flat,
277
+ self.edge_dis_encoder.weight.reshape(-1, self.num_heads, self.num_heads)[:max_dist, :, :],
278
+ )
279
+ input_edges = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.num_heads).permute(
280
+ 1, 2, 3, 0, 4
281
+ )
282
+ input_edges = (input_edges.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2)
283
+ else:
284
+ # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
285
+ input_edges = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
286
+
287
+ graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + input_edges
288
+ graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset
289
+
290
+ return graph_attn_bias
291
+
292
+
293
+ class GraphormerMultiheadAttention(nn.Module):
294
+ """Multi-headed attention.
295
+
296
+ See "Attention Is All You Need" for more details.
297
+ """
298
+
299
+ def __init__(self, config: GraphormerConfig):
300
+ super().__init__()
301
+ self.embedding_dim = config.embedding_dim
302
+ self.kdim = config.kdim if config.kdim is not None else config.embedding_dim
303
+ self.vdim = config.vdim if config.vdim is not None else config.embedding_dim
304
+ self.qkv_same_dim = self.kdim == config.embedding_dim and self.vdim == config.embedding_dim
305
+
306
+ self.num_heads = config.num_attention_heads
307
+ self.attention_dropout_module = torch.nn.Dropout(p=config.attention_dropout, inplace=False)
308
+
309
+ self.head_dim = config.embedding_dim // config.num_attention_heads
310
+ if not (self.head_dim * config.num_attention_heads == self.embedding_dim):
311
+ raise AssertionError("The embedding_dim must be divisible by num_heads.")
312
+ self.scaling = self.head_dim**-0.5
313
+
314
+ self.self_attention = True # config.self_attention
315
+ if not (self.self_attention):
316
+ raise NotImplementedError("The Graphormer model only supports self attention for now.")
317
+ if self.self_attention and not self.qkv_same_dim:
318
+ raise AssertionError("Self-attention requires query, key and value to be of the same size.")
319
+
320
+ self.k_proj = quant_noise(
321
+ nn.Linear(self.kdim, config.embedding_dim, bias=config.bias),
322
+ config.q_noise,
323
+ config.qn_block_size,
324
+ )
325
+ self.v_proj = quant_noise(
326
+ nn.Linear(self.vdim, config.embedding_dim, bias=config.bias),
327
+ config.q_noise,
328
+ config.qn_block_size,
329
+ )
330
+ self.q_proj = quant_noise(
331
+ nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias),
332
+ config.q_noise,
333
+ config.qn_block_size,
334
+ )
335
+
336
+ self.out_proj = quant_noise(
337
+ nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias),
338
+ config.q_noise,
339
+ config.qn_block_size,
340
+ )
341
+
342
+ self.onnx_trace = False
343
+
344
+ def reset_parameters(self):
345
+ if self.qkv_same_dim:
346
+ # Empirically observed the convergence to be much better with
347
+ # the scaled initialization
348
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
349
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
350
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
351
+ else:
352
+ nn.init.xavier_uniform_(self.k_proj.weight)
353
+ nn.init.xavier_uniform_(self.v_proj.weight)
354
+ nn.init.xavier_uniform_(self.q_proj.weight)
355
+
356
+ nn.init.xavier_uniform_(self.out_proj.weight)
357
+ if self.out_proj.bias is not None:
358
+ nn.init.constant_(self.out_proj.bias, 0.0)
359
+
360
+ def forward(
361
+ self,
362
+ query: torch.LongTensor,
363
+ key: Optional[torch.Tensor],
364
+ value: Optional[torch.Tensor],
365
+ attn_bias: Optional[torch.Tensor],
366
+ key_padding_mask: Optional[torch.Tensor] = None,
367
+ need_weights: bool = True,
368
+ attn_mask: Optional[torch.Tensor] = None,
369
+ before_softmax: bool = False,
370
+ need_head_weights: bool = False,
371
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
372
+ """
373
+ Args:
374
+ key_padding_mask (Bytetorch.Tensor, optional): mask to exclude
375
+ keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s.
376
+ need_weights (bool, optional): return the attention weights,
377
+ averaged over heads (default: False).
378
+ attn_mask (Bytetorch.Tensor, optional): typically used to
379
+ implement causal attention, where the mask prevents the attention from looking forward in time
380
+ (default: None).
381
+ before_softmax (bool, optional): return the raw attention
382
+ weights and values before the attention softmax.
383
+ need_head_weights (bool, optional): return the attention
384
+ weights for each head. Implies *need_weights*. Default: return the average attention weights over all
385
+ heads.
386
+ """
387
+ if need_head_weights:
388
+ need_weights = True
389
+
390
+ tgt_len, bsz, embedding_dim = query.size()
391
+ src_len = tgt_len
392
+ if not (embedding_dim == self.embedding_dim):
393
+ raise AssertionError(
394
+ f"The query embedding dimension {embedding_dim} is not equal to the expected embedding_dim"
395
+ f" {self.embedding_dim}."
396
+ )
397
+ if not (list(query.size()) == [tgt_len, bsz, embedding_dim]):
398
+ raise AssertionError("Query size incorrect in Graphormer, compared to model dimensions.")
399
+
400
+ if key is not None:
401
+ src_len, key_bsz, _ = key.size()
402
+ if not torch.jit.is_scripting():
403
+ if (key_bsz != bsz) or (value is None) or not (src_len, bsz == value.shape[:2]):
404
+ raise AssertionError(
405
+ "The batch shape does not match the key or value shapes provided to the attention."
406
+ )
407
+
408
+ q = self.q_proj(query)
409
+ k = self.k_proj(query)
410
+ v = self.v_proj(query)
411
+
412
+ q *= self.scaling
413
+
414
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
415
+ if k is not None:
416
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
417
+ if v is not None:
418
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
419
+
420
+ if (k is None) or not (k.size(1) == src_len):
421
+ raise AssertionError("The shape of the key generated in the attention is incorrect")
422
+
423
+ # This is part of a workaround to get around fork/join parallelism
424
+ # not supporting Optional types.
425
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
426
+ key_padding_mask = None
427
+
428
+ if key_padding_mask is not None:
429
+ if key_padding_mask.size(0) != bsz or key_padding_mask.size(1) != src_len:
430
+ raise AssertionError(
431
+ "The shape of the generated padding mask for the key does not match expected dimensions."
432
+ )
433
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
434
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
435
+
436
+ if list(attn_weights.size()) != [bsz * self.num_heads, tgt_len, src_len]:
437
+ raise AssertionError("The attention weights generated do not match the expected dimensions.")
438
+
439
+ if attn_bias is not None:
440
+ attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
441
+
442
+ if attn_mask is not None:
443
+ attn_mask = attn_mask.unsqueeze(0)
444
+ attn_weights += attn_mask
445
+
446
+ if key_padding_mask is not None:
447
+ # don't attend to padding symbols
448
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
449
+ attn_weights = attn_weights.masked_fill(
450
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
451
+ )
452
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
453
+
454
+ if before_softmax:
455
+ return attn_weights, v
456
+
457
+ attn_weights_float = torch.nn.functional.softmax(attn_weights, dim=-1)
458
+ attn_weights = attn_weights_float.type_as(attn_weights)
459
+ attn_probs = self.attention_dropout_module(attn_weights)
460
+
461
+ if v is None:
462
+ raise AssertionError("No value generated")
463
+ attn = torch.bmm(attn_probs, v)
464
+ if list(attn.size()) != [bsz * self.num_heads, tgt_len, self.head_dim]:
465
+ raise AssertionError("The attention generated do not match the expected dimensions.")
466
+
467
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embedding_dim)
468
+ attn: torch.Tensor = self.out_proj(attn)
469
+
470
+ attn_weights = None
471
+ if need_weights:
472
+ attn_weights = attn_weights_float.contiguous().view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
473
+ if not need_head_weights:
474
+ # average attention weights over heads
475
+ attn_weights = attn_weights.mean(dim=0)
476
+
477
+ return attn, attn_weights
478
+
479
+ def apply_sparse_mask(self, attn_weights: torch.Tensor, tgt_len: int, src_len: int, bsz: int) -> torch.Tensor:
480
+ return attn_weights
481
+
482
+
483
+ class GraphormerGraphEncoderLayer(nn.Module):
484
+ def __init__(self, config: GraphormerConfig) -> None:
485
+ super().__init__()
486
+
487
+ # Initialize parameters
488
+ self.embedding_dim = config.embedding_dim
489
+ self.num_attention_heads = config.num_attention_heads
490
+ self.q_noise = config.q_noise
491
+ self.qn_block_size = config.qn_block_size
492
+ self.pre_layernorm = config.pre_layernorm
493
+
494
+ self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
495
+
496
+ self.activation_dropout_module = torch.nn.Dropout(p=config.activation_dropout, inplace=False)
497
+
498
+ # Initialize blocks
499
+ self.activation_fn = ACT2FN[config.activation_fn]
500
+ self.self_attn = GraphormerMultiheadAttention(config)
501
+
502
+ # layer norm associated with the self attention layer
503
+ self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
504
+
505
+ self.fc1 = self.build_fc(
506
+ self.embedding_dim,
507
+ config.ffn_embedding_dim,
508
+ q_noise=config.q_noise,
509
+ qn_block_size=config.qn_block_size,
510
+ )
511
+ self.fc2 = self.build_fc(
512
+ config.ffn_embedding_dim,
513
+ self.embedding_dim,
514
+ q_noise=config.q_noise,
515
+ qn_block_size=config.qn_block_size,
516
+ )
517
+
518
+ # layer norm associated with the position wise feed-forward NN
519
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
520
+
521
+ def build_fc(
522
+ self, input_dim: int, output_dim: int, q_noise: float, qn_block_size: int
523
+ ) -> Union[nn.Module, nn.Linear, nn.Embedding, nn.Conv2d]:
524
+ return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
525
+
526
+ def forward(
527
+ self,
528
+ input_nodes: torch.Tensor,
529
+ self_attn_bias: Optional[torch.Tensor] = None,
530
+ self_attn_mask: Optional[torch.Tensor] = None,
531
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
532
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
533
+ """
534
+ nn.LayerNorm is applied either before or after the self-attention/ffn modules similar to the original
535
+ Transformer implementation.
536
+ """
537
+ residual = input_nodes
538
+ if self.pre_layernorm:
539
+ input_nodes = self.self_attn_layer_norm(input_nodes)
540
+
541
+ input_nodes, attn = self.self_attn(
542
+ query=input_nodes,
543
+ key=input_nodes,
544
+ value=input_nodes,
545
+ attn_bias=self_attn_bias,
546
+ key_padding_mask=self_attn_padding_mask,
547
+ need_weights=False,
548
+ attn_mask=self_attn_mask,
549
+ )
550
+ input_nodes = self.dropout_module(input_nodes)
551
+ input_nodes = residual + input_nodes
552
+ if not self.pre_layernorm:
553
+ input_nodes = self.self_attn_layer_norm(input_nodes)
554
+
555
+ residual = input_nodes
556
+ if self.pre_layernorm:
557
+ input_nodes = self.final_layer_norm(input_nodes)
558
+ input_nodes = self.activation_fn(self.fc1(input_nodes))
559
+ input_nodes = self.activation_dropout_module(input_nodes)
560
+ input_nodes = self.fc2(input_nodes)
561
+ input_nodes = self.dropout_module(input_nodes)
562
+ input_nodes = residual + input_nodes
563
+ if not self.pre_layernorm:
564
+ input_nodes = self.final_layer_norm(input_nodes)
565
+
566
+ return input_nodes, attn
567
+
568
+
569
+ class GraphormerGraphEncoder(nn.Module):
570
+ def __init__(self, config: GraphormerConfig):
571
+ super().__init__()
572
+
573
+ self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
574
+ self.layerdrop = config.layerdrop
575
+ self.embedding_dim = config.embedding_dim
576
+ self.apply_graphormer_init = config.apply_graphormer_init
577
+ self.traceable = config.traceable
578
+
579
+ self.graph_node_feature = GraphormerGraphNodeFeature(config)
580
+ self.graph_attn_bias = GraphormerGraphAttnBias(config)
581
+
582
+ self.embed_scale = config.embed_scale
583
+
584
+ if config.q_noise > 0:
585
+ self.quant_noise = quant_noise(
586
+ nn.Linear(self.embedding_dim, self.embedding_dim, bias=False),
587
+ config.q_noise,
588
+ config.qn_block_size,
589
+ )
590
+ else:
591
+ self.quant_noise = None
592
+
593
+ if config.encoder_normalize_before:
594
+ self.emb_layer_norm = nn.LayerNorm(self.embedding_dim)
595
+ else:
596
+ self.emb_layer_norm = None
597
+
598
+ if config.pre_layernorm:
599
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
600
+
601
+ if self.layerdrop > 0.0:
602
+ self.layers = LayerDropModuleList(p=self.layerdrop)
603
+ else:
604
+ self.layers = nn.ModuleList([])
605
+ self.layers.extend([GraphormerGraphEncoderLayer(config) for _ in range(config.num_hidden_layers)])
606
+
607
+ # Apply initialization of model params after building the model
608
+ if config.freeze_embeddings:
609
+ raise NotImplementedError("Freezing embeddings is not implemented yet.")
610
+
611
+ for layer in range(config.num_trans_layers_to_freeze):
612
+ m = self.layers[layer]
613
+ if m is not None:
614
+ for p in m.parameters():
615
+ p.requires_grad = False
616
+
617
+ def forward(
618
+ self,
619
+ input_nodes: torch.LongTensor,
620
+ input_edges: torch.LongTensor,
621
+ attn_bias: torch.Tensor,
622
+ in_degree: torch.LongTensor,
623
+ out_degree: torch.LongTensor,
624
+ spatial_pos: torch.LongTensor,
625
+ attn_edge_type: torch.LongTensor,
626
+ perturb=None,
627
+ last_state_only: bool = False,
628
+ token_embeddings: Optional[torch.Tensor] = None,
629
+ attn_mask: Optional[torch.Tensor] = None,
630
+ ) -> Tuple[Union[torch.Tensor, List[torch.LongTensor]], torch.Tensor]:
631
+ # compute padding mask. This is needed for multi-head attention
632
+ data_x = input_nodes
633
+ n_graph, n_node = data_x.size()[:2]
634
+ padding_mask = (data_x[:, :, 0]).eq(0)
635
+ padding_mask_cls = torch.zeros(n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype)
636
+ padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1)
637
+
638
+ attn_bias = self.graph_attn_bias(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type)
639
+
640
+ if token_embeddings is not None:
641
+ input_nodes = token_embeddings
642
+ else:
643
+ input_nodes = self.graph_node_feature(input_nodes, in_degree, out_degree)
644
+
645
+ if perturb is not None:
646
+ input_nodes[:, 1:, :] += perturb
647
+
648
+ if self.embed_scale is not None:
649
+ input_nodes = input_nodes * self.embed_scale
650
+
651
+ if self.quant_noise is not None:
652
+ input_nodes = self.quant_noise(input_nodes)
653
+
654
+ if self.emb_layer_norm is not None:
655
+ input_nodes = self.emb_layer_norm(input_nodes)
656
+
657
+ input_nodes = self.dropout_module(input_nodes)
658
+
659
+ input_nodes = input_nodes.transpose(0, 1)
660
+
661
+ inner_states = []
662
+ if not last_state_only:
663
+ inner_states.append(input_nodes)
664
+
665
+ for layer in self.layers:
666
+ input_nodes, _ = layer(
667
+ input_nodes,
668
+ self_attn_padding_mask=padding_mask,
669
+ self_attn_mask=attn_mask,
670
+ self_attn_bias=attn_bias,
671
+ )
672
+ if not last_state_only:
673
+ inner_states.append(input_nodes)
674
+
675
+ graph_rep = input_nodes[0, :, :]
676
+
677
+ if last_state_only:
678
+ inner_states = [input_nodes]
679
+
680
+ if self.traceable:
681
+ return torch.stack(inner_states), graph_rep
682
+ else:
683
+ return inner_states, graph_rep
684
+
685
+
686
+ class GraphormerDecoderHead(nn.Module):
687
+ def __init__(self, embedding_dim: int, num_classes: int):
688
+ super().__init__()
689
+ """num_classes should be 1 for regression, or the number of classes for classification"""
690
+ self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
691
+ self.classifier = nn.Linear(embedding_dim, num_classes, bias=False)
692
+ self.num_classes = num_classes
693
+
694
+ def forward(self, input_nodes: torch.Tensor, **unused) -> torch.Tensor:
695
+ input_nodes = self.classifier(input_nodes)
696
+ input_nodes = input_nodes + self.lm_output_learned_bias
697
+ return input_nodes
698
+
699
+
700
+ class GraphormerPreTrainedModel(PreTrainedModel):
701
+ """
702
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
703
+ models.
704
+ """
705
+
706
+ config_class = GraphormerConfig
707
+ base_model_prefix = "graphormer"
708
+ main_input_name_nodes = "input_nodes"
709
+ main_input_name_edges = "input_edges"
710
+
711
+ def normal_(self, data: torch.Tensor):
712
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
713
+ # so that the RNG is consistent with and without FSDP
714
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
715
+
716
+ def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, GraphormerMultiheadAttention]):
717
+ """
718
+ Initialize the weights specific to the Graphormer Model.
719
+ """
720
+ if isinstance(module, nn.Linear):
721
+ self.normal_(module.weight.data)
722
+ if module.bias is not None:
723
+ module.bias.data.zero_()
724
+ if isinstance(module, nn.Embedding):
725
+ self.normal_(module.weight.data)
726
+ if module.padding_idx is not None:
727
+ module.weight.data[module.padding_idx].zero_()
728
+ if isinstance(module, GraphormerMultiheadAttention):
729
+ self.normal_(module.q_proj.weight.data)
730
+ self.normal_(module.k_proj.weight.data)
731
+ self.normal_(module.v_proj.weight.data)
732
+
733
+ def _init_weights(
734
+ self,
735
+ module: Union[
736
+ nn.Linear, nn.Conv2d, nn.Embedding, nn.LayerNorm, GraphormerMultiheadAttention, GraphormerGraphEncoder
737
+ ],
738
+ ):
739
+ """
740
+ Initialize the weights
741
+ """
742
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
743
+ # We might be missing part of the Linear init, dependant on the layer num
744
+ module.weight.data.normal_(mean=0.0, std=0.02)
745
+ if module.bias is not None:
746
+ module.bias.data.zero_()
747
+ elif isinstance(module, nn.Embedding):
748
+ module.weight.data.normal_(mean=0.0, std=0.02)
749
+ if module.padding_idx is not None:
750
+ module.weight.data[module.padding_idx].zero_()
751
+ elif isinstance(module, GraphormerMultiheadAttention):
752
+ module.q_proj.weight.data.normal_(mean=0.0, std=0.02)
753
+ module.k_proj.weight.data.normal_(mean=0.0, std=0.02)
754
+ module.v_proj.weight.data.normal_(mean=0.0, std=0.02)
755
+ module.reset_parameters()
756
+ elif isinstance(module, nn.LayerNorm):
757
+ module.bias.data.zero_()
758
+ module.weight.data.fill_(1.0)
759
+ elif isinstance(module, GraphormerGraphEncoder):
760
+ if module.apply_graphormer_init:
761
+ module.apply(self.init_graphormer_params)
762
+
763
+ elif isinstance(module, nn.LayerNorm):
764
+ module.bias.data.zero_()
765
+ module.weight.data.fill_(1.0)
766
+
767
+
768
+ class GraphormerModel(GraphormerPreTrainedModel):
769
+ """The Graphormer model is a graph-encoder model.
770
+
771
+ It goes from a graph to its representation. If you want to use the model for a downstream classification task, use
772
+ GraphormerForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine
773
+ this model with a downstream model of your choice, following the example in GraphormerForGraphClassification.
774
+ """
775
+
776
+ def __init__(self, config: GraphormerConfig):
777
+ super().__init__(config)
778
+ self.max_nodes = config.max_nodes
779
+
780
+ self.graph_encoder = GraphormerGraphEncoder(config)
781
+
782
+ self.share_input_output_embed = config.share_input_output_embed
783
+ self.lm_output_learned_bias = None
784
+
785
+ # Remove head is set to true during fine-tuning
786
+ self.load_softmax = not getattr(config, "remove_head", False)
787
+
788
+ self.lm_head_transform_weight = nn.Linear(config.embedding_dim, config.embedding_dim)
789
+ self.activation_fn = ACT2FN[config.activation_fn]
790
+ self.layer_norm = nn.LayerNorm(config.embedding_dim)
791
+
792
+ self.post_init()
793
+
794
+ def reset_output_layer_parameters(self):
795
+ self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
796
+
797
+ def forward(
798
+ self,
799
+ input_nodes: torch.LongTensor,
800
+ input_edges: torch.LongTensor,
801
+ attn_bias: torch.Tensor,
802
+ in_degree: torch.LongTensor,
803
+ out_degree: torch.LongTensor,
804
+ spatial_pos: torch.LongTensor,
805
+ attn_edge_type: torch.LongTensor,
806
+ perturb: Optional[torch.FloatTensor] = None,
807
+ masked_tokens: None = None,
808
+ return_dict: Optional[bool] = None,
809
+ **unused,
810
+ ) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]:
811
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
812
+
813
+ inner_states, graph_rep = self.graph_encoder(
814
+ input_nodes, input_edges, attn_bias, in_degree, out_degree, spatial_pos, attn_edge_type, perturb=perturb
815
+ )
816
+
817
+ # last inner state, then revert Batch and Graph len
818
+ input_nodes = inner_states[-1].transpose(0, 1)
819
+
820
+ # project masked tokens only
821
+ if masked_tokens is not None:
822
+ raise NotImplementedError
823
+
824
+ input_nodes = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(input_nodes)))
825
+
826
+ # project back to size of vocabulary
827
+ if self.share_input_output_embed and hasattr(self.graph_encoder.embed_tokens, "weight"):
828
+ input_nodes = torch.nn.functional.linear(input_nodes, self.graph_encoder.embed_tokens.weight)
829
+
830
+ if not return_dict:
831
+ return tuple(x for x in [input_nodes, inner_states] if x is not None)
832
+ return BaseModelOutputWithNoAttention(last_hidden_state=input_nodes, hidden_states=inner_states)
833
+
834
+ def max_nodes(self):
835
+ """Maximum output length supported by the encoder."""
836
+ return self.max_nodes
837
+
838
+
839
+ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
840
+ """
841
+ This model can be used for graph-level classification or regression tasks.
842
+
843
+ It can be trained on
844
+ - regression (by setting config.num_classes to 1); there should be one float-type label per graph
845
+ - one task classification (by setting config.num_classes to the number of classes); there should be one integer
846
+ label per graph
847
+ - binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list
848
+ of integer labels for each graph.
849
+ """
850
+
851
+ def __init__(self, config: GraphormerConfig):
852
+ super().__init__(config)
853
+ self.encoder = GraphormerModel(config)
854
+ self.embedding_dim = config.embedding_dim
855
+ self.num_classes = config.num_classes
856
+ self.classifier = GraphormerDecoderHead(self.embedding_dim, self.num_classes)
857
+ self.is_encoder_decoder = True
858
+
859
+ # Initialize weights and apply final processing
860
+ self.post_init()
861
+
862
+ def forward(
863
+ self,
864
+ input_nodes: torch.LongTensor,
865
+ input_edges: torch.LongTensor,
866
+ attn_bias: torch.Tensor,
867
+ in_degree: torch.LongTensor,
868
+ out_degree: torch.LongTensor,
869
+ spatial_pos: torch.LongTensor,
870
+ attn_edge_type: torch.LongTensor,
871
+ labels: Optional[torch.LongTensor] = None,
872
+ return_dict: Optional[bool] = None,
873
+ **unused,
874
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
875
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
876
+
877
+ encoder_outputs = self.encoder(
878
+ input_nodes,
879
+ input_edges,
880
+ attn_bias,
881
+ in_degree,
882
+ out_degree,
883
+ spatial_pos,
884
+ attn_edge_type,
885
+ return_dict=True,
886
+ )
887
+ outputs, hidden_states = encoder_outputs["last_hidden_state"], encoder_outputs["hidden_states"]
888
+
889
+ head_outputs = self.classifier(outputs)
890
+ logits = head_outputs[:, 0, :].contiguous()
891
+
892
+ loss = None
893
+ if labels is not None:
894
+ mask = ~torch.isnan(labels)
895
+
896
+ if self.num_classes == 1: # regression
897
+ loss_fct = MSELoss()
898
+ loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float())
899
+ elif self.num_classes > 1 and len(labels.shape) == 1: # One task classification
900
+ loss_fct = CrossEntropyLoss()
901
+ loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1))
902
+ else: # Binary multi-task classification
903
+ loss_fct = BCEWithLogitsLoss(reduction="sum")
904
+ loss = loss_fct(logits[mask], labels[mask])
905
+
906
+ if not return_dict:
907
+ return tuple(x for x in [loss, logits, hidden_states] if x is not None)
908
+ return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=None)
909
+
910
+
911
+ __all__ = ["GraphormerForGraphClassification", "GraphormerModel", "GraphormerPreTrainedModel"]
docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_jukebox import *
22
+ from .modeling_jukebox import *
23
+ from .tokenization_jukebox import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Jukebox configuration"""
16
+
17
+ import os
18
+ from typing import List, Union
19
+
20
+ from ....configuration_utils import PretrainedConfig
21
+ from ....utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ _LARGE_ATTENTION = [
28
+ "block_attn",
29
+ "transpose_block_attn",
30
+ "prev_block_attn",
31
+ "block_attn",
32
+ "transpose_block_attn",
33
+ "prev_block_attn",
34
+ "block_attn",
35
+ "transpose_block_attn",
36
+ "prev_block_attn",
37
+ "block_attn",
38
+ "transpose_block_attn",
39
+ "prev_block_attn",
40
+ "block_attn",
41
+ "transpose_block_attn",
42
+ "prev_block_attn",
43
+ "block_attn",
44
+ "transpose_block_attn",
45
+ "prev_block_attn",
46
+ "cross_attention",
47
+ "block_attn",
48
+ "transpose_block_attn",
49
+ "prev_block_attn",
50
+ "block_attn",
51
+ "transpose_block_attn",
52
+ "prev_block_attn",
53
+ "block_attn",
54
+ "transpose_block_attn",
55
+ "prev_block_attn",
56
+ "cross_attention",
57
+ "block_attn",
58
+ "transpose_block_attn",
59
+ "prev_block_attn",
60
+ "block_attn",
61
+ "transpose_block_attn",
62
+ "prev_block_attn",
63
+ "block_attn",
64
+ "transpose_block_attn",
65
+ "prev_block_attn",
66
+ "cross_attention",
67
+ "block_attn",
68
+ "transpose_block_attn",
69
+ "prev_block_attn",
70
+ "block_attn",
71
+ "transpose_block_attn",
72
+ "prev_block_attn",
73
+ "block_attn",
74
+ "transpose_block_attn",
75
+ "prev_block_attn",
76
+ "cross_attention",
77
+ "block_attn",
78
+ "transpose_block_attn",
79
+ "prev_block_attn",
80
+ "block_attn",
81
+ "transpose_block_attn",
82
+ "prev_block_attn",
83
+ "block_attn",
84
+ "transpose_block_attn",
85
+ "prev_block_attn",
86
+ "cross_attention",
87
+ "block_attn",
88
+ "transpose_block_attn",
89
+ "prev_block_attn",
90
+ "block_attn",
91
+ "transpose_block_attn",
92
+ "prev_block_attn",
93
+ "block_attn",
94
+ "transpose_block_attn",
95
+ "prev_block_attn",
96
+ "cross_attention",
97
+ "block_attn",
98
+ "transpose_block_attn",
99
+ "prev_block_attn",
100
+ "block_attn",
101
+ "transpose_block_attn",
102
+ "prev_block_attn",
103
+ "block_attn",
104
+ "transpose_block_attn",
105
+ "prev_block_attn",
106
+ "cross_attention",
107
+ ]
108
+ _RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"]
109
+ _FullDenseAttention = ["dense_attention"]
110
+ _PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"]
111
+
112
+
113
+ def full_dense_attention(layer):
114
+ return _FullDenseAttention[0]
115
+
116
+
117
+ def raw_column_previous_row_attention(layer):
118
+ return _RawColumnPreviousRowAttention[layer % 3]
119
+
120
+
121
+ def large_separated_enc_dec_w_lyrics(layer):
122
+ return _LARGE_ATTENTION[layer % 79]
123
+
124
+
125
+ def enc_dec_with_lyrics(layer):
126
+ if layer % 16 == 15:
127
+ return _PrimePrimeDenseAttention[layer % 3]
128
+ return _RawColumnPreviousRowAttention[layer % 3]
129
+
130
+
131
+ ATTENTION_PATTERNS = {
132
+ "full_dense_attention": full_dense_attention,
133
+ "raw_column_previous_row_attention": raw_column_previous_row_attention, # Alternate row, column and previous row attn
134
+ "large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics, # Used by large separated_enc_dec model with lyrics
135
+ "enc_dec_with_lyrics": enc_dec_with_lyrics, # Used by encoder_decoder model with lyrics
136
+ }
137
+
138
+
139
+ class JukeboxPriorConfig(PretrainedConfig):
140
+ """
141
+ This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a
142
+ `JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a
143
+ configuration with the defaults will yield a similar configuration to that of the top level prior from the
144
+ [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox
145
+ -1b-lyrics) architecture.
146
+
147
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
148
+ documentation from [`PretrainedConfig`] for more information.
149
+
150
+
151
+
152
+ Args:
153
+ act_fn (`str`, *optional*, defaults to `"quick_gelu"`):
154
+ Activation function.
155
+ alignment_head (`int`, *optional*, defaults to 2):
156
+ Head that is responsible of the alignment between lyrics and music. Only used to compute the lyric to audio
157
+ alignment
158
+ alignment_layer (`int`, *optional*, defaults to 68):
159
+ Index of the layer that is responsible of the alignment between lyrics and music. Only used to compute the
160
+ lyric to audio alignment
161
+ attention_multiplier (`float`, *optional*, defaults to 0.25):
162
+ Multiplier coefficient used to define the hidden dimension of the attention layers. 0.25 means that
163
+ 0.25*width of the model will be used.
164
+ attention_pattern (`str`, *optional*, defaults to `"enc_dec_with_lyrics"`):
165
+ Which attention pattern to use for the decoder/
166
+ attn_dropout (`int`, *optional*, defaults to 0):
167
+ Dropout probability for the post-attention layer dropout in the decoder.
168
+ attn_res_scale (`bool`, *optional*, defaults to `False`):
169
+ Whether or not to scale the residuals in the attention conditioner block.
170
+ blocks (`int`, *optional*, defaults to 64):
171
+ Number of blocks used in the `block_attn`. A sequence of length seq_len is factored as `[blocks, seq_len //
172
+ blocks]` in the `JukeboxAttention` layer.
173
+ conv_res_scale (`int`, *optional*):
174
+ Whether or not to scale the residuals in the conditioner block. Since the top level prior does not have a
175
+ conditioner, the default value is to None and should not be modified.
176
+ num_layers (`int`, *optional*, defaults to 72):
177
+ Number of layers of the transformer architecture.
178
+ emb_dropout (`int`, *optional*, defaults to 0):
179
+ Embedding dropout used in the lyric decoder.
180
+ encoder_config (`JukeboxPriorConfig`, *optional*) :
181
+ Configuration of the encoder which models the prior on the lyrics.
182
+ encoder_loss_fraction (`float`, *optional*, defaults to 0.4):
183
+ Multiplication factor used in front of the lyric encoder loss.
184
+ hidden_size (`int`, *optional*, defaults to 2048):
185
+ Hidden dimension of the attention layers.
186
+ init_scale (`float`, *optional*, defaults to 0.2):
187
+ Initialization scales for the prior modules.
188
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
189
+ Whether or not the prior is an encoder-decoder model. In case it is not, and `nb_relevant_lyric_tokens` is
190
+ greater than 0, the `encoder` args should be specified for the lyric encoding.
191
+ mask (`bool`, *optional*, defaults to `False`):
192
+ Whether or not to mask the previous positions in the attention.
193
+ max_duration (`int`, *optional*, defaults to 600):
194
+ Maximum supported duration of the generated song in seconds.
195
+ max_nb_genres (`int`, *optional*, defaults to 1):
196
+ Maximum number of genres that can be used to condition the model.
197
+ merged_decoder (`bool`, *optional*, defaults to `True`):
198
+ Whether or not the decoder and the encoder inputs are merged. This is used for the separated
199
+ encoder-decoder architecture
200
+ metadata_conditioning (`bool`, *optional*, defaults to `True)`:
201
+ Whether or not to condition on the artist and genre metadata.
202
+ metadata_dims (`List[int]`, *optional*, defaults to `[604, 7898]`):
203
+ Number of genres and the number of artists that were used to train the embedding layers of the prior
204
+ models.
205
+ min_duration (`int`, *optional*, defaults to 0):
206
+ Minimum duration of the generated audio on which the model was trained.
207
+ mlp_multiplier (`float`, *optional*, defaults to 1.0):
208
+ Multiplier coefficient used to define the hidden dimension of the MLP layers. 0.25 means that 0.25*width of
209
+ the model will be used.
210
+ music_vocab_size (`int`, *optional*, defaults to 2048):
211
+ Number of different music tokens. Should be similar to the `JukeboxVQVAEConfig.nb_discrete_codes`.
212
+ n_ctx (`int`, *optional*, defaults to 6144):
213
+ Number of context tokens for each prior. The context tokens are the music tokens that are attended to when
214
+ generating music tokens.
215
+ n_heads (`int`, *optional*, defaults to 2):
216
+ Number of attention heads.
217
+ nb_relevant_lyric_tokens (`int`, *optional*, defaults to 384):
218
+ Number of lyric tokens that are used when sampling a single window of length `n_ctx`
219
+ res_conv_depth (`int`, *optional*, defaults to 3):
220
+ Depth of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the
221
+ `JukeboxMusicTokenConditioner`.
222
+ res_conv_width (`int`, *optional*, defaults to 128):
223
+ Width of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the
224
+ `JukeboxMusicTokenConditioner`.
225
+ res_convolution_multiplier (`int`, *optional*, defaults to 1):
226
+ Multiplier used to scale the `hidden_dim` of the `JukeboxResConv1DBlock`.
227
+ res_dilation_cycle (`int`, *optional*):
228
+ Dilation cycle used to define the `JukeboxMusicTokenConditioner`. Usually similar to the ones used in the
229
+ corresponding level of the VQVAE. The first prior does not use it as it is not conditioned on upper level
230
+ tokens.
231
+ res_dilation_growth_rate (`int`, *optional*, defaults to 1):
232
+ Dilation grow rate used between each convolutionnal block of the `JukeboxMusicTokenConditioner`
233
+ res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`):
234
+ Downsampling rates used in the audio conditioning network
235
+ res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
236
+ Striding used in the audio conditioning network
237
+ resid_dropout (`int`, *optional*, defaults to 0):
238
+ Residual dropout used in the attention pattern.
239
+ sampling_rate (`int`, *optional*, defaults to 44100):
240
+ Sampling rate used for training.
241
+ spread (`int`, *optional*):
242
+ Spread used in the `summary_spread_attention` pattern
243
+ timing_dims (`int`, *optional*, defaults to 64):
244
+ Dimension of the timing embedding.
245
+ zero_out (`bool`, *optional*, defaults to `False`):
246
+ Whether or not to zero out convolution weights when initializing.
247
+ """
248
+
249
+ model_type = "jukebox_prior"
250
+ attribute_map = {
251
+ "max_position_embeddings": "n_positions",
252
+ "num_attention_heads": "n_head",
253
+ }
254
+
255
+ def __init__(
256
+ self,
257
+ act_fn="quick_gelu",
258
+ level=0,
259
+ alignment_head=2,
260
+ alignment_layer=68,
261
+ attention_multiplier=0.25,
262
+ attention_pattern="enc_dec_with_lyrics",
263
+ attn_dropout=0,
264
+ attn_res_scale=False,
265
+ blocks=64,
266
+ conv_res_scale=None,
267
+ num_layers=72,
268
+ emb_dropout=0,
269
+ encoder_config=None,
270
+ encoder_loss_fraction=0.4,
271
+ hidden_size=2048,
272
+ init_scale=0.2,
273
+ is_encoder_decoder=True,
274
+ lyric_vocab_size=80,
275
+ mask=False,
276
+ max_duration=600,
277
+ max_nb_genres=1,
278
+ merged_decoder=True,
279
+ metadata_conditioning=True,
280
+ metadata_dims=[604, 7898],
281
+ min_duration=0,
282
+ mlp_multiplier=1.0,
283
+ music_vocab_size=2048,
284
+ n_ctx=6144,
285
+ n_heads=2,
286
+ nb_relevant_lyric_tokens=384,
287
+ res_conv_depth=3,
288
+ res_conv_width=128,
289
+ res_convolution_multiplier=1,
290
+ res_dilation_cycle=None,
291
+ res_dilation_growth_rate=1,
292
+ res_downs_t=[3, 2, 2],
293
+ res_strides_t=[2, 2, 2],
294
+ resid_dropout=0,
295
+ sampling_rate=44100,
296
+ spread=None,
297
+ timing_dims=64,
298
+ zero_out=False,
299
+ **kwargs,
300
+ ):
301
+ self.act_fn = act_fn
302
+ self.alignment_head = alignment_head
303
+ self.alignment_layer = alignment_layer
304
+ self.attention_multiplier = attention_multiplier
305
+ self.attention_pattern = attention_pattern
306
+ self.attn_dropout = attn_dropout
307
+ self.attn_res_scale = attn_res_scale
308
+ self.blocks = blocks
309
+ self.conv_res_scale = conv_res_scale
310
+ self.num_layers = num_layers
311
+ self.emb_dropout = emb_dropout
312
+ self.music_vocab_size = music_vocab_size
313
+ if encoder_config is not None:
314
+ self.encoder_config = JukeboxPriorConfig(**encoder_config)
315
+ else:
316
+ self.encoder_config = None
317
+ self.encoder_loss_fraction = encoder_loss_fraction
318
+ self.init_scale = init_scale
319
+ self.is_encoder_decoder = is_encoder_decoder
320
+ self.lyric_vocab_size = lyric_vocab_size
321
+ self.level = level
322
+ self.mask = mask
323
+ self.max_duration = max_duration
324
+ self.max_nb_genres = max_nb_genres
325
+ self.merged_decoder = merged_decoder
326
+ self.metadata_conditioning = metadata_conditioning
327
+ self.metadata_dims = metadata_dims
328
+ self.min_duration = min_duration
329
+ self.mlp_multiplier = mlp_multiplier
330
+ self.n_ctx = n_ctx
331
+ self.n_heads = n_heads
332
+ self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens
333
+ self.res_conv_depth = res_conv_depth
334
+ self.res_conv_width = res_conv_width
335
+ self.res_convolution_multiplier = res_convolution_multiplier
336
+ self.res_dilation_cycle = res_dilation_cycle
337
+ self.res_dilation_growth_rate = res_dilation_growth_rate
338
+ self.res_downs_t = res_downs_t
339
+ self.res_strides_t = res_strides_t
340
+ self.resid_dropout = resid_dropout
341
+ self.sampling_rate = sampling_rate
342
+ self.spread = spread
343
+ self.timing_dims = timing_dims
344
+ self.hidden_size = hidden_size
345
+ self.zero_out = zero_out
346
+
347
+ @classmethod
348
+ def from_pretrained(
349
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs
350
+ ) -> "PretrainedConfig":
351
+ cls._set_token_in_kwargs(kwargs)
352
+
353
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
354
+
355
+ # get the prior config dict if we are loading from JukeboxConfig
356
+ if config_dict.get("model_type") == "jukebox":
357
+ config_dict = config_dict[f"prior_{level}"]
358
+
359
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
360
+ logger.warning(
361
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
362
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
363
+ )
364
+
365
+ return cls.from_dict(config_dict, **kwargs)
366
+
367
+
368
+ class JukeboxVQVAEConfig(PretrainedConfig):
369
+ """
370
+ This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a
371
+ `JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration
372
+ with the defaults will yield a similar configuration to that of the VQVAE from
373
+ [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
374
+
375
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
376
+ documentation from [`PretrainedConfig`] for more information.
377
+
378
+ Args:
379
+ act_fn (`str`, *optional*, defaults to `"relu"`):
380
+ Activation function of the model.
381
+ nb_discrete_codes (`int`, *optional*, defaults to 2048):
382
+ Number of codes of the VQVAE.
383
+ commit (`float`, *optional*, defaults to 0.02):
384
+ Commit loss multiplier.
385
+ conv_input_shape (`int`, *optional*, defaults to 1):
386
+ Number of audio channels.
387
+ conv_res_scale (`bool`, *optional*, defaults to `False`):
388
+ Whether or not to scale the residuals of the `JukeboxResConv1DBlock`.
389
+ embed_dim (`int`, *optional*, defaults to 64):
390
+ Embedding dimension of the codebook vectors.
391
+ hop_fraction (`List[int]`, *optional*, defaults to `[0.125, 0.5, 0.5]`):
392
+ Fraction of non-intersecting window used when continuing the sampling process.
393
+ levels (`int`, *optional*, defaults to 3):
394
+ Number of hierarchical levels that used in the VQVAE.
395
+ lmu (`float`, *optional*, defaults to 0.99):
396
+ Used in the codebook update, exponential moving average coefficient. For more detail refer to Appendix A.1
397
+ of the original [VQVAE paper](https://arxiv.org/pdf/1711.00937v2.pdf)
398
+ multipliers (`List[int]`, *optional*, defaults to `[2, 1, 1]`):
399
+ Depth and width multipliers used for each level. Used on the `res_conv_width` and `res_conv_depth`
400
+ res_conv_depth (`int`, *optional*, defaults to 4):
401
+ Depth of the encoder and decoder block. If no `multipliers` are used, this is the same for each level.
402
+ res_conv_width (`int`, *optional*, defaults to 32):
403
+ Width of the encoder and decoder block. If no `multipliers` are used, this is the same for each level.
404
+ res_convolution_multiplier (`int`, *optional*, defaults to 1):
405
+ Scaling factor of the hidden dimension used in the `JukeboxResConv1DBlock`.
406
+ res_dilation_cycle (`int`, *optional*):
407
+ Dilation cycle value used in the `JukeboxResnet`. If an int is used, each new Conv1 block will have a depth
408
+ reduced by a power of `res_dilation_cycle`.
409
+ res_dilation_growth_rate (`int`, *optional*, defaults to 3):
410
+ Resnet dilation growth rate used in the VQVAE (dilation_growth_rate ** depth)
411
+ res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`):
412
+ Downsampling rate for each level of the hierarchical VQ-VAE.
413
+ res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
414
+ Stride used for each level of the hierarchical VQ-VAE.
415
+ sample_length (`int`, *optional*, defaults to 1058304):
416
+ Provides the max input shape of the VQVAE. Is used to compute the input shape of each level.
417
+ init_scale (`float`, *optional*, defaults to 0.2):
418
+ Initialization scale.
419
+ zero_out (`bool`, *optional*, defaults to `False`):
420
+ Whether or not to zero out convolution weights when initializing.
421
+ """
422
+
423
+ model_type = "jukebox_vqvae"
424
+
425
+ def __init__(
426
+ self,
427
+ act_fn="relu",
428
+ nb_discrete_codes=2048,
429
+ commit=0.02,
430
+ conv_input_shape=1,
431
+ conv_res_scale=False,
432
+ embed_dim=64,
433
+ hop_fraction=[0.125, 0.5, 0.5],
434
+ levels=3,
435
+ lmu=0.99,
436
+ multipliers=[2, 1, 1],
437
+ res_conv_depth=4,
438
+ res_conv_width=32,
439
+ res_convolution_multiplier=1,
440
+ res_dilation_cycle=None,
441
+ res_dilation_growth_rate=3,
442
+ res_downs_t=[3, 2, 2],
443
+ res_strides_t=[2, 2, 2],
444
+ sample_length=1058304,
445
+ init_scale=0.2,
446
+ zero_out=False,
447
+ **kwargs,
448
+ ):
449
+ self.hop_fraction = hop_fraction
450
+ self.conv_input_shape = conv_input_shape
451
+ self.sample_length = sample_length
452
+
453
+ # VQVAE parameters (all used)
454
+ self.levels = levels
455
+ self.embed_dim = embed_dim
456
+ self.nb_discrete_codes = nb_discrete_codes
457
+ self.res_conv_width = res_conv_width
458
+ self.res_conv_depth = res_conv_depth
459
+ self.res_convolution_multiplier = res_convolution_multiplier
460
+ self.res_dilation_growth_rate = res_dilation_growth_rate
461
+ self.res_dilation_cycle = res_dilation_cycle
462
+ self.multipliers = multipliers
463
+ self.res_downs_t = res_downs_t
464
+ self.res_strides_t = res_strides_t
465
+ self.lmu = lmu
466
+ self.commit = commit
467
+ self.conv_res_scale = conv_res_scale
468
+ self.act_fn = act_fn
469
+ self.init_scale = init_scale
470
+ self.zero_out = zero_out
471
+
472
+ @classmethod
473
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
474
+ cls._set_token_in_kwargs(kwargs)
475
+
476
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
477
+
478
+ # get the text config dict if we are loading from CLIPConfig
479
+ if config_dict.get("model_type") == "jukebox":
480
+ config_dict = config_dict["vqvae_config"]
481
+
482
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
483
+ logger.warning(
484
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
485
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
486
+ )
487
+
488
+ return cls.from_dict(config_dict, **kwargs)
489
+
490
+
491
+ class JukeboxConfig(PretrainedConfig):
492
+ """
493
+ This is the configuration class to store the configuration of a [`JukeboxModel`].
494
+
495
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
496
+ documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will
497
+ yield a similar configuration to that of
498
+ [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
499
+
500
+
501
+ The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling =
502
+ (5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256
503
+ to get the second level codes. This is mostly true for training the top level prior and the upsamplers.
504
+
505
+ Args:
506
+ vqvae_config (`JukeboxVQVAEConfig`, *optional*):
507
+ Configuration for the `JukeboxVQVAE` model.
508
+ prior_config_list (`List[JukeboxPriorConfig]`, *optional*):
509
+ List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors.
510
+ nb_priors (`int`, *optional*, defaults to 3):
511
+ Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive
512
+ (decoder) model, apart from the top prior, which can include a lyric encoder. The available models were
513
+ trained using a top prior and 2 upsampler priors.
514
+ sampling_rate (`int`, *optional*, defaults to 44100):
515
+ Sampling rate of the raw audio.
516
+ timing_dims (`int`, *optional*, defaults to 64):
517
+ Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding
518
+ layer. The timing embedding layer converts the absolute and relative position in the currently sampled
519
+ audio to a tensor of length `timing_dims` that will be added to the music tokens.
520
+ min_duration (`int`, *optional*, defaults to 0):
521
+ Minimum duration of the audios to generate
522
+ max_duration (`float`, *optional*, defaults to 600.0):
523
+ Maximum duration of the audios to generate
524
+ max_nb_genres (`int`, *optional*, defaults to 5):
525
+ Maximum number of genres that can be used to condition a single sample.
526
+ metadata_conditioning (`bool`, *optional*, defaults to `True`):
527
+ Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum
528
+ duration.
529
+
530
+ Example:
531
+
532
+ ```python
533
+ >>> from transformers import JukeboxModel, JukeboxConfig
534
+
535
+ >>> # Initializing a Jukebox configuration
536
+ >>> configuration = JukeboxConfig()
537
+
538
+ >>> # Initializing a model from the configuration
539
+ >>> model = JukeboxModel(configuration)
540
+
541
+ >>> # Accessing the model configuration
542
+ >>> configuration = model.config
543
+ ```
544
+ """
545
+
546
+ model_type = "jukebox"
547
+
548
+ def __init__(
549
+ self,
550
+ vqvae_config=None,
551
+ prior_config_list=None,
552
+ nb_priors=3,
553
+ sampling_rate=44100,
554
+ timing_dims=64,
555
+ min_duration=0,
556
+ max_duration=600.0,
557
+ max_nb_genres=5,
558
+ metadata_conditioning=True,
559
+ **kwargs,
560
+ ):
561
+ if vqvae_config is None:
562
+ vqvae_config = {}
563
+ logger.info("vqvae_config is None. initializing the JukeboxVQVAE with default values.")
564
+
565
+ self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config)
566
+ if prior_config_list is not None:
567
+ self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list]
568
+ else:
569
+ self.prior_configs = []
570
+ for prior_idx in range(nb_priors):
571
+ prior_config = kwargs.pop(f"prior_{prior_idx}", None)
572
+ if prior_config is None:
573
+ prior_config = {}
574
+ logger.info(
575
+ f"prior_{prior_idx}'s config is None. Initializing the JukeboxPriorConfig list with default"
576
+ " values."
577
+ )
578
+ self.prior_configs.append(JukeboxPriorConfig(**prior_config))
579
+
580
+ self.hop_fraction = self.vqvae_config.hop_fraction
581
+
582
+ self.nb_priors = nb_priors
583
+
584
+ # Metadata conditioning
585
+ self.max_nb_genres = max_nb_genres
586
+ self.sampling_rate = sampling_rate
587
+ self.timing_dims = timing_dims
588
+ self.min_duration = min_duration
589
+ self.max_duration = max_duration
590
+ self.metadata_conditioning = metadata_conditioning
591
+
592
+ super().__init__(**kwargs)
593
+
594
+ @classmethod
595
+ def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs):
596
+ r"""
597
+ Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model
598
+ configuration.
599
+
600
+ Returns:
601
+ [`JukeboxConfig`]: An instance of a configuration object
602
+ """
603
+ prior_config_list = [config.to_dict() for config in prior_configs]
604
+ return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)
605
+
606
+ def to_dict(self):
607
+ # Override the default to_dict to apply to_dict to the list of prior configs.
608
+ result = super().to_dict()
609
+ result["prior_config_list"] = [config.to_dict() for config in result.pop("prior_configs")]
610
+ return result
611
+
612
+
613
+ __all__ = ["JukeboxConfig", "JukeboxPriorConfig", "JukeboxVQVAEConfig"]
docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Jukebox checkpoints"""
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ from pathlib import Path
21
+
22
+ import requests
23
+ import torch
24
+
25
+ from transformers import JukeboxConfig, JukeboxModel
26
+ from transformers.utils import logging
27
+
28
+
29
+ logging.set_verbosity_info()
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ PREFIX = "https://openaipublic.azureedge.net/jukebox/models/"
34
+ MODEL_MAPPING = {
35
+ "jukebox-1b-lyrics": [
36
+ "5b/vqvae.pth.tar",
37
+ "5b/prior_level_0.pth.tar",
38
+ "5b/prior_level_1.pth.tar",
39
+ "1b_lyrics/prior_level_2.pth.tar",
40
+ ],
41
+ "jukebox-5b-lyrics": [
42
+ "5b/vqvae.pth.tar",
43
+ "5b/prior_level_0.pth.tar",
44
+ "5b/prior_level_1.pth.tar",
45
+ "5b_lyrics/prior_level_2.pth.tar",
46
+ ],
47
+ }
48
+
49
+
50
+ def replace_key(key):
51
+ if key.endswith(".model.1.bias") and len(key.split(".")) > 10:
52
+ key = key.replace(".model.1.bias", ".conv1d_1.bias")
53
+ elif key.endswith(".model.1.weight") and len(key.split(".")) > 10:
54
+ key = key.replace(".model.1.weight", ".conv1d_1.weight")
55
+ elif key.endswith(".model.3.bias") and len(key.split(".")) > 10:
56
+ key = key.replace(".model.3.bias", ".conv1d_2.bias")
57
+ elif key.endswith(".model.3.weight") and len(key.split(".")) > 10:
58
+ key = key.replace(".model.3.weight", ".conv1d_2.weight")
59
+
60
+ if "conditioner_blocks.0." in key:
61
+ key = key.replace("conditioner_blocks.0", "conditioner_blocks")
62
+
63
+ if "prime_prior" in key:
64
+ key = key.replace("prime_prior", "encoder")
65
+
66
+ if ".emb." in key and "total" not in key and "absolute" not in key and "relative" not in key:
67
+ key = key.replace(".emb.", ".")
68
+
69
+ if key.endswith("k"): # replace vqvae.X.k with vqvae.X.codebook
70
+ return key.replace(".k", ".codebook")
71
+ if "y_emb." in key:
72
+ return key.replace("y_emb.", "metadata_embedding.")
73
+
74
+ if "x_emb.emb." in key:
75
+ key = key.replace("0.x_emb.emb", "embed_tokens")
76
+
77
+ if "prime_state_ln" in key:
78
+ return key.replace("prime_state_ln", "encoder.final_layer_norm")
79
+ if ".ln" in key:
80
+ return key.replace(".ln", ".layer_norm")
81
+ if "_ln" in key:
82
+ return key.replace("_ln", "_layer_norm")
83
+
84
+ if "prime_state_proj" in key:
85
+ return key.replace("prime_state_proj", "encoder.proj_in")
86
+ if "prime_x_out" in key:
87
+ return key.replace("prime_x_out", "encoder.lm_head")
88
+ if "prior.x_out" in key:
89
+ return key.replace("x_out", "fc_proj_out")
90
+ if "x_emb" in key:
91
+ return key.replace("x_emb", "embed_tokens")
92
+
93
+ return key
94
+
95
+
96
+ def fix_jukebox_keys(state_dict, model_state_dict, key_prefix, mapping):
97
+ new_dict = {}
98
+ import re
99
+
100
+ re_encoder_block_conv_in = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
101
+ re_encoder_block_resnet = re.compile(
102
+ r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
103
+ )
104
+ re_encoder_block_proj_out = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
105
+
106
+ re_decoder_block_conv_out = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
107
+ re_decoder_block_resnet = re.compile(
108
+ r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
109
+ )
110
+ re_decoder_block_proj_in = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
111
+
112
+ re_prior_cond_conv_out = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).(bias|weight)")
113
+ re_prior_cond_resnet = re.compile(
114
+ r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
115
+ )
116
+ re_prior_cond_proj_in = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(bias|weight)")
117
+
118
+ for original_key, value in state_dict.items():
119
+ # rename vqvae.encoder keys
120
+ if re_encoder_block_conv_in.fullmatch(original_key):
121
+ regex_match = re_encoder_block_conv_in.match(original_key)
122
+ groups = regex_match.groups()
123
+ block_index = int(groups[2]) * 2 + int(groups[3])
124
+ re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}.{groups[-1]}"
125
+ key = re_encoder_block_conv_in.sub(re_new_key, original_key)
126
+
127
+ elif re_encoder_block_resnet.fullmatch(original_key):
128
+ regex_match = re_encoder_block_resnet.match(original_key)
129
+ groups = regex_match.groups()
130
+ block_index = int(groups[2]) * 2 + int(groups[3])
131
+ conv_index = {"1": 1, "3": 2}[groups[-2]]
132
+ prefix = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}."
133
+ resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
134
+ re_new_key = prefix + resnet_block
135
+ key = re_encoder_block_resnet.sub(re_new_key, original_key)
136
+
137
+ elif re_encoder_block_proj_out.fullmatch(original_key):
138
+ regex_match = re_encoder_block_proj_out.match(original_key)
139
+ groups = regex_match.groups()
140
+ re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.proj_out.{groups[-1]}"
141
+ key = re_encoder_block_proj_out.sub(re_new_key, original_key)
142
+
143
+ # rename vqvae.decoder keys
144
+ elif re_decoder_block_conv_out.fullmatch(original_key):
145
+ regex_match = re_decoder_block_conv_out.match(original_key)
146
+ groups = regex_match.groups()
147
+ block_index = int(groups[2]) * 2 + int(groups[3]) - 2
148
+ re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}.{groups[-1]}"
149
+ key = re_decoder_block_conv_out.sub(re_new_key, original_key)
150
+
151
+ elif re_decoder_block_resnet.fullmatch(original_key):
152
+ regex_match = re_decoder_block_resnet.match(original_key)
153
+ groups = regex_match.groups()
154
+ block_index = int(groups[2]) * 2 + int(groups[3]) - 2
155
+ conv_index = {"1": 1, "3": 2}[groups[-2]]
156
+ prefix = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}."
157
+ resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
158
+ re_new_key = prefix + resnet_block
159
+ key = re_decoder_block_resnet.sub(re_new_key, original_key)
160
+
161
+ elif re_decoder_block_proj_in.fullmatch(original_key):
162
+ regex_match = re_decoder_block_proj_in.match(original_key)
163
+ groups = regex_match.groups()
164
+ re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.proj_in.{groups[-1]}"
165
+ key = re_decoder_block_proj_in.sub(re_new_key, original_key)
166
+
167
+ # rename prior cond.model to upsampler.upsample_block and resnet
168
+ elif re_prior_cond_conv_out.fullmatch(original_key):
169
+ regex_match = re_prior_cond_conv_out.match(original_key)
170
+ groups = regex_match.groups()
171
+ block_index = int(groups[1]) * 2 + int(groups[2]) - 2
172
+ re_new_key = f"conditioner_blocks.upsampler.upsample_block.{block_index}.{groups[-1]}"
173
+ key = re_prior_cond_conv_out.sub(re_new_key, original_key)
174
+
175
+ elif re_prior_cond_resnet.fullmatch(original_key):
176
+ regex_match = re_prior_cond_resnet.match(original_key)
177
+ groups = regex_match.groups()
178
+ block_index = int(groups[1]) * 2 + int(groups[2]) - 2
179
+ conv_index = {"1": 1, "3": 2}[groups[-2]]
180
+ prefix = f"conditioner_blocks.upsampler.upsample_block.{block_index}."
181
+ resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
182
+ re_new_key = prefix + resnet_block
183
+ key = re_prior_cond_resnet.sub(re_new_key, original_key)
184
+
185
+ elif re_prior_cond_proj_in.fullmatch(original_key):
186
+ regex_match = re_prior_cond_proj_in.match(original_key)
187
+ groups = regex_match.groups()
188
+ re_new_key = f"conditioner_blocks.upsampler.proj_in.{groups[-1]}"
189
+ key = re_prior_cond_proj_in.sub(re_new_key, original_key)
190
+
191
+ # keep original key
192
+ else:
193
+ key = original_key
194
+
195
+ key = replace_key(key)
196
+
197
+ if f"{key_prefix}.{key}" not in model_state_dict or key is None:
198
+ print(f"failed converting {original_key} to {key}, does not match")
199
+
200
+ # handle missmatched shape
201
+ elif value.shape != model_state_dict[f"{key_prefix}.{key}"].shape:
202
+ val = model_state_dict[f"{key_prefix}.{key}"]
203
+ print(f"{original_key}-> {key} : \nshape {val.shape} and {value.shape}, do not match")
204
+ key = original_key
205
+
206
+ mapping[key] = original_key
207
+ new_dict[key] = value
208
+
209
+ return new_dict
210
+
211
+
212
+ @torch.no_grad()
213
+ def convert_openai_checkpoint(model_name=None, pytorch_dump_folder_path=None):
214
+ """
215
+ Copy/paste/tweak model's weights to our Jukebox structure.
216
+ """
217
+ for file in MODEL_MAPPING[model_name]:
218
+ if not os.path.isfile(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}"):
219
+ r = requests.get(f"{PREFIX}{file}", allow_redirects=True)
220
+ os.makedirs(f"{pytorch_dump_folder_path}/", exist_ok=True)
221
+ open(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}", "wb").write(r.content)
222
+
223
+ model_to_convert = MODEL_MAPPING[model_name.split("/")[-1]]
224
+
225
+ config = JukeboxConfig.from_pretrained(model_name)
226
+ model = JukeboxModel(config)
227
+
228
+ weight_dict = []
229
+ mapping = {}
230
+ for i, dict_name in enumerate(model_to_convert):
231
+ old_dic = torch.load(f"{pytorch_dump_folder_path}/{dict_name.split('/')[-1]}", weights_only=True)["model"]
232
+
233
+ new_dic = {}
234
+ for k in old_dic.keys():
235
+ if k.endswith(".b"):
236
+ new_dic[k.replace("b", "bias")] = old_dic[k]
237
+ elif k.endswith(".w"):
238
+ new_dic[k.replace("w", "weight")] = old_dic[k]
239
+ elif "level_2" not in dict_name and "cond.model." in k:
240
+ new_dic[k.replace(".blocks.", ".model.")] = old_dic[k]
241
+ else:
242
+ new_dic[k] = old_dic[k]
243
+
244
+ key_prefix = "vqvae" if i == 0 else f"priors.{3 - i}"
245
+ new_dic = fix_jukebox_keys(new_dic, model.state_dict(), key_prefix, mapping)
246
+ weight_dict.append(new_dic)
247
+
248
+ vqvae_state_dict = weight_dict.pop(0)
249
+ model.vqvae.load_state_dict(vqvae_state_dict)
250
+ for i in range(len(weight_dict)):
251
+ model.priors[i].load_state_dict(weight_dict[2 - i])
252
+
253
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
254
+ with open(f"{pytorch_dump_folder_path}/mapping.json", "w") as txtfile:
255
+ json.dump(mapping, txtfile)
256
+
257
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
258
+ model.save_pretrained(pytorch_dump_folder_path)
259
+
260
+ return weight_dict
261
+
262
+
263
+ if __name__ == "__main__":
264
+ parser = argparse.ArgumentParser()
265
+ # Required parameters
266
+ parser.add_argument(
267
+ "--model_name",
268
+ default="jukebox-5b-lyrics",
269
+ type=str,
270
+ help="Name of the model you'd like to convert.",
271
+ )
272
+ parser.add_argument(
273
+ "--pytorch_dump_folder_path",
274
+ default="jukebox-5b-lyrics-converted",
275
+ type=str,
276
+ help="Path to the output PyTorch model directory.",
277
+ )
278
+ args = parser.parse_args()
279
+ convert_openai_checkpoint(args.model_name, args.pytorch_dump_folder_path)
docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Open AI Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for OpenAI Jukebox."""
16
+
17
+ import json
18
+ import os
19
+ import re
20
+ import unicodedata
21
+ from json.encoder import INFINITY
22
+ from typing import Any, Dict, List, Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+ import regex
26
+
27
+ from ....tokenization_utils import AddedToken, PreTrainedTokenizer
28
+ from ....tokenization_utils_base import BatchEncoding
29
+ from ....utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging
30
+ from ....utils.generic import _is_jax, _is_numpy
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ VOCAB_FILES_NAMES = {
36
+ "artists_file": "artists.json",
37
+ "lyrics_file": "lyrics.json",
38
+ "genres_file": "genres.json",
39
+ }
40
+
41
+
42
+ class JukeboxTokenizer(PreTrainedTokenizer):
43
+ """
44
+ Constructs a Jukebox tokenizer. Jukebox can be conditioned on 3 different inputs :
45
+ - Artists, unique ids are associated to each artist from the provided dictionary.
46
+ - Genres, unique ids are associated to each genre from the provided dictionary.
47
+ - Lyrics, character based tokenization. Must be initialized with the list of characters that are inside the
48
+ vocabulary.
49
+
50
+ This tokenizer does not require training. It should be able to process a different number of inputs:
51
+ as the conditioning of the model can be done on the three different queries. If None is provided, defaults values will be used.:
52
+
53
+ Depending on the number of genres on which the model should be conditioned (`n_genres`).
54
+ ```python
55
+ >>> from transformers import JukeboxTokenizer
56
+
57
+ >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
58
+ >>> tokenizer("Alan Jackson", "Country Rock", "old town road")["input_ids"]
59
+ [tensor([[ 0, 0, 0, 6785, 546, 41, 38, 30, 76, 46, 41, 49,
60
+ 40, 76, 44, 41, 27, 30]]), tensor([[ 0, 0, 0, 145, 0]]), tensor([[ 0, 0, 0, 145, 0]])]
61
+ ```
62
+
63
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
64
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
65
+
66
+ <Tip>
67
+
68
+ If nothing is provided, the genres and the artist will either be selected randomly or set to None
69
+
70
+ </Tip>
71
+
72
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to:
73
+ this superclass for more information regarding those methods.
74
+
75
+ However the code does not allow that and only supports composing from various genres.
76
+
77
+ Args:
78
+ artists_file (`str`):
79
+ Path to the vocabulary file which contains a mapping between artists and ids. The default file supports
80
+ both "v2" and "v3"
81
+ genres_file (`str`):
82
+ Path to the vocabulary file which contain a mapping between genres and ids.
83
+ lyrics_file (`str`):
84
+ Path to the vocabulary file which contains the accepted characters for the lyrics tokenization.
85
+ version (`List[str]`, `optional`, default to `["v3", "v2", "v2"]`) :
86
+ List of the tokenizer versions. The `5b-lyrics`'s top level prior model was trained using `v3` instead of
87
+ `v2`.
88
+ n_genres (`int`, `optional`, defaults to 1):
89
+ Maximum number of genres to use for composition.
90
+ max_n_lyric_tokens (`int`, `optional`, defaults to 512):
91
+ Maximum number of lyric tokens to keep.
92
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
93
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
94
+ token instead.
95
+ """
96
+
97
+ vocab_files_names = VOCAB_FILES_NAMES
98
+ model_input_names = ["input_ids", "attention_mask"]
99
+
100
+ def __init__(
101
+ self,
102
+ artists_file,
103
+ genres_file,
104
+ lyrics_file,
105
+ version=["v3", "v2", "v2"],
106
+ max_n_lyric_tokens=512,
107
+ n_genres=5,
108
+ unk_token="<|endoftext|>",
109
+ **kwargs,
110
+ ):
111
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
112
+ self.version = version
113
+ self.max_n_lyric_tokens = max_n_lyric_tokens
114
+ self.n_genres = n_genres
115
+ self._added_tokens_decoder = {0: unk_token}
116
+
117
+ with open(artists_file, encoding="utf-8") as vocab_handle:
118
+ self.artists_encoder = json.load(vocab_handle)
119
+
120
+ with open(genres_file, encoding="utf-8") as vocab_handle:
121
+ self.genres_encoder = json.load(vocab_handle)
122
+
123
+ with open(lyrics_file, encoding="utf-8") as vocab_handle:
124
+ self.lyrics_encoder = json.load(vocab_handle)
125
+
126
+ oov = r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+"
127
+ # In v2, we had a n_vocab=80 and in v3 we missed + and so n_vocab=79 of characters.
128
+ if len(self.lyrics_encoder) == 79:
129
+ oov = oov.replace(r"\-'", r"\-+'")
130
+
131
+ self.out_of_vocab = regex.compile(oov)
132
+ self.artists_decoder = {v: k for k, v in self.artists_encoder.items()}
133
+ self.genres_decoder = {v: k for k, v in self.genres_encoder.items()}
134
+ self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()}
135
+ super().__init__(
136
+ unk_token=unk_token,
137
+ n_genres=n_genres,
138
+ version=version,
139
+ max_n_lyric_tokens=max_n_lyric_tokens,
140
+ **kwargs,
141
+ )
142
+
143
+ @property
144
+ def vocab_size(self):
145
+ return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder)
146
+
147
+ def get_vocab(self):
148
+ return {
149
+ "artists_encoder": self.artists_encoder,
150
+ "genres_encoder": self.genres_encoder,
151
+ "lyrics_encoder": self.lyrics_encoder,
152
+ }
153
+
154
+ def _convert_token_to_id(self, list_artists, list_genres, list_lyrics):
155
+ """Converts the artist, genre and lyrics tokens to their index using the vocabulary.
156
+ The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to
157
+ the lyrics token sequence.
158
+ """
159
+ artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists]
160
+ for genres in range(len(list_genres)):
161
+ list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]]
162
+ list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres]))
163
+
164
+ lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []]
165
+ return artists_id, list_genres, lyric_ids
166
+
167
+ def _tokenize(self, lyrics):
168
+ """
169
+ Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
170
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
171
+
172
+ Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary.
173
+ """
174
+ # only lyrics are not tokenized, but character based is easily handled
175
+ return list(lyrics)
176
+
177
+ def tokenize(self, artist, genre, lyrics, **kwargs):
178
+ """
179
+ Converts three strings in a 3 sequence of tokens using the tokenizer
180
+ """
181
+ artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics)
182
+ lyrics = self._tokenize(lyrics)
183
+ return artist, genre, lyrics
184
+
185
+ def prepare_for_tokenization(
186
+ self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False
187
+ ) -> Tuple[str, str, str, Dict[str, Any]]:
188
+ """
189
+ Performs any necessary transformations before tokenization.
190
+
191
+ Args:
192
+ artist (`str`):
193
+ The artist name to prepare. This will mostly lower the string
194
+ genres (`str`):
195
+ The genre name to prepare. This will mostly lower the string.
196
+ lyrics (`str`):
197
+ The lyrics to prepare.
198
+ is_split_into_words (`bool`, *optional*, defaults to `False`):
199
+ Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
200
+ tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
201
+ which it will tokenize. This is useful for NER or token classification.
202
+ """
203
+ for idx in range(len(self.version)):
204
+ if self.version[idx] == "v3":
205
+ artists[idx] = artists[idx].lower()
206
+ genres[idx] = [genres[idx].lower()]
207
+ else:
208
+ artists[idx] = self._normalize(artists[idx]) + ".v2"
209
+ genres[idx] = [
210
+ self._normalize(genre) + ".v2" for genre in genres[idx].split("_")
211
+ ] # split is for the full dictionary with combined genres
212
+
213
+ if self.version[0] == "v2":
214
+ self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+")
215
+ vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n"
216
+ self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))}
217
+ self.vocab["<unk>"] = 0
218
+ self.n_vocab = len(vocab) + 1
219
+ self.lyrics_encoder = self.vocab
220
+ self.lyrics_decoder = {v: k for k, v in self.vocab.items()}
221
+ self.lyrics_decoder[0] = ""
222
+ else:
223
+ self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+")
224
+
225
+ lyrics = self._run_strip_accents(lyrics)
226
+ lyrics = lyrics.replace("\\", "\n")
227
+ lyrics = self.out_of_vocab.sub("", lyrics), [], []
228
+ return artists, genres, lyrics
229
+
230
+ def _run_strip_accents(self, text):
231
+ """Strips accents from a piece of text."""
232
+ text = unicodedata.normalize("NFD", text)
233
+ output = []
234
+ for char in text:
235
+ cat = unicodedata.category(char)
236
+ if cat == "Mn":
237
+ continue
238
+ output.append(char)
239
+ return "".join(output)
240
+
241
+ def _normalize(self, text: str) -> str:
242
+ """
243
+ Normalizes the input text. This process is for the genres and the artist
244
+
245
+ Args:
246
+ text (`str`):
247
+ Artist or Genre string to normalize
248
+ """
249
+
250
+ accepted = (
251
+ [chr(i) for i in range(ord("a"), ord("z") + 1)]
252
+ + [chr(i) for i in range(ord("A"), ord("Z") + 1)]
253
+ + [chr(i) for i in range(ord("0"), ord("9") + 1)]
254
+ + ["."]
255
+ )
256
+ accepted = frozenset(accepted)
257
+ pattern = re.compile(r"_+")
258
+ text = "".join([c if c in accepted else "_" for c in text.lower()])
259
+ text = pattern.sub("_", text).strip("_")
260
+ return text
261
+
262
+ def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str:
263
+ return " ".join(lyrics)
264
+
265
+ def convert_to_tensors(
266
+ self, inputs, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False
267
+ ):
268
+ """
269
+ Convert the inner content to tensors.
270
+
271
+ Args:
272
+ tensor_type (`str` or [`~utils.TensorType`], *optional*):
273
+ The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
274
+ unset, no modification is done.
275
+ prepend_batch_axis (`int`, *optional*, defaults to `False`):
276
+ Whether or not to add the batch dimension during the conversion.
277
+ """
278
+ # Convert to TensorType
279
+ if not isinstance(tensor_type, TensorType):
280
+ tensor_type = TensorType(tensor_type)
281
+
282
+ # Get a function reference for the correct framework
283
+ if tensor_type == TensorType.TENSORFLOW:
284
+ if not is_tf_available():
285
+ raise ImportError(
286
+ "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
287
+ )
288
+ import tensorflow as tf
289
+
290
+ as_tensor = tf.constant
291
+ is_tensor = tf.is_tensor
292
+ elif tensor_type == TensorType.PYTORCH:
293
+ if not is_torch_available():
294
+ raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
295
+ import torch
296
+
297
+ as_tensor = torch.tensor
298
+ is_tensor = torch.is_tensor
299
+ elif tensor_type == TensorType.JAX:
300
+ if not is_flax_available():
301
+ raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
302
+ import jax.numpy as jnp # noqa: F811
303
+
304
+ as_tensor = jnp.array
305
+ is_tensor = _is_jax
306
+ else:
307
+ as_tensor = np.asarray
308
+ is_tensor = _is_numpy
309
+
310
+ # Do the tensor conversion in batch
311
+
312
+ try:
313
+ if prepend_batch_axis:
314
+ inputs = [inputs]
315
+
316
+ if not is_tensor(inputs):
317
+ inputs = as_tensor(inputs)
318
+ except: # noqa E722
319
+ raise ValueError(
320
+ "Unable to create tensor, you should probably activate truncation and/or padding "
321
+ "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
322
+ )
323
+
324
+ return inputs
325
+
326
+ def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding:
327
+ """Convert the raw string to a list of token ids
328
+
329
+ Args:
330
+ artist (`str`):
331
+ Name of the artist.
332
+ genres (`str`):
333
+ List of genres that will be mixed to condition the audio
334
+ lyrics (`str`, *optional*, defaults to `""`):
335
+ Lyrics used to condition the generation
336
+ """
337
+ input_ids = [0, 0, 0]
338
+ artist = [artist] * len(self.version)
339
+ genres = [genres] * len(self.version)
340
+
341
+ artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics)
342
+ artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens)
343
+
344
+ attention_masks = [-INFINITY] * len(full_tokens[-1])
345
+ input_ids = [
346
+ self.convert_to_tensors(
347
+ [input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors
348
+ )
349
+ for i in range(len(self.version))
350
+ ]
351
+ return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks})
352
+
353
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
354
+ """
355
+ Saves the tokenizer's vocabulary dictionary to the provided save_directory.
356
+
357
+ Args:
358
+ save_directory (`str`):
359
+ A path to the directory where to saved. It will be created if it doesn't exist.
360
+
361
+ filename_prefix (`Optional[str]`, *optional*):
362
+ A prefix to add to the names of the files saved by the tokenizer.
363
+
364
+ """
365
+ if not os.path.isdir(save_directory):
366
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
367
+ return
368
+
369
+ artists_file = os.path.join(
370
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"]
371
+ )
372
+ with open(artists_file, "w", encoding="utf-8") as f:
373
+ f.write(json.dumps(self.artists_encoder, ensure_ascii=False))
374
+
375
+ genres_file = os.path.join(
376
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"]
377
+ )
378
+ with open(genres_file, "w", encoding="utf-8") as f:
379
+ f.write(json.dumps(self.genres_encoder, ensure_ascii=False))
380
+
381
+ lyrics_file = os.path.join(
382
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"]
383
+ )
384
+ with open(lyrics_file, "w", encoding="utf-8") as f:
385
+ f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False))
386
+
387
+ return (artists_file, genres_file, lyrics_file)
388
+
389
+ def _convert_id_to_token(self, artists_index, genres_index, lyric_index):
390
+ """
391
+ Converts an index (integer) in a token (str) using the vocab.
392
+
393
+ Args:
394
+ artists_index (`int`):
395
+ Index of the artist in its corresponding dictionary.
396
+ genres_index (`Union[List[int], int]`):
397
+ Index of the genre in its corresponding dictionary.
398
+ lyric_index (`List[int]`):
399
+ List of character indices, which each correspond to a character.
400
+ """
401
+ artist = self.artists_decoder.get(artists_index)
402
+ genres = [self.genres_decoder.get(genre) for genre in genres_index]
403
+ lyrics = [self.lyrics_decoder.get(character) for character in lyric_index]
404
+ return artist, genres, lyrics
405
+
406
+
407
+ __all__ = ["JukeboxTokenizer"]
docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_mctct import *
22
+ from .feature_extraction_mctct import *
23
+ from .modeling_mctct import *
24
+ from .processing_mctct import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """M-CTC-T model configuration"""
16
+
17
+ from ....configuration_utils import PretrainedConfig
18
+ from ....utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class MCTCTConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`MCTCTModel`]. It is used to instantiate an
27
+ M-CTC-T model according to the specified arguments, defining the model architecture. Instantiating a configuration
28
+ with the defaults will yield a similar configuration to that of the M-CTC-T
29
+ [speechbrain/m-ctc-t-large](https://huggingface.co/speechbrain/m-ctc-t-large) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 8065):
37
+ Vocabulary size of the M-CTC-T model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`MCTCTModel`].
39
+ hidden_size (`int`, *optional*, defaults to 1536):
40
+ Dimension of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 36):
42
+ Number of hidden layers in the Transformer encoder.
43
+ intermediate_size (`int`, *optional*, defaults to 6144):
44
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
45
+ num_attention_heads (`int`, *optional*, defaults to 4):
46
+ Number of attention heads for each attention layer in the Transformer encoder.
47
+ attention_head_dim (`int`, *optional*, defaults to 384):
48
+ Dimensions of each attention head for each attention layer in the Transformer encoder.
49
+ max_position_embeddings (`int`, *optional*, defaults to 920):
50
+ The maximum sequence length that this model might ever be used with (after log-mel spectrogram extraction).
51
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
52
+ The epsilon used by the layer normalization layers.
53
+ layerdrop (`float`, *optional*, defaults to 0.3):
54
+ The probability of dropping an encoder layer during training. The default 0.3 value is used in the original
55
+ implementation.
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
57
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
58
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.3):
62
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
63
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.3):
64
+ The dropout ratio for the attention probabilities.
65
+ pad_token_id (`int`, *optional*, defaults to 1):
66
+ The tokenizer index of the pad token.
67
+ bos_token_id (`int`, *optional*, defaults to 0):
68
+ The tokenizer index of the bos token.
69
+ eos_token_id (`int`, *optional*, defaults to 2):
70
+ The tokenizer index of the eos token.
71
+ conv_glu_dim (`int`, *optional*, defaults to 1):
72
+ The dimension of the output of the `Conv1dSubsampler` layer in which GLU is applied on. Though the original
73
+ Flashlight code uses the value of 2, here it's adapted to 1 due to transposition differences.
74
+ conv_dropout (`int`, *optional*, defaults to 0.3):
75
+ The probability of randomly dropping the `Conv1dSubsampler` layer during training.
76
+ num_conv_layers (`int`, *optional*, defaults to 1):
77
+ Number of convolution layers before applying transformer encoder layers.
78
+ conv_kernel (`Sequence[int]`, *optional*, defaults to `(7,)`):
79
+ The kernel size of the 1D convolution applied before transformer layers. `len(conv_kernel)` must be equal
80
+ to `num_conv_layers`.
81
+ conv_stride (`Sequence[int]`, *optional*, defaults to `(3,)`):
82
+ The stride length of the 1D convolution applied before transformer layers. `len(conv_stride)` must be equal
83
+ to `num_conv_layers`.
84
+ input_feat_per_channel (`int`, *optional*, defaults to 80):
85
+ Feature dimensions of the channels of the input to the Conv1D layer.
86
+ input_channels (`int`, *optional*, defaults to 1):
87
+ Number of input channels of the input to the Conv1D layer.
88
+ conv_channels (`List[int]`, *optional*):
89
+ Channel sizes of intermediate Conv1D layers.
90
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
91
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
92
+ instance of [`MCTCTForCTC`].
93
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
94
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
95
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
96
+ of [`MCTCTForCTC`].
97
+
98
+ Example:
99
+
100
+ ```python
101
+ >>> from transformers import MCTCTConfig, MCTCTModel
102
+
103
+ >>> # Initializing a M-CTC-T mctct-large style configuration
104
+ >>> configuration = MCTCTConfig()
105
+
106
+ >>> # Initializing a model (with random weights) from the mctct-large style configuration
107
+ >>> model = MCTCTModel(configuration)
108
+
109
+ >>> # Accessing the model configuration
110
+ >>> configuration = model.config
111
+ ```"""
112
+
113
+ model_type = "mctct"
114
+
115
+ def __init__(
116
+ self,
117
+ vocab_size=8065,
118
+ hidden_size=1536,
119
+ num_hidden_layers=36,
120
+ intermediate_size=6144,
121
+ num_attention_heads=4,
122
+ attention_head_dim=384,
123
+ max_position_embeddings=920,
124
+ layer_norm_eps=1e-5,
125
+ layerdrop=0.3,
126
+ hidden_act="relu",
127
+ initializer_range=0.02,
128
+ hidden_dropout_prob=0.3,
129
+ attention_probs_dropout_prob=0.3,
130
+ pad_token_id=1,
131
+ bos_token_id=0,
132
+ eos_token_id=2,
133
+ conv_glu_dim=1,
134
+ conv_dropout=0.3,
135
+ num_conv_layers=1,
136
+ conv_kernel=(7,),
137
+ conv_stride=(3,),
138
+ input_feat_per_channel=80,
139
+ input_channels=1,
140
+ conv_channels=None,
141
+ ctc_loss_reduction="sum",
142
+ ctc_zero_infinity=False,
143
+ **kwargs,
144
+ ):
145
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
146
+ self.vocab_size = vocab_size
147
+ self.hidden_size = hidden_size
148
+ self.num_hidden_layers = num_hidden_layers
149
+ self.intermediate_size = intermediate_size
150
+ self.num_attention_heads = num_attention_heads
151
+ self.attention_head_dim = attention_head_dim
152
+ self.max_position_embeddings = max_position_embeddings
153
+ self.layer_norm_eps = layer_norm_eps
154
+ self.layerdrop = layerdrop
155
+ self.hidden_act = hidden_act
156
+ self.initializer_range = initializer_range
157
+ self.hidden_dropout_prob = hidden_dropout_prob
158
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
159
+ self.pad_token_id = pad_token_id
160
+ self.bos_token_id = bos_token_id
161
+ self.eos_token_id = eos_token_id
162
+ self.conv_glu_dim = conv_glu_dim
163
+ self.conv_dropout = conv_dropout
164
+ self.num_conv_layers = num_conv_layers
165
+ self.input_feat_per_channel = input_feat_per_channel
166
+ self.input_channels = input_channels
167
+ self.conv_channels = conv_channels
168
+ self.ctc_loss_reduction = ctc_loss_reduction
169
+ self.ctc_zero_infinity = ctc_zero_infinity
170
+
171
+ # prevents config testing fail with exporting to json
172
+ self.conv_kernel = list(conv_kernel)
173
+ self.conv_stride = list(conv_stride)
174
+
175
+ if len(self.conv_kernel) != self.num_conv_layers:
176
+ raise ValueError(
177
+ "Configuration for convolutional module is incorrect. "
178
+ "It is required that `len(config.conv_kernel)` == `config.num_conv_layers` "
179
+ f"but is `len(config.conv_kernel) = {len(self.conv_kernel)}`, "
180
+ f"`config.num_conv_layers = {self.num_conv_layers}`."
181
+ )
182
+
183
+
184
+ __all__ = ["MCTCTConfig"]
docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for M-CTC-T
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ import numpy as np
22
+
23
+ from ....audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function
24
+ from ....feature_extraction_sequence_utils import SequenceFeatureExtractor
25
+ from ....feature_extraction_utils import BatchFeature
26
+ from ....file_utils import PaddingStrategy, TensorType
27
+ from ....utils import logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
34
+ r"""
35
+ Constructs a M-CTC-T feature extractor.
36
+
37
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
38
+ most of the main methods. Users should refer to this superclass for more information regarding those methods. This
39
+ code has been adapted from Flashlight's C++ code. For more information about the implementation, one can refer to
40
+ this [notebook](https://colab.research.google.com/drive/1GLtINkkhzms-IsdcGy_-tVCkv0qNF-Gt#scrollTo=pMCRGMmUC_an)
41
+ that takes the user step-by-step in the implementation.
42
+
43
+ Args:
44
+ feature_size (`int`, defaults to 80):
45
+ The feature dimension of the extracted features. This is the number of mel_frequency
46
+ sampling_rate (`int`, defaults to 16000):
47
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
48
+ padding_value (`float`, defaults to 0.0):
49
+ The value that is used to fill the padding values.
50
+ hop_length (`int`, defaults to 10):
51
+ Number of audio samples between windows. Otherwise referred to as "shift" in many papers.
52
+ win_length (`int`, defaults to 25):
53
+ Number of ms per window
54
+ win_function (`str`, defaults to `"hamming_window"`):
55
+ Name for the window function used for windowing, must be accessible via `torch.{win_function}`
56
+ frame_signal_scale (`float`, defaults to 32768.0):
57
+ Constant multiplied in creating the frames before applying DFT.
58
+ preemphasis_coeff (`float`, defaults to 0.97):
59
+ Constant multiplied in applying Pre-emphasis before DFT.
60
+ mel_floor (`float` defaults to 1.0):
61
+ Minimum value of mel frequency banks.
62
+ normalize_means (`bool`, *optional*, defaults to `True`):
63
+ Whether or not to zero-mean normalize the extracted features.
64
+ normalize_vars (`bool`, *optional*, defaults to `True`):
65
+ Whether or not to unit-variance normalize the extracted features.
66
+ """
67
+
68
+ model_input_names = ["input_features", "attention_mask"]
69
+
70
+ def __init__(
71
+ self,
72
+ feature_size=80,
73
+ sampling_rate=16000,
74
+ padding_value=0.0,
75
+ hop_length=10,
76
+ win_length=25,
77
+ win_function="hamming_window",
78
+ frame_signal_scale=32768.0,
79
+ preemphasis_coeff=0.97,
80
+ mel_floor=1.0,
81
+ normalize_means=True,
82
+ normalize_vars=True,
83
+ return_attention_mask=False,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
87
+
88
+ self.feature_size = feature_size
89
+ self.sampling_rate = sampling_rate
90
+ self.padding_value = padding_value
91
+ self.hop_length = hop_length
92
+ self.win_length = win_length
93
+ self.frame_signal_scale = frame_signal_scale
94
+ self.preemphasis_coeff = preemphasis_coeff
95
+ self.mel_floor = mel_floor
96
+ self.normalize_means = normalize_means
97
+ self.normalize_vars = normalize_vars
98
+ self.win_function = win_function
99
+ self.return_attention_mask = return_attention_mask
100
+
101
+ self.sample_size = win_length * sampling_rate // 1000
102
+ self.sample_stride = hop_length * sampling_rate // 1000
103
+
104
+ self.n_fft = optimal_fft_length(self.sample_size)
105
+ self.n_freqs = (self.n_fft // 2) + 1
106
+
107
+ def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray:
108
+ """
109
+ Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
110
+ """
111
+ if self.win_function == "hamming_window":
112
+ window = window_function(window_length=self.sample_size, name=self.win_function, periodic=False)
113
+ else:
114
+ window = window_function(window_length=self.sample_size, name=self.win_function)
115
+
116
+ fbanks = mel_filter_bank(
117
+ num_frequency_bins=self.n_freqs,
118
+ num_mel_filters=self.feature_size,
119
+ min_frequency=0.0,
120
+ max_frequency=self.sampling_rate / 2.0,
121
+ sampling_rate=self.sampling_rate,
122
+ )
123
+
124
+ msfc_features = spectrogram(
125
+ one_waveform * self.frame_signal_scale,
126
+ window=window,
127
+ frame_length=self.sample_size,
128
+ hop_length=self.sample_stride,
129
+ fft_length=self.n_fft,
130
+ center=False,
131
+ preemphasis=self.preemphasis_coeff,
132
+ mel_filters=fbanks,
133
+ mel_floor=self.mel_floor,
134
+ log_mel="log",
135
+ )
136
+ return msfc_features.T
137
+
138
+ def _normalize_one(self, x, input_length, padding_value):
139
+ # make sure we normalize float32 arrays
140
+ if self.normalize_means:
141
+ mean = x[:input_length].mean(axis=0)
142
+ x = np.subtract(x, mean)
143
+ if self.normalize_vars:
144
+ std = x[:input_length].std(axis=0)
145
+ x = np.divide(x, std)
146
+
147
+ if input_length < x.shape[0]:
148
+ x[input_length:] = padding_value
149
+
150
+ # make sure array is in float32
151
+ x = x.astype(np.float32)
152
+
153
+ return x
154
+
155
+ def normalize(
156
+ self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
157
+ ) -> List[np.ndarray]:
158
+ lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
159
+ return [self._normalize_one(x, n, self.padding_value) for x, n in zip(input_features, lengths)]
160
+
161
+ def __call__(
162
+ self,
163
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
164
+ padding: Union[bool, str, PaddingStrategy] = False,
165
+ max_length: Optional[int] = None,
166
+ truncation: bool = False,
167
+ pad_to_multiple_of: Optional[int] = None,
168
+ return_attention_mask: Optional[bool] = None,
169
+ return_tensors: Optional[Union[str, TensorType]] = None,
170
+ sampling_rate: Optional[int] = None,
171
+ **kwargs,
172
+ ) -> BatchFeature:
173
+ """
174
+ Main method to featurize and prepare for the model one or several sequence(s). sequences. It returns the
175
+ log-mel spectrogram of the input audio, as implemented in the original Flashlight MFSC feature extraction code.
176
+
177
+ Args:
178
+ raw_speech (`torch.Tensor`, `np.ndarray`, `List[float]`, `List[torch.Tensor]`, `List[np.ndarray]`, `List[List[float]]`):
179
+ The sequence or batch of sequences to be padded. Each sequence can be a tensor, a numpy array, a list
180
+ of float values, a list of tensors, a list of numpy arrays or a list of list of float values. Must be
181
+ mono channel audio, not stereo, i.e. single float per timestep.
182
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
183
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
184
+ index) among:
185
+
186
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
187
+ sequence if provided).
188
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
189
+ acceptable input length for the model if that argument is not provided.
190
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
191
+ lengths).
192
+ max_length (`int`, *optional*):
193
+ Maximum length of the returned list and optionally padding length (see above).
194
+ truncation (`bool`):
195
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
196
+ pad_to_multiple_of (`int`, *optional*):
197
+ If set will pad the sequence to a multiple of the provided value.
198
+
199
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
200
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
201
+ return_attention_mask (`bool`, *optional*):
202
+ Whether to return the attention mask. If left to the default, will return the attention mask according
203
+ to the specific feature_extractor's default.
204
+
205
+ [What are attention masks?](../glossary#attention-mask)
206
+
207
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
208
+ If set, will return tensors instead of list of python integers. Acceptable values are:
209
+
210
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
211
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
212
+ - `'np'`: Return Numpy `np.ndarray` objects.
213
+ sampling_rate (`int`, *optional*):
214
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
215
+ `sampling_rate` at the forward call to prevent silent errors.
216
+ padding_value (`float`, defaults to 0.0):
217
+ """
218
+
219
+ if sampling_rate is not None:
220
+ if sampling_rate != self.sampling_rate:
221
+ raise ValueError(
222
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
223
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
224
+ f" {self.sampling_rate} and not {sampling_rate}."
225
+ )
226
+ else:
227
+ logger.warning(
228
+ "It is strongly recommended to pass the ``sampling_rate`` argument to this function. "
229
+ "Failing to do so can result in silent errors that might be hard to debug."
230
+ )
231
+
232
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
233
+ if is_batched_numpy and len(raw_speech.shape) > 2:
234
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
235
+ is_batched = is_batched_numpy or (
236
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
237
+ )
238
+
239
+ if is_batched:
240
+ raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
241
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
242
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
243
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
244
+ raw_speech = raw_speech.astype(np.float32)
245
+
246
+ # always return batch
247
+ if not is_batched:
248
+ raw_speech = [raw_speech]
249
+
250
+ # extract fbank features
251
+ features = [self._extract_mfsc_features(one_waveform) for one_waveform in raw_speech]
252
+
253
+ # convert into correct format for padding
254
+ encoded_inputs = BatchFeature({"input_features": features})
255
+
256
+ padded_inputs = self.pad(
257
+ encoded_inputs,
258
+ padding=padding,
259
+ max_length=max_length,
260
+ truncation=truncation,
261
+ pad_to_multiple_of=pad_to_multiple_of,
262
+ return_attention_mask=True,
263
+ **kwargs,
264
+ )
265
+ # make sure list is in array format
266
+ input_features = padded_inputs.get("input_features")
267
+ if isinstance(input_features[0], list):
268
+ padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
269
+
270
+ attention_mask = padded_inputs.get("attention_mask")
271
+ if attention_mask is not None:
272
+ padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
273
+
274
+ if self.normalize_means or self.normalize_vars:
275
+ attention_mask = (
276
+ np.array(attention_mask, dtype=np.int32)
277
+ if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
278
+ and padding
279
+ else None
280
+ )
281
+ padded_inputs["input_features"] = self.normalize(
282
+ padded_inputs["input_features"], attention_mask=attention_mask
283
+ )
284
+
285
+ if return_tensors is not None:
286
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
287
+
288
+ return padded_inputs
289
+
290
+
291
+ __all__ = ["MCTCTFeatureExtractor"]
docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch M-CTC-T model."""
16
+
17
+ import math
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+
24
+ from ....activations import ACT2FN
25
+ from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
26
+ from ....integrations.deepspeed import is_deepspeed_zero3_enabled
27
+ from ....integrations.fsdp import is_fsdp_managed_module
28
+ from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
29
+ from ....modeling_outputs import BaseModelOutput, CausalLMOutput
30
+ from ....modeling_utils import (
31
+ PreTrainedModel,
32
+ apply_chunking_to_forward,
33
+ find_pruneable_heads_and_indices,
34
+ prune_linear_layer,
35
+ )
36
+ from ....utils import logging
37
+ from .configuration_mctct import MCTCTConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _HIDDEN_STATES_START_POSITION = 1
43
+
44
+ _CONFIG_FOR_DOC = "MCTCTConfig"
45
+
46
+ # Base docstring
47
+ _CHECKPOINT_FOR_DOC = "speechbrain/m-ctc-t-large"
48
+ _EXPECTED_OUTPUT_SHAPE = [1, 195, 1536]
49
+
50
+ # CTC docstring
51
+ _CTC_EXPECTED_OUTPUT = '"Mr. Quilter is the apostle of the middle classes, and we\'re glad to welcome his gospel."'
52
+ _CTC_EXPECTED_LOSS = 1885.65
53
+
54
+
55
+ class MCTCTConv1dSubsampler(nn.Module):
56
+ """
57
+ Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation
58
+ via gated linear units (https://arxiv.org/abs/1911.08460)
59
+ """
60
+
61
+ def __init__(self, config):
62
+ super().__init__()
63
+ self.config = config
64
+ self.glu_dim = config.conv_glu_dim
65
+
66
+ self.dropout = nn.Dropout(config.conv_dropout)
67
+
68
+ self.num_layers = config.num_conv_layers
69
+ self.in_channels = config.input_feat_per_channel * config.input_channels
70
+
71
+ if self.num_layers > 1:
72
+ if config.conv_channels is None:
73
+ raise ValueError(
74
+ "Need to specify `conv_channels` configuration in `MCTCTConfig` to use multiple convolution"
75
+ " layers."
76
+ )
77
+
78
+ self.mid_channels = config.conv_channels
79
+ else:
80
+ self.mid_channels = None
81
+
82
+ self.out_channels = config.hidden_size * 2 # considering GLU halving
83
+ self.kernel_size = config.conv_kernel
84
+ self.stride = config.conv_stride
85
+
86
+ # NOTE: MCTCT by construction only uses one convolution kernel. I've made this flexible to allow for
87
+ # multiple layers of convolutions, but not sure if this model definition should just restrict it
88
+ # to one layer. This becomes especially relevant when considering the padding like line 1 of forward().
89
+ self.conv_layers = nn.ModuleList(
90
+ nn.Conv1d(
91
+ self.in_channels if i == 0 else self.mid_channels[i],
92
+ self.mid_channels[i] if i < self.num_layers - 1 else self.out_channels,
93
+ kernel_size=k,
94
+ stride=self.stride[i],
95
+ padding="valid",
96
+ )
97
+ for i, k in enumerate(self.kernel_size)
98
+ )
99
+
100
+ def forward(self, input_features):
101
+ # NOTE: in reference to the NOTE in __init__, right now it just calculates padding as if
102
+ # there will be just one conv layer.
103
+ padding = sum([size // 2 for size in self.kernel_size]) # (7, 7) -> (3, 3)
104
+
105
+ input_features = torch.nn.functional.pad(input_features, (0, 0, padding, padding), "constant", 0)
106
+ hidden_states = input_features.transpose(1, 2).contiguous() # -> Batch x Frame x Time
107
+ for conv in self.conv_layers:
108
+ hidden_states = conv(hidden_states)
109
+ hidden_states = nn.functional.glu(hidden_states, dim=self.glu_dim)
110
+ hidden_states = self.dropout(hidden_states)
111
+
112
+ hidden_states = hidden_states.transpose(1, 2).contiguous() # -> Batch x Time x Frame
113
+ return hidden_states
114
+
115
+
116
+ class MCTCTEmbeddings(nn.Module):
117
+ """Construct the embeddings from word, position and token_type embeddings."""
118
+
119
+ def __init__(self, config):
120
+ super().__init__()
121
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
122
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
123
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
124
+
125
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
126
+ # any TensorFlow checkpoint file
127
+ # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
128
+ self.LayerNorm = MCTCTLayerNorm()
129
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
130
+
131
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
132
+ self.register_buffer(
133
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
134
+ )
135
+ self.register_buffer(
136
+ "token_type_ids",
137
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
138
+ persistent=False,
139
+ )
140
+
141
+ def forward(
142
+ self, input_features=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
143
+ ):
144
+ input_shape = input_features.size() if input_features is not None else inputs_embeds.size()[:-1]
145
+
146
+ seq_length = input_shape[1]
147
+
148
+ if position_ids is None:
149
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
150
+
151
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
152
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
153
+ # issue #5664
154
+ if token_type_ids is None:
155
+ if hasattr(self, "token_type_ids"):
156
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
157
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
158
+ token_type_ids = buffered_token_type_ids_expanded
159
+ else:
160
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
161
+
162
+ if inputs_embeds is None:
163
+ inputs_embeds = self.word_embeddings(input_features)
164
+
165
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
166
+
167
+ embeddings = inputs_embeds + token_type_embeddings
168
+
169
+ embeddings = self.LayerNorm(embeddings)
170
+ embeddings = self.dropout(embeddings)
171
+ return embeddings
172
+
173
+
174
+ class MCTCTSelfAttention(nn.Module):
175
+ def __init__(self, config):
176
+ super().__init__()
177
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
178
+ raise ValueError(
179
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
180
+ f"heads ({config.num_attention_heads})"
181
+ )
182
+
183
+ self.num_attention_heads = config.num_attention_heads
184
+ self.attention_head_size = config.attention_head_dim
185
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
186
+
187
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
188
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
189
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
190
+
191
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
192
+
193
+ self.max_position_embeddings = config.max_position_embeddings
194
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
195
+
196
+ self.is_decoder = config.is_decoder
197
+
198
+ def transpose_for_scores(self, x):
199
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
200
+ x = x.view(*new_x_shape)
201
+ return x.permute(0, 2, 1, 3)
202
+
203
+ def reshape_fortran(self, x, shape):
204
+ if len(x.shape) > 0:
205
+ x = x.permute(*reversed(range(len(x.shape))))
206
+ return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
207
+
208
+ def relative_position_embedding_rotate(self, scores):
209
+ # NOTE: should re-evaluate whether this re-implementation was truly necessary
210
+ # or the reason why my complete re-haul worked was due to some other part
211
+ # of the code. Adding this and the reshape fortrain code seems very undesirable.
212
+ scores = scores.permute(0, 2, 3, 1) # e.g. [10, 1839, 14, 4]
213
+
214
+ batch, hidden_state, seq_len, heads = scores.shape
215
+
216
+ # e.g. [10, 1853, 14, 4]
217
+ scores = torch.cat((scores, torch.zeros((batch, seq_len, seq_len, heads), device=scores.device)), dim=1)
218
+
219
+ # e.g. [10, 25942, 1, 4]
220
+ scores = self.reshape_fortran(scores, [batch, (hidden_state + seq_len) * seq_len, 1, heads])
221
+
222
+ # e.g. [10, 25928, 1, 4]
223
+ scores = scores[:, : (seq_len + hidden_state - 1) * seq_len]
224
+
225
+ # e.g. [10, 1852, 14, 4]
226
+ scores = self.reshape_fortran(scores, [batch, hidden_state + seq_len - 1, seq_len, heads])
227
+
228
+ halfpoint = hidden_state // 2
229
+ scores = scores[:, halfpoint : halfpoint + seq_len].transpose(1, 2) # e.g. [10, 14, 14, 4]
230
+
231
+ return scores.permute(0, 3, 1, 2)
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states,
236
+ attention_mask=None,
237
+ head_mask=None,
238
+ output_attentions=False,
239
+ ):
240
+ mixed_query_layer = self.query(hidden_states)
241
+ mixed_query_layer = mixed_query_layer / math.sqrt(self.attention_head_size)
242
+
243
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
244
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
245
+
246
+ query_layer = self.transpose_for_scores(mixed_query_layer)
247
+
248
+ # Take the dot product between "query" and "key" to get the raw attention scores.
249
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
250
+
251
+ # relative key position embeddings
252
+ positional_embedding = self.distance_embedding.weight
253
+ relative_position_scores = torch.einsum("lh, bche -> bcle", positional_embedding, query_layer.transpose(2, 3))
254
+
255
+ relative_position_scores = self.relative_position_embedding_rotate(relative_position_scores)
256
+ attention_scores = attention_scores + relative_position_scores
257
+
258
+ if attention_mask is not None:
259
+ # Apply the attention mask is (precomputed for all layers in MCTCTModel forward() function)
260
+ attention_scores = attention_scores + attention_mask
261
+
262
+ # Normalize the attention scores to probabilities.
263
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
264
+
265
+ # This is actually dropping out entire tokens to attend to, which might
266
+ # seem a bit unusual, but is taken from the original Transformer paper.
267
+ attention_probs = self.dropout(attention_probs)
268
+
269
+ # Mask heads if we want to
270
+ if head_mask is not None:
271
+ attention_probs = attention_probs * head_mask
272
+
273
+ context_layer = torch.matmul(attention_probs, value_layer)
274
+
275
+ context_layer = context_layer.permute(0, 2, 1, 3).flatten(start_dim=-2)
276
+
277
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
278
+
279
+ return outputs
280
+
281
+
282
+ class MCTCTLayerNorm(nn.Module):
283
+ def __init__(self):
284
+ super().__init__()
285
+ self.singleton_weight = nn.Parameter(torch.ones(1))
286
+ self.singleton_bias = nn.Parameter(torch.zeros(1))
287
+
288
+ def forward(self, hidden_states):
289
+ return (hidden_states * self.singleton_weight) + self.singleton_bias
290
+
291
+
292
+ class MCTCTSelfOutput(nn.Module):
293
+ def __init__(self, config):
294
+ super().__init__()
295
+ self.config = config
296
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
297
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
298
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
299
+
300
+ def forward(self, hidden_states, input_tensor):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.dropout(hidden_states)
303
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
304
+ return hidden_states
305
+
306
+
307
+ class MCTCTAttention(nn.Module):
308
+ def __init__(self, config):
309
+ super().__init__()
310
+ self.self = MCTCTSelfAttention(config)
311
+ self.output = MCTCTSelfOutput(config)
312
+ self.pruned_heads = set()
313
+
314
+ def prune_heads(self, heads):
315
+ if len(heads) == 0:
316
+ return
317
+ heads, index = find_pruneable_heads_and_indices(
318
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
319
+ )
320
+
321
+ # Prune linear layers
322
+ self.self.query = prune_linear_layer(self.self.query, index)
323
+ self.self.key = prune_linear_layer(self.self.key, index)
324
+ self.self.value = prune_linear_layer(self.self.value, index)
325
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
326
+
327
+ # Update hyper params and store pruned heads
328
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
329
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
330
+ self.pruned_heads = self.pruned_heads.union(heads)
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states,
335
+ attention_mask=None,
336
+ head_mask=None,
337
+ output_attentions=False,
338
+ ):
339
+ self_outputs = self.self(
340
+ hidden_states,
341
+ attention_mask,
342
+ head_mask,
343
+ output_attentions,
344
+ )
345
+ attention_output = self.output(self_outputs[0], hidden_states)
346
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
347
+
348
+ return outputs
349
+
350
+
351
+ class MCTCTIntermediate(nn.Module):
352
+ def __init__(self, config):
353
+ super().__init__()
354
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
355
+ if isinstance(config.hidden_act, str):
356
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
357
+ else:
358
+ self.intermediate_act_fn = config.hidden_act
359
+
360
+ def forward(self, hidden_states):
361
+ hidden_states = self.dense(hidden_states)
362
+ hidden_states = self.intermediate_act_fn(hidden_states)
363
+ return hidden_states
364
+
365
+
366
+ class MCTCTOutput(nn.Module):
367
+ def __init__(self, config):
368
+ super().__init__()
369
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
370
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
371
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
372
+
373
+ def forward(self, hidden_states, input_tensor):
374
+ hidden_states = self.dense(hidden_states)
375
+ hidden_states = self.dropout(hidden_states)
376
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
377
+ return hidden_states
378
+
379
+
380
+ class MCTCTLayer(nn.Module):
381
+ def __init__(self, config: MCTCTConfig):
382
+ super().__init__()
383
+
384
+ self.seq_len_dim = 1
385
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
386
+
387
+ self.intermediate = MCTCTIntermediate(config)
388
+ self.attention = MCTCTAttention(config)
389
+ self.is_decoder = config.is_decoder
390
+ self.output = MCTCTOutput(config)
391
+
392
+ def forward(
393
+ self,
394
+ hidden_states,
395
+ attention_mask=None,
396
+ head_mask=None,
397
+ output_attentions=False,
398
+ ):
399
+ self_attention_outputs = self.attention(
400
+ hidden_states, attention_mask, head_mask, output_attentions=output_attentions
401
+ )
402
+ attention_output = self_attention_outputs[0]
403
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
404
+
405
+ layer_output = apply_chunking_to_forward(
406
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
407
+ )
408
+
409
+ outputs = (layer_output,) + outputs
410
+
411
+ return outputs
412
+
413
+ def feed_forward_chunk(self, attention_output):
414
+ intermediate_output = self.intermediate(attention_output)
415
+ layer_output = self.output(intermediate_output, attention_output)
416
+ return layer_output
417
+
418
+
419
+ class MCTCTPreTrainedModel(PreTrainedModel):
420
+ """
421
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
422
+ models.
423
+ """
424
+
425
+ config_class = MCTCTConfig
426
+ base_model_prefix = "mctct"
427
+ main_input_name = "input_features"
428
+ supports_gradient_checkpointing = True
429
+
430
+ def _init_weights(self, module):
431
+ """Initialize the weights"""
432
+ std = self.config.initializer_range
433
+ if isinstance(module, nn.Linear):
434
+ # Slightly different from the TF version which uses truncated_normal for initialization
435
+ # cf https://github.com/pytorch/pytorch/pull/5617
436
+ module.weight.data.normal_(mean=0.0, std=std)
437
+ if module.bias is not None:
438
+ module.bias.data.zero_()
439
+ elif isinstance(module, nn.Embedding):
440
+ module.weight.data.normal_(mean=0.0, std=std)
441
+ if module.padding_idx is not None:
442
+ module.weight.data[module.padding_idx].zero_()
443
+ elif isinstance(module, nn.LayerNorm):
444
+ module.bias.data.zero_()
445
+ module.weight.data.fill_(1.0)
446
+ elif isinstance(module, MCTCTLayerNorm):
447
+ module.singleton_weight.data.fill_(1.0)
448
+ module.singleton_bias.data.zero_()
449
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
450
+ module.weight.data.normal_(mean=0.0, std=std)
451
+ if module.bias is not None:
452
+ module.bias.data.zero_()
453
+
454
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
455
+ """
456
+ Computes the output length of the convolutional layers
457
+ """
458
+ dilation = 1
459
+ for _, kernel_sz, stride in zip(
460
+ range(self.config.num_conv_layers), self.config.conv_kernel, self.config.conv_stride
461
+ ):
462
+ padding = kernel_sz // 2
463
+ input_lengths = input_lengths + 2 * padding - dilation * (kernel_sz - 1) - 1
464
+ input_lengths = torch.div(input_lengths, stride, rounding_mode="trunc") + 1
465
+
466
+ return input_lengths
467
+
468
+ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask):
469
+ # generate creates 3D attention mask, because of the shape of input_features
470
+ # convert it to 2D if thats the case
471
+ if len(attention_mask.shape) > 2:
472
+ attention_mask = attention_mask[:, :, -1]
473
+
474
+ # subsampled_lengths = attention_mask.sum(-1)
475
+ subsampled_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
476
+ bsz = attention_mask.size()[0]
477
+ attention_mask = torch.zeros(
478
+ (bsz, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
479
+ )
480
+
481
+ # these two operations makes sure that all values
482
+ # before the output lengths indices are attended to
483
+ attention_mask[(torch.arange(bsz, device=attention_mask.device), subsampled_lengths - 1)] = 1
484
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long()
485
+ return attention_mask
486
+
487
+
488
+ MCTCT_START_DOCSTRING = r"""
489
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
490
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
491
+ behavior.
492
+
493
+ Parameters:
494
+ config ([`MCTCTConfig`]): Model configuration class with all the parameters of the model.
495
+ Initializing with a config file does not load the weights associated with the model, only the
496
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
497
+ """
498
+
499
+ MCTCT_INPUTS_DOCSTRING = r"""
500
+ Args:
501
+ input_features (`torch.LongTensor` of shape `({0})`):
502
+ Indices of input sequence tokens in the vocabulary.
503
+
504
+ Indices can be obtained using [`Wav2Vec2CTCTokenizer`]. See [`PreTrainedTokenizer.encode`] and
505
+ [`PreTrainedTokenizer.__call__`] for details.
506
+
507
+ [What are input IDs?](../glossary#input-ids)
508
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
509
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
510
+
511
+ - 1 for tokens that are **not masked**,
512
+ - 0 for tokens that are **masked**.
513
+
514
+ [What are attention masks?](../glossary#attention-mask)
515
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
516
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
517
+
518
+ - 1 indicates the head is **not masked**,
519
+ - 0 indicates the head is **masked**.
520
+ output_attentions (`bool`, *optional*):
521
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
522
+ tensors for more detail.
523
+ output_hidden_states (`bool`, *optional*):
524
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
525
+ more detail.
526
+ return_dict (`bool`, *optional*):
527
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
528
+ """
529
+
530
+
531
+ class MCTCTEncoder(MCTCTPreTrainedModel):
532
+ def __init__(self, config: MCTCTConfig):
533
+ super().__init__(config)
534
+ self.hidden_dropout_prob = config.hidden_dropout_prob
535
+
536
+ self.layer_norm = MCTCTLayerNorm()
537
+ self.conv = MCTCTConv1dSubsampler(config)
538
+ self.layers = nn.ModuleList([MCTCTLayer(config) for _ in range(config.num_hidden_layers)])
539
+
540
+ self.gradient_checkpointing = False
541
+
542
+ def forward(
543
+ self,
544
+ input_features: torch.Tensor,
545
+ attention_mask: torch.Tensor,
546
+ head_mask: torch.Tensor,
547
+ output_attentions: bool = False,
548
+ output_hidden_states: bool = False,
549
+ return_dict: bool = True,
550
+ ) -> Union[Tuple, BaseModelOutput]:
551
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
552
+ output_hidden_states = (
553
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
554
+ )
555
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
556
+
557
+ input_features = self.layer_norm(input_features)
558
+
559
+ inputs_embeds = self.conv(input_features)
560
+
561
+ # subsample attention mask if necessary
562
+ if attention_mask is not None:
563
+ attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
564
+
565
+ hidden_states = nn.functional.dropout(inputs_embeds, p=self.hidden_dropout_prob, training=self.training)
566
+
567
+ # expand attention_mask
568
+ if attention_mask is not None:
569
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
570
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
571
+
572
+ encoder_states = () if output_hidden_states else None
573
+ all_attentions = () if output_attentions else None
574
+
575
+ # check if head_mask has a correct number of layers specified if desired
576
+ if head_mask is not None:
577
+ if head_mask.size()[0] != len(self.layers):
578
+ raise ValueError(
579
+ f"The head_mask should be specified for {len(self.layers)} layers, "
580
+ f"but it is for {head_mask.size()[0]}."
581
+ )
582
+
583
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
584
+ for idx, encoder_layer in enumerate(self.layers):
585
+ if output_hidden_states:
586
+ encoder_states = encoder_states + (hidden_states,)
587
+
588
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
589
+ dropout_probability = torch.rand([])
590
+
591
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
592
+ if not skip_the_layer or synced_gpus:
593
+ # under fsdp or deepspeed zero3 all gpus must run in sync
594
+ if self.gradient_checkpointing and self.training:
595
+ layer_outputs = self._gradient_checkpointing_func(
596
+ encoder_layer.__call__,
597
+ hidden_states,
598
+ attention_mask,
599
+ (head_mask[idx] if head_mask is not None else None),
600
+ output_attentions,
601
+ )
602
+ else:
603
+ layer_outputs = encoder_layer(
604
+ hidden_states=hidden_states,
605
+ attention_mask=attention_mask,
606
+ output_attentions=output_attentions,
607
+ )
608
+
609
+ hidden_states = layer_outputs[0]
610
+
611
+ if skip_the_layer:
612
+ layer_outputs = (None, None)
613
+
614
+ if output_attentions:
615
+ all_attentions = all_attentions + (layer_outputs[1],)
616
+
617
+ if output_hidden_states:
618
+ encoder_states = encoder_states + (hidden_states,)
619
+
620
+ if not return_dict:
621
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
622
+ return BaseModelOutput(
623
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
624
+ )
625
+
626
+
627
+ @add_start_docstrings(
628
+ "The bare M-CTC-T Model transformer outputting raw hidden-states without any specific head on top.",
629
+ MCTCT_START_DOCSTRING,
630
+ )
631
+ class MCTCTModel(MCTCTPreTrainedModel):
632
+ def __init__(self, config):
633
+ super().__init__(config)
634
+ self.config = config
635
+
636
+ self.encoder = MCTCTEncoder(config)
637
+
638
+ # Initialize weights and apply final processing
639
+ self.post_init()
640
+
641
+ @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
642
+ @add_code_sample_docstrings(
643
+ checkpoint=_CHECKPOINT_FOR_DOC,
644
+ output_type=BaseModelOutput,
645
+ config_class=_CONFIG_FOR_DOC,
646
+ modality="audio",
647
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
648
+ )
649
+ def forward(
650
+ self,
651
+ input_features: torch.Tensor,
652
+ attention_mask: Optional[torch.Tensor] = None,
653
+ head_mask: Optional[torch.Tensor] = None,
654
+ output_attentions: Optional[bool] = None,
655
+ output_hidden_states: Optional[bool] = None,
656
+ return_dict: Optional[bool] = None,
657
+ ) -> Union[Tuple, BaseModelOutput]:
658
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
659
+ output_hidden_states = (
660
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
661
+ )
662
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
663
+
664
+ if input_features is None:
665
+ raise ValueError("You have to specify input_features.")
666
+
667
+ encoder_outputs = self.encoder(
668
+ input_features,
669
+ attention_mask=attention_mask,
670
+ head_mask=head_mask,
671
+ output_attentions=output_attentions,
672
+ output_hidden_states=output_hidden_states,
673
+ return_dict=return_dict,
674
+ )
675
+ sequence_output = encoder_outputs[0]
676
+
677
+ if not return_dict:
678
+ return (sequence_output,) + encoder_outputs[1:]
679
+
680
+ return BaseModelOutput(
681
+ last_hidden_state=sequence_output,
682
+ hidden_states=encoder_outputs.hidden_states,
683
+ attentions=encoder_outputs.attentions,
684
+ )
685
+
686
+
687
+ @add_start_docstrings(
688
+ """MCTCT Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
689
+ MCTCT_START_DOCSTRING,
690
+ )
691
+ class MCTCTForCTC(MCTCTPreTrainedModel):
692
+ def __init__(self, config):
693
+ super().__init__(config)
694
+
695
+ self.mctct = MCTCTModel(config)
696
+
697
+ if config.vocab_size is None:
698
+ raise ValueError(
699
+ f"You are trying to instantiate {self.__class__} with a configuration that "
700
+ "does not define the vocabulary size of the language model head. Please "
701
+ "instantiate the model as follows: `MCTCTForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
702
+ "or define `vocab_size` of your model's configuration."
703
+ )
704
+ output_hidden_size = config.hidden_size
705
+
706
+ self.ctc_head = nn.Linear(output_hidden_size, config.vocab_size)
707
+
708
+ # Initialize weights and apply final processing
709
+ self.post_init()
710
+
711
+ @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING)
712
+ @add_code_sample_docstrings(
713
+ checkpoint=_CHECKPOINT_FOR_DOC,
714
+ output_type=CausalLMOutput,
715
+ config_class=_CONFIG_FOR_DOC,
716
+ expected_output=_CTC_EXPECTED_OUTPUT,
717
+ expected_loss=_CTC_EXPECTED_LOSS,
718
+ )
719
+ def forward(
720
+ self,
721
+ input_features: torch.Tensor,
722
+ attention_mask: Optional[torch.Tensor] = None,
723
+ head_mask: Optional[torch.Tensor] = None,
724
+ output_attentions: Optional[bool] = None,
725
+ output_hidden_states: Optional[bool] = None,
726
+ return_dict: Optional[bool] = None,
727
+ labels: Optional[torch.LongTensor] = None,
728
+ ) -> Union[Tuple, CausalLMOutput]:
729
+ r"""
730
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
731
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
732
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
733
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
734
+ config.vocab_size - 1]`.
735
+ """
736
+ if labels is not None and labels.max() >= self.config.vocab_size:
737
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
738
+
739
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
740
+ outputs = self.mctct(
741
+ input_features,
742
+ attention_mask=attention_mask,
743
+ head_mask=head_mask,
744
+ output_attentions=output_attentions,
745
+ output_hidden_states=output_hidden_states,
746
+ return_dict=return_dict,
747
+ )
748
+
749
+ hidden_states = outputs[0]
750
+
751
+ logits = self.ctc_head(hidden_states)
752
+
753
+ loss = None
754
+ if labels is not None:
755
+ # retrieve loss input_lengths from attention_mask
756
+ attention_mask = (
757
+ attention_mask
758
+ if attention_mask is not None
759
+ else torch.ones(input_features.shape[:-1], dtype=torch.long)
760
+ )
761
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
762
+ # assuming that padded tokens are filled with -100
763
+ # when not being attended to
764
+ labels_mask = labels >= 0
765
+ target_lengths = labels_mask.sum(-1)
766
+ flattened_targets = labels.masked_select(labels_mask)
767
+
768
+ # ctc_loss doesn't support fp16
769
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
770
+
771
+ with torch.backends.cudnn.flags(enabled=False):
772
+ loss = nn.functional.ctc_loss(
773
+ log_probs,
774
+ flattened_targets,
775
+ input_lengths,
776
+ target_lengths,
777
+ blank=self.config.pad_token_id,
778
+ reduction=self.config.ctc_loss_reduction,
779
+ zero_infinity=self.config.ctc_zero_infinity,
780
+ )
781
+
782
+ if not return_dict:
783
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
784
+ return ((loss,) + output) if loss is not None else output
785
+
786
+ return CausalLMOutput(
787
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
788
+ )
789
+
790
+
791
+ __all__ = ["MCTCTForCTC", "MCTCTModel", "MCTCTPreTrainedModel"]
docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Speech processor class for M-CTC-T
17
+ """
18
+
19
+ import warnings
20
+ from contextlib import contextmanager
21
+
22
+ from ....processing_utils import ProcessorMixin
23
+
24
+
25
+ class MCTCTProcessor(ProcessorMixin):
26
+ r"""
27
+ Constructs a MCTCT processor which wraps a MCTCT feature extractor and a MCTCT tokenizer into a single processor.
28
+
29
+ [`MCTCTProcessor`] offers all the functionalities of [`MCTCTFeatureExtractor`] and [`AutoTokenizer`]. See the
30
+ [`~MCTCTProcessor.__call__`] and [`~MCTCTProcessor.decode`] for more information.
31
+
32
+ Args:
33
+ feature_extractor (`MCTCTFeatureExtractor`):
34
+ An instance of [`MCTCTFeatureExtractor`]. The feature extractor is a required input.
35
+ tokenizer (`AutoTokenizer`):
36
+ An instance of [`AutoTokenizer`]. The tokenizer is a required input.
37
+ """
38
+
39
+ feature_extractor_class = "MCTCTFeatureExtractor"
40
+ tokenizer_class = "AutoTokenizer"
41
+
42
+ def __init__(self, feature_extractor, tokenizer):
43
+ super().__init__(feature_extractor, tokenizer)
44
+ self.current_processor = self.feature_extractor
45
+ self._in_target_context_manager = False
46
+
47
+ def __call__(self, *args, **kwargs):
48
+ """
49
+ When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
50
+ [`~MCTCTFeatureExtractor.__call__`] and returns its output. If used in the context
51
+ [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's
52
+ [`~AutoTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information.
53
+ """
54
+ # For backward compatibility
55
+ if self._in_target_context_manager:
56
+ return self.current_processor(*args, **kwargs)
57
+
58
+ if "raw_speech" in kwargs:
59
+ warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
60
+ audio = kwargs.pop("raw_speech")
61
+ else:
62
+ audio = kwargs.pop("audio", None)
63
+ sampling_rate = kwargs.pop("sampling_rate", None)
64
+ text = kwargs.pop("text", None)
65
+ if len(args) > 0:
66
+ audio = args[0]
67
+ args = args[1:]
68
+
69
+ if audio is None and text is None:
70
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
71
+
72
+ if audio is not None:
73
+ inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
74
+ if text is not None:
75
+ encodings = self.tokenizer(text, **kwargs)
76
+
77
+ if text is None:
78
+ return inputs
79
+ elif audio is None:
80
+ return encodings
81
+ else:
82
+ inputs["labels"] = encodings["input_ids"]
83
+ return inputs
84
+
85
+ def batch_decode(self, *args, **kwargs):
86
+ """
87
+ This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
88
+ to the docstring of this method for more information.
89
+ """
90
+ return self.tokenizer.batch_decode(*args, **kwargs)
91
+
92
+ def pad(self, *args, **kwargs):
93
+ """
94
+ When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
95
+ [`~MCTCTFeatureExtractor.pad`] and returns its output. If used in the context
96
+ [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
97
+ [`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
98
+ """
99
+ # For backward compatibility
100
+ if self._in_target_context_manager:
101
+ return self.current_processor.pad(*args, **kwargs)
102
+
103
+ input_features = kwargs.pop("input_features", None)
104
+ labels = kwargs.pop("labels", None)
105
+ if len(args) > 0:
106
+ input_features = args[0]
107
+ args = args[1:]
108
+
109
+ if input_features is not None:
110
+ input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
111
+ if labels is not None:
112
+ labels = self.tokenizer.pad(labels, **kwargs)
113
+
114
+ if labels is None:
115
+ return input_features
116
+ elif input_features is None:
117
+ return labels
118
+ else:
119
+ input_features["labels"] = labels["input_ids"]
120
+ return input_features
121
+
122
+ def decode(self, *args, **kwargs):
123
+ """
124
+ This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
125
+ docstring of this method for more information.
126
+ """
127
+ return self.tokenizer.decode(*args, **kwargs)
128
+
129
+ @contextmanager
130
+ def as_target_processor(self):
131
+ """
132
+ Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT.
133
+ """
134
+ warnings.warn(
135
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
136
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
137
+ "your audio inputs, or in a separate call."
138
+ )
139
+ self._in_target_context_manager = True
140
+ self.current_processor = self.tokenizer
141
+ yield
142
+ self.current_processor = self.feature_extractor
143
+ self._in_target_context_manager = False
144
+
145
+
146
+ __all__ = ["MCTCTProcessor"]
docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_mega import *
22
+ from .modeling_mega import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Mega Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MEGA configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from ....configuration_utils import PretrainedConfig
21
+ from ....onnx import OnnxConfig
22
+ from ....utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class MegaConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`MegaModel`]. It is used to instantiate a Mega
31
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
32
+ defaults will yield a similar configuration to that of the Mega
33
+ [mnaylor/mega-base-wikitext](https://huggingface.co/mnaylor/mega-base-wikitext) architecture.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 30522):
41
+ Vocabulary size of the Mega model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`MegaModel`].
43
+ hidden_size (`int`, *optional*, defaults to 128):
44
+ Dimensionality of the encoder layers and the pooler layer.
45
+ num_hidden_layers (`int`, *optional*, defaults to 4):
46
+ Number of hidden layers in the Mega encoder.
47
+ intermediate_size (`int`, *optional*, defaults to 256):
48
+ Dimensionality of the hidden size (self-attention value projection) within the Mega encoder
49
+ ema_projection_size (`int`, *optional*, defaults to 16):
50
+ Dimensionality of the MegaMultiDimensionDampedEma
51
+ bidirectional (`bool`, *optional*, defaults to `True`):
52
+ Whether the MegaMultiDimensionDampedEma used in Mega's self-attention should work bidirectionally (`True`)
53
+ or unidirectionally (`False`). Bidirectional EMA is incompatible with causal decoding, so this should be
54
+ False if you intend to use the model as a decoder.
55
+ shared_representation_size (`int`, *optional*, defaults to 64):
56
+ Dimensionality of the linear projection for shared representation of self-attention queries and keys
57
+ use_chunking (`bool`, *optional*, defaults to `False`):
58
+ Whether to chunk inputs for linear self-attention complexity (described as Mega-chunk in the paper)
59
+ chunk_size (`int`, *optional*, defaults to -1):
60
+ If `use_chunking` is set to `True`, determines the size of the chunks to apply to the input sequence. If
61
+ chunking is used, input sequences must be padded to a multiple of `chunk_size`
62
+ truncation (`int`, *optional*):
63
+ If specified, the sequence length for which to truncate MegaMultiDimensionDampedEma
64
+ normalize_before_mega (`bool`, *optional*, defaults to `True`):
65
+ Whether to normalize before (`True`) or after (`False`) passing through Mega encoder blocks
66
+ normalization_type (`str`, *optional*, defaults to `"scalenorm"`):
67
+ Type of normalization to use in Mega encoder blocks. Choose one of `"scalenorm"`, `"layernorm"`,
68
+ `"rmsnorm"`, `"batchnorm"`, or `"syncbatchnorm"` (GPU required for syncbatchnorm)
69
+ norm_affine (`bool`, *optional*, defaults to `True`):
70
+ If `True`, applies a parameterized affine transformation to inputs during normalization
71
+ activation (`str`, *optional*, defaults to `"silu"`):
72
+ Activation function to apply within Mega encoder blocks. Choose one of `"silu"`, `"relu"`, `"linear"`,
73
+ `"gelu"`, or `"gelu_accurate"`
74
+ attention_activation (`str`, *optional*, defaults to `"softmax"`):
75
+ Activation function to apply for single-headed self-attention (a la Transformer). Choose one of
76
+ `"softmax"`, `"laplace"`, or `"relu2"`
77
+ dropout_prob (`float`, *optional*, defaults to 0.1):
78
+ The dropout probability for EMA self-attention
79
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
80
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
81
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
82
+ The dropout ratio for the attention probabilities.
83
+ use_feature_dropout (`bool`, *optional*, defaults to `False`):
84
+ Whether to use feature-based (`True`) or standard dropout (`False`)
85
+ use_normalized_ffn (`bool`, *optional*, defaults to `True`):
86
+ Whether to use the normalized feed-forward sub-layer in Mega blocks (`True`) or pass Mega encoder output
87
+ as-is (`False`)
88
+ nffn_hidden_size (`int`, *optional*, defaults to 256):
89
+ If using the normalized feed-forward network (NFFN) layer within Mega (`use_normalized_ffn = True`), this
90
+ is the hidden size of the NFFN
91
+ normalize_before_ffn (`bool`, *optional*, defaults to `True`):
92
+ Whether to normalize before (`True`) or after (`False`) the feed-forward portion of NFFN
93
+ nffn_activation_dropout_prob (`float`, *optional*, defaults to 0.1):
94
+ The dropout ratio for the NFFN component.
95
+ max_positions (`int`, *optional*, defaults to 2048):
96
+ The maximum sequence length to use for positional representations. For `"simple"` relative positional bias,
97
+ this is a hard limit on input length; `"rotary"` relative positional bias will extrapolate to longer
98
+ sequences
99
+ add_token_type_embeddings (`bool`, *optional*, defaults to `True`):
100
+ Whether to account for token types in embeddings. Left as optional to maintain compatibility with original
101
+ implementation while adding support for token types.
102
+ type_vocab_size (`int`, *optional*, defaults to 2):
103
+ The vocabulary size of the `token_type_ids` passed when calling [`MegaModel`]. Only used if
104
+ `add_token_type_embeddings = True`
105
+ initializer_range (`float`, *optional*, defaults to 0.02):
106
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
107
+ ema_delta_alpha_range (`float`, *optional*, defaults to 0.2):
108
+ The standard deviation for initializing the delta (damping factor) and alpha (decay factor) parameters in
109
+ MegaMultiDimensionDampedEma.
110
+ ema_beta_range (`float`, *optional*, defaults to 0.02):
111
+ The standard deviation for initializing the beta parameter (expansion matrix) in
112
+ MegaMultiDimensionDampedEma.
113
+ ema_gamma_omega_range (`float`, *optional*, defaults to 1.0):
114
+ The standard deviation for initializing the gamma (projection matrix) and omega (residual weight)
115
+ parameters in MultiDimensionEMA.
116
+ relative_positional_bias (`str`, *optional*, defaults to `"rotary"`):
117
+ Type of relative positional encoding. Choose one of `"rotary"` or `"simple"`. If `"simple"` is selected,
118
+ `max_positions` is used as a limit on input size, while `"rotary"` extrapolates beyond `max_positions`.
119
+ is_decoder (`bool`, *optional*, defaults to `False`):
120
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
121
+ use_cache (`bool`, *optional*, defaults to `True`):
122
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
123
+ relevant if `config.is_decoder=True`.
124
+ classifier_dropout (`float`, *optional*):
125
+ The dropout ratio for the classification head.
126
+ add_lm_hidden_dense_layer (`bool`, *optional*, defaults to `True`):
127
+ Whether to include a hidden layer for projection between encoder outputs and LM heads (`True`) or pass
128
+ hidden states directly to LM head (`False`). Remains optional for compatibility with original
129
+ implementation
130
+
131
+ Examples:
132
+
133
+ ```python
134
+ >>> from transformers import MegaConfig, MegaModel
135
+
136
+ >>> # Initializing a Mega configuration
137
+ >>> configuration = MegaConfig()
138
+
139
+ >>> # Initializing a model (with random weights) from the configuration
140
+ >>> model = MegaModel(configuration)
141
+
142
+ >>> # Accessing the model configuration
143
+ >>> configuration = model.config
144
+ ```"""
145
+
146
+ model_type = "mega"
147
+
148
+ def __init__(
149
+ self,
150
+ vocab_size=30522,
151
+ hidden_size=128,
152
+ num_hidden_layers=4,
153
+ intermediate_size=256,
154
+ ema_projection_size=16,
155
+ bidirectional=True,
156
+ shared_representation_size=64,
157
+ use_chunking=False,
158
+ chunk_size=-1,
159
+ truncation=None,
160
+ normalize_before_mega=True,
161
+ normalization_type="scalenorm",
162
+ norm_affine=True,
163
+ activation="silu",
164
+ attention_activation="softmax",
165
+ dropout_prob=0.1,
166
+ hidden_dropout_prob=0.1,
167
+ attention_probs_dropout_prob=0.1,
168
+ use_feature_dropout=False,
169
+ use_normalized_ffn=True,
170
+ nffn_hidden_size=256,
171
+ normalize_before_ffn=True,
172
+ nffn_activation_dropout_prob=0.1,
173
+ max_positions=2048,
174
+ add_token_type_embeddings=False,
175
+ type_vocab_size=2,
176
+ initializer_range=0.02,
177
+ ema_delta_alpha_range=0.2,
178
+ ema_beta_range=0.02,
179
+ ema_gamma_omega_range=1.0,
180
+ pad_token_id=1,
181
+ bos_token_id=0,
182
+ eos_token_id=2,
183
+ relative_positional_bias="rotary",
184
+ classifier_dropout=None,
185
+ use_cache=True,
186
+ add_lm_hidden_dense_layer=True,
187
+ **kwargs,
188
+ ):
189
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
190
+
191
+ self.vocab_size = vocab_size
192
+ self.hidden_size = hidden_size
193
+ self.num_hidden_layers = num_hidden_layers
194
+ self.activation = activation
195
+ self.attention_activation = attention_activation
196
+ self.intermediate_size = intermediate_size
197
+ self.ema_projection_size = ema_projection_size
198
+ self.bidirectional = bidirectional
199
+ self.shared_representation_size = shared_representation_size
200
+ self.use_chunking = use_chunking
201
+ self.chunk_size = chunk_size
202
+ self.truncation = truncation
203
+ self.normalize_before_mega = normalize_before_mega
204
+ self.normalization_type = normalization_type
205
+ self.norm_affine = norm_affine
206
+ self.dropout_prob = dropout_prob
207
+ self.hidden_dropout_prob = hidden_dropout_prob
208
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
209
+ self.use_feature_dropout = use_feature_dropout
210
+ self.use_normalized_ffn = use_normalized_ffn
211
+ self.nffn_hidden_size = nffn_hidden_size
212
+ self.normalize_before_ffn = normalize_before_ffn
213
+ self.nffn_activation_dropout_prob = nffn_activation_dropout_prob
214
+ self.max_positions = max_positions
215
+ self.add_token_type_embeddings = add_token_type_embeddings
216
+ self.type_vocab_size = type_vocab_size
217
+ self.initializer_range = initializer_range
218
+ self.ema_delta_alpha_range = ema_delta_alpha_range
219
+ self.ema_beta_range = ema_beta_range
220
+ self.ema_gamma_omega_range = ema_gamma_omega_range
221
+ self.relative_positional_bias = relative_positional_bias
222
+ self.use_cache = use_cache
223
+ self.classifier_dropout = classifier_dropout
224
+ self.add_lm_hidden_dense_layer = add_lm_hidden_dense_layer
225
+ self.num_attention_heads = 1 # not used but required by Hugging Face
226
+
227
+
228
+ class MegaOnnxConfig(OnnxConfig):
229
+ @property
230
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
231
+ if self.task == "multiple-choice":
232
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
233
+ else:
234
+ dynamic_axis = {0: "batch", 1: "sequence"}
235
+ return OrderedDict(
236
+ [
237
+ ("input_ids", dynamic_axis),
238
+ ("attention_mask", dynamic_axis),
239
+ ]
240
+ )
241
+
242
+
243
+ __all__ = ["MegaConfig", "MegaOnnxConfig"]
docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Convert Mega pretrained checkpoint. Built to convert the Masked LM checkpoint located at
18
+ https://huggingface.co/mnaylor/mega-wikitext-103
19
+
20
+ Requirements:
21
+ - clone the Mega repo and install fairseq from there
22
+ 1. git clone https://github.com/facebookresearch/mega.git
23
+ 2. cd mega && pip install -e
24
+ - clone the pretrained weights for the original implementation from the hugging face repo
25
+ * use this location as the path for pretrained weights
26
+ """
27
+
28
+ import argparse
29
+
30
+ # utilities to import the model weights and config file
31
+ import os
32
+ import pickle as pkl
33
+
34
+ # PyTorch + new model classes
35
+ import torch
36
+ from torch import nn
37
+
38
+ from transformers import AutoTokenizer, MegaConfig, MegaForMaskedLM
39
+
40
+
41
+ # import the EncoderLayer class used to pretrain
42
+ # !! NOTE !! this requires the version of fairseq that is built when you install the Mega source
43
+ try:
44
+ from fairseq.modules.mega_layer import MegaEncoderLayer
45
+ except ImportError:
46
+ raise ImportError("You need to install the version of fairseq from the Mega repo!")
47
+
48
+
49
+ # define the wrapper classes used to train the MLM (see colab notebook below)
50
+ # https://colab.research.google.com/drive/1qfUO6o5HRdxBblWlw058HVyvaEPhPpH8?usp=sharing
51
+ # MegaLM outputs hidden states
52
+ class MegaLM(nn.Module):
53
+ "The base class for our Mega encoder - given input IDs, embed text and return encoder output"
54
+
55
+ def __init__(self, mega_args, depth, vocab_size):
56
+ super().__init__()
57
+ self.mega_args = mega_args
58
+ self.embedding_layer = nn.Embedding(vocab_size, self.mega_args.encoder_embed_dim)
59
+ self.encoders = nn.ModuleList([MegaEncoderLayer(self.mega_args) for _ in range(depth)])
60
+ self.depth = depth
61
+
62
+ def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
63
+ """
64
+ Code for a forward pass - expects input_ids and attention_mask to come from a Hugging Face tokenizer as PyTorch
65
+ tensors, and returns a tensor of size (batch, n_classes) containing classification logits
66
+
67
+ Other options:
68
+ - batch_first: boolean indicating whether the batch dimension is first in input_ids (default: True, which
69
+ aligns with the HF tokenizer behavior)
70
+ - ignore_mask_value: the value in attention_mask that identifies tokens that should be ignored (default: 0,
71
+ which aligns with HF tokenizer)
72
+ """
73
+
74
+ # Mega expects embeddings to be (time, batch, embedding size), but
75
+ # Hugging Face returns tokens as (batch, time)
76
+ if batch_first:
77
+ input_ids = input_ids.T
78
+
79
+ # to make things more confusing, Mega expects the attention mask to
80
+ # be (batch, time), but with values of 0 (normal token) and 1 (ignore token)
81
+ # which is the opposite of what HF returns
82
+ if ignore_mask_value == 0:
83
+ attention_mask = 1 - attention_mask
84
+
85
+ # get token embeddings from IDs
86
+ embeds = self.embedding_layer(input_ids)
87
+
88
+ # pass through the Mega layers
89
+ # input is (time, batch, encoder dim) and output is the same
90
+ for encoder in self.encoders:
91
+ embeds = encoder(embeds, attention_mask)
92
+
93
+ # return according to the shape specified
94
+ if batch_first:
95
+ # (T, B, H) --> (B, T, H)
96
+ return torch.transpose(embeds, 0, 1)
97
+ else:
98
+ return embeds
99
+
100
+
101
+ # renamed from MegaForMaskedLM to avoid confusion with new module
102
+ class OriginalMegaForMaskedLM(nn.Module):
103
+ "A wrapper class for doing masked language modeling with Mega"
104
+
105
+ def __init__(self, mega_args, depth, vocab_size):
106
+ super().__init__()
107
+ self.mega = MegaLM(mega_args, depth, vocab_size)
108
+ self.mlm_head = nn.Linear(mega_args.encoder_embed_dim, vocab_size)
109
+ self.dropout = nn.Dropout(p=0.1)
110
+
111
+ def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
112
+ """
113
+ Perform a forward pass through the Mega encoder and the masked LM head. Returns logits for each vocabulary
114
+ entry.
115
+
116
+ If `batch_first` (default to align with Hugging Face tokenizer behavior), output will have the shape (Batch
117
+ size, Sequence length, Vocab size); otherwise (S, B, V)
118
+ """
119
+ encoder_output = self.mega(input_ids, attention_mask, batch_first, ignore_mask_value)
120
+ return self.mlm_head(self.dropout(encoder_output))
121
+
122
+
123
+ # code to convert the checkpoint located in the user-specified location
124
+ def convert_checkpoint_to_huggingface(pretrained_checkpoint_path, output_path, includes_tokenizer):
125
+ with open(os.path.join(pretrained_checkpoint_path, "model_args.pkl"), "rb") as f:
126
+ mega_original_args = pkl.load(f)
127
+
128
+ # load the original encoder
129
+ original_mlm = OriginalMegaForMaskedLM(**mega_original_args).eval()
130
+
131
+ # load its weights
132
+ print(
133
+ "Original Mega encoder:",
134
+ original_mlm.mega.load_state_dict(
135
+ torch.load(
136
+ os.path.join(pretrained_checkpoint_path, "encoder_weights.pt"), map_location="cpu", weights_only=True
137
+ )
138
+ ),
139
+ )
140
+ print(
141
+ "Original Mega MLM layer:",
142
+ original_mlm.mlm_head.load_state_dict(
143
+ torch.load(
144
+ os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
145
+ )
146
+ ),
147
+ )
148
+
149
+ # create a new config from the old one
150
+ hf_config = MegaConfig(
151
+ num_hidden_layers=mega_original_args["depth"],
152
+ vocab_size=mega_original_args["vocab_size"],
153
+ hidden_size=mega_original_args["mega_args"].encoder_embed_dim,
154
+ shared_representation_size=mega_original_args["mega_args"].encoder_z_dim,
155
+ intermediate_size=mega_original_args["mega_args"].encoder_hidden_dim,
156
+ ema_projection_size=mega_original_args["mega_args"].encoder_n_dim,
157
+ dropout_prob=mega_original_args["mega_args"].dropout,
158
+ attention_probs_dropout_prob=mega_original_args["mega_args"].attention_dropout,
159
+ hidden_dropout_prob=mega_original_args["mega_args"].hidden_dropout,
160
+ activation=mega_original_args["mega_args"].activation_fn,
161
+ attention_activation=mega_original_args["mega_args"].attention_activation_fn,
162
+ bidirectional=mega_original_args["mega_args"].bidirectional,
163
+ use_chunking=mega_original_args["mega_args"].encoder_chunk_size > 0,
164
+ chunk_size=mega_original_args["mega_args"].encoder_chunk_size,
165
+ truncation=mega_original_args["mega_args"].truncation_length,
166
+ normalization_type=mega_original_args["mega_args"].normalization_type,
167
+ normalize_before_mega=True,
168
+ norm_affine=True,
169
+ use_feature_dropout=mega_original_args["mega_args"].feature_dropout,
170
+ relative_positional_bias=mega_original_args["mega_args"].rel_pos_bias,
171
+ max_positions=mega_original_args["mega_args"].max_source_positions,
172
+ nffn_hidden_size=mega_original_args["mega_args"].encoder_ffn_embed_dim,
173
+ normalize_before_ffn=mega_original_args["mega_args"].normalize_before,
174
+ # new arguments added for HF implementation
175
+ nffn_activation_dropout_prob=0.0,
176
+ add_token_type_embeddings=False,
177
+ add_lm_hidden_dense_layer=False,
178
+ )
179
+
180
+ hf_mlm = MegaForMaskedLM(hf_config).eval()
181
+
182
+ # the originl checkpoint just uses nn.Embedding for the word embeddings
183
+ # we use a wrapper module for embeddings to add support for positional embeddings
184
+ hf_mlm.mega.embedding_layer.word_embeddings.weight = original_mlm.mega.embedding_layer.weight
185
+
186
+ # modify the state dictionary of the original checkpoint to account for naming issues in the Hugging Face
187
+ # ecosystem -- any names containing "beta" or "gamma" aren't safe to use and are renamed upon _load_pretrained,
188
+ # also renaming previously confusing parameter names
189
+ original_state_dict = original_mlm.mega.encoders.state_dict()
190
+ updated_keys = {}
191
+ for module_name in original_state_dict.keys():
192
+ new_module_name = None
193
+ # have to handle gamma, beta, and alpha differently due to their use
194
+ # in multiple modules within the original repository;
195
+ # beta is used in EMA, MovingAverageGatedAttention, and RotaryRelativePositionalBias, and must be renamed due to flax/tf weights
196
+ # the EMA sublayer was renamed from "move" to "ema_gate" for readability, so that is also done here
197
+ if "beta" in module_name:
198
+ # EMA sub-layers were always called "move" in the original repo
199
+ if "move.beta" in module_name:
200
+ new_module_name = module_name.replace("move.beta", "ema_gate.ema_expansion_matrix")
201
+ elif "mega_layer.beta" in module_name:
202
+ new_module_name = module_name.replace("beta", "qk_bias")
203
+ else:
204
+ new_module_name = module_name.replace("beta", "b_param")
205
+ # beta is used in EMA and MovingAverageGatedAttention, and must be renamed due to flax/tf weights
206
+ elif "gamma" in module_name:
207
+ if "move.gamma" in module_name:
208
+ new_module_name = module_name.replace("move.gamma", "ema_gate.kernel_projection_matrix")
209
+ elif "mega_layer.gamma" in module_name:
210
+ new_module_name = module_name.replace("gamma", "qk_weight")
211
+ else:
212
+ new_module_name = module_name.replace("gamma", "g_param")
213
+ # alpha is used in EMA and positional bias; renaming to improve readability
214
+ elif "move.alpha" in module_name:
215
+ new_module_name = module_name.replace("move.alpha", "ema_gate.decay_factor")
216
+ # delta is only used in EMA; renaming to improve readability
217
+ elif "move.delta" in module_name:
218
+ new_module_name = module_name.replace("move.delta", "ema_gate.damping_factor")
219
+ # omega is only used in EMA; renaming to improve readability
220
+ elif "omega" in module_name:
221
+ new_module_name = module_name.replace("move.omega", "ema_gate.residual_weight")
222
+
223
+ if new_module_name:
224
+ updated_keys[module_name] = new_module_name
225
+
226
+ if len(updated_keys) != 0:
227
+ print(f"Renaming these keys: {updated_keys.keys()}")
228
+ else:
229
+ print("No need to rename state dict entries")
230
+ for old, new in updated_keys.items():
231
+ original_state_dict[new] = original_state_dict.pop(old)
232
+
233
+ # now attempt to load the state dictionary with updated names
234
+ # note that we now call it `mega.layers` instead of `mega.encoders` due to hugging face style
235
+ print("HF Mega encoder:", hf_mlm.mega.layers.load_state_dict(original_state_dict))
236
+
237
+ # load the MLM head weights directly
238
+ print(
239
+ "HF Mega MLM layer:",
240
+ hf_mlm.mlm_head.load_state_dict(
241
+ torch.load(
242
+ os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
243
+ )
244
+ ),
245
+ )
246
+
247
+ # test on a randomly generated input sequence
248
+ input_ids = torch.randint(0, hf_config.vocab_size, size=(4, 256))
249
+ input_mask = torch.ones_like(input_ids)
250
+ # mask a few tokens to make sure masking is applied appropriately :)
251
+ input_mask[:, -10:] = 0
252
+
253
+ # run forward passes
254
+ original_output = original_mlm(input_ids, input_mask, batch_first=True, ignore_mask_value=0)
255
+ hf_output = hf_mlm(input_ids, input_mask)[0]
256
+
257
+ # print shapes and diff
258
+ print(f"original output {original_output.shape}")
259
+ print(f"hf output {hf_output.shape}")
260
+ print(f"max diff: {(original_output - hf_output).max()}") # 0.0
261
+ success = torch.allclose(original_output, hf_output, atol=1e-3)
262
+
263
+ if success:
264
+ print("Yay!")
265
+ hf_mlm.save_pretrained(output_path)
266
+ else:
267
+ raise RuntimeError(f"Something's broken :(\nOriginal:\n{original_output}\n\nHF\n{hf_output}\n{hf_mlm}")
268
+
269
+ if includes_tokenizer:
270
+ print("Transferring tokenizer")
271
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_checkpoint_path)
272
+ tokenizer.save_pretrained(output_path)
273
+
274
+
275
+ if __name__ == "__main__":
276
+ parser = argparse.ArgumentParser()
277
+
278
+ parser.add_argument(
279
+ "--pretrained_checkpoint_path",
280
+ default=None,
281
+ type=str,
282
+ required=True,
283
+ help="Point to the directory containing your model weights using the official Mega repo",
284
+ )
285
+
286
+ parser.add_argument(
287
+ "--output_path", default=None, type=str, required=True, help="Location to save the Hugging Face version"
288
+ )
289
+
290
+ parser.add_argument(
291
+ "--includes_tokenizer",
292
+ action="store_true",
293
+ help="Use this flag if there is a Hugging Face tokenizer in the original checkpoint repo",
294
+ )
295
+
296
+ args = parser.parse_args()
297
+
298
+ convert_checkpoint_to_huggingface(args.pretrained_checkpoint_path, args.output_path, args.includes_tokenizer)
docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_mmbt import *
22
+ from .modeling_mmbt import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # Copyright (c) HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """MMBT configuration"""
17
+
18
+ from ....utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class MMBTConfig:
25
+ """
26
+ This is the configuration class to store the configuration of a [`MMBTModel`]. It is used to instantiate a MMBT
27
+ model according to the specified arguments, defining the model architecture.
28
+
29
+ Args:
30
+ config ([`PreTrainedConfig`]):
31
+ Config of the underlying Transformer models. Its values are copied over to use a single config.
32
+ num_labels (`int`, *optional*):
33
+ Size of final Linear layer for classification.
34
+ modal_hidden_size (`int`, *optional*, defaults to 2048):
35
+ Embedding dimension of the non-text modality encoder.
36
+ """
37
+
38
+ def __init__(self, config, num_labels=None, modal_hidden_size=2048):
39
+ self.__dict__ = config.__dict__
40
+ self.modal_hidden_size = modal_hidden_size
41
+ if num_labels:
42
+ self.num_labels = num_labels
43
+
44
+
45
+ __all__ = ["MMBTConfig"]
docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # Copyright (c) HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch MMBT model."""
17
+
18
+ import torch
19
+ from torch import nn
20
+ from torch.nn import CrossEntropyLoss, MSELoss
21
+
22
+ from ....modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
23
+ from ....modeling_utils import ModuleUtilsMixin
24
+ from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ _CONFIG_FOR_DOC = "MMBTConfig"
30
+
31
+
32
+ class ModalEmbeddings(nn.Module):
33
+ """Generic Modal Embeddings which takes in an encoder, and a transformer embedding."""
34
+
35
+ def __init__(self, config, encoder, embeddings):
36
+ super().__init__()
37
+ self.config = config
38
+ self.encoder = encoder
39
+ self.proj_embeddings = nn.Linear(config.modal_hidden_size, config.hidden_size)
40
+ self.position_embeddings = embeddings.position_embeddings
41
+ self.token_type_embeddings = embeddings.token_type_embeddings
42
+ self.word_embeddings = embeddings.word_embeddings
43
+ self.LayerNorm = embeddings.LayerNorm
44
+ self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
45
+
46
+ def forward(self, input_modal, start_token=None, end_token=None, position_ids=None, token_type_ids=None):
47
+ token_embeddings = self.proj_embeddings(self.encoder(input_modal))
48
+ seq_length = token_embeddings.size(1)
49
+
50
+ if start_token is not None:
51
+ start_token_embeds = self.word_embeddings(start_token)
52
+ seq_length += 1
53
+ token_embeddings = torch.cat([start_token_embeds.unsqueeze(1), token_embeddings], dim=1)
54
+
55
+ if end_token is not None:
56
+ end_token_embeds = self.word_embeddings(end_token)
57
+ seq_length += 1
58
+ token_embeddings = torch.cat([token_embeddings, end_token_embeds.unsqueeze(1)], dim=1)
59
+
60
+ if position_ids is None:
61
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_modal.device)
62
+ position_ids = position_ids.unsqueeze(0).expand(input_modal.size(0), seq_length)
63
+
64
+ if token_type_ids is None:
65
+ token_type_ids = torch.zeros(
66
+ (input_modal.size(0), seq_length), dtype=torch.long, device=input_modal.device
67
+ )
68
+
69
+ position_embeddings = self.position_embeddings(position_ids)
70
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
71
+ embeddings = token_embeddings + position_embeddings + token_type_embeddings
72
+ embeddings = self.LayerNorm(embeddings)
73
+ embeddings = self.dropout(embeddings)
74
+ return embeddings
75
+
76
+
77
+ MMBT_START_DOCSTRING = r"""
78
+ MMBT model was proposed in [Supervised Multimodal Bitransformers for Classifying Images and
79
+ Text](https://github.com/facebookresearch/mmbt) by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
80
+ It's a supervised multimodal bitransformer model that fuses information from text and other image encoders, and
81
+ obtain state-of-the-art performance on various multimodal classification benchmark tasks.
82
+
83
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
84
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
85
+ etc.)
86
+
87
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
88
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
89
+ and behavior.
90
+
91
+ Parameters:
92
+ config ([`MMBTConfig`]): Model configuration class with all the parameters of the model.
93
+ Initializing with a config file does not load the weights associated with the model, only the
94
+ configuration.
95
+ transformer (`nn.Module`): A text transformer that is used by MMBT.
96
+ It should have embeddings, encoder, and pooler attributes.
97
+ encoder (`nn.Module`): Encoder for the second modality.
98
+ It should take in a batch of modal inputs and return k, n dimension embeddings.
99
+ """
100
+
101
+ MMBT_INPUTS_DOCSTRING = r"""
102
+ Args:
103
+ input_modal (`torch.FloatTensor` of shape `(batch_size, ***)`):
104
+ The other modality data. It will be the shape that the encoder for that type expects. e.g. With an Image
105
+ Encoder, the shape would be (batch_size, channels, height, width)
106
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
107
+ Indices of input sequence tokens in the vocabulary. It does not expect [CLS] token to be added as it's
108
+ appended to the end of other modality embeddings. Indices can be obtained using [`AutoTokenizer`]. See
109
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
110
+
111
+ [What are input IDs?](../glossary#input-ids)
112
+ modal_start_tokens (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
113
+ Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for classification
114
+ tasks.
115
+ modal_end_tokens (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
116
+ Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used.
117
+ attention_mask (*optional*) `torch.FloatTensor` of shape `(batch_size, sequence_length)`:
118
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
119
+
120
+ - 1 for tokens that are **not masked**,
121
+ - 0 for tokens that are **masked**.
122
+
123
+ [What are attention masks?](../glossary#attention-mask)
124
+ token_type_ids (*optional*) `torch.LongTensor` of shape `(batch_size, sequence_length)`:
125
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
126
+ 1]`:
127
+
128
+ - 0 corresponds to a *sentence A* token,
129
+ - 1 corresponds to a *sentence B* token.
130
+
131
+ [What are token type IDs?](../glossary#token-type-ids)
132
+ modal_token_type_ids (*optional*) `torch.LongTensor` of shape `(batch_size, modal_sequence_length)`:
133
+ Segment token indices to indicate different portions of the non-text modality. The embeddings from these
134
+ tokens will be summed with the respective token embeddings for the non-text modality.
135
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
136
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
137
+ config.max_position_embeddings - 1]`.
138
+
139
+ [What are position IDs?](../glossary#position-ids)
140
+ modal_position_ids (`torch.LongTensor` of shape `(batch_size, modal_sequence_length)`, *optional*):
141
+ Indices of positions of each input sequence tokens in the position embeddings for the non-text modality.
142
+ Selected in the range `[0, config.max_position_embeddings - 1]`.
143
+
144
+ [What are position IDs?](../glossary#position-ids)
145
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
146
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
147
+
148
+ - 1 indicates the head is **not masked**,
149
+ - 0 indicates the head is **masked**.
150
+
151
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, embedding_dim)`, *optional*):
152
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
153
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
154
+ model's internal embedding lookup matrix.
155
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
156
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
157
+ the model is configured as a decoder.
158
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
159
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
160
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
161
+
162
+ - 1 for tokens that are **not masked**,
163
+ - 0 for tokens that are **masked**.
164
+
165
+ output_attentions (`bool`, *optional*):
166
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
167
+ tensors for more detail.
168
+ output_hidden_states (`bool`, *optional*):
169
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
170
+ more detail.
171
+ return_dict (`bool`, *optional*):
172
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
173
+ """
174
+
175
+
176
+ @add_start_docstrings(
177
+ "The bare MMBT Model outputting raw hidden-states without any specific head on top.",
178
+ MMBT_START_DOCSTRING,
179
+ )
180
+ class MMBTModel(nn.Module, ModuleUtilsMixin):
181
+ def __init__(self, config, transformer, encoder):
182
+ super().__init__()
183
+ self.config = config
184
+ self.transformer = transformer
185
+ self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings)
186
+
187
+ @add_start_docstrings_to_model_forward(MMBT_INPUTS_DOCSTRING)
188
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
189
+ def forward(
190
+ self,
191
+ input_modal,
192
+ input_ids=None,
193
+ modal_start_tokens=None,
194
+ modal_end_tokens=None,
195
+ attention_mask=None,
196
+ token_type_ids=None,
197
+ modal_token_type_ids=None,
198
+ position_ids=None,
199
+ modal_position_ids=None,
200
+ head_mask=None,
201
+ inputs_embeds=None,
202
+ encoder_hidden_states=None,
203
+ encoder_attention_mask=None,
204
+ output_attentions=None,
205
+ output_hidden_states=None,
206
+ return_dict=None,
207
+ ):
208
+ r"""
209
+ Returns:
210
+
211
+ Examples:
212
+
213
+ ```python
214
+ # For example purposes. Not runnable.
215
+ transformer = BertModel.from_pretrained("google-bert/bert-base-uncased")
216
+ encoder = ImageEncoder(args)
217
+ mmbt = MMBTModel(config, transformer, encoder)
218
+ ```"""
219
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
220
+ output_hidden_states = (
221
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
222
+ )
223
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
224
+
225
+ if input_ids is not None and inputs_embeds is not None:
226
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
227
+ elif input_ids is not None:
228
+ input_txt_shape = input_ids.size()
229
+ elif inputs_embeds is not None:
230
+ input_txt_shape = inputs_embeds.size()[:-1]
231
+ else:
232
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
233
+
234
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
235
+
236
+ modal_embeddings = self.modal_encoder(
237
+ input_modal,
238
+ start_token=modal_start_tokens,
239
+ end_token=modal_end_tokens,
240
+ position_ids=modal_position_ids,
241
+ token_type_ids=modal_token_type_ids,
242
+ )
243
+
244
+ input_modal_shape = modal_embeddings.size()[:-1]
245
+
246
+ if token_type_ids is None:
247
+ token_type_ids = torch.ones(input_txt_shape, dtype=torch.long, device=device)
248
+
249
+ txt_embeddings = self.transformer.embeddings(
250
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
251
+ )
252
+
253
+ embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1)
254
+
255
+ input_shape = embedding_output.size()[:-1]
256
+
257
+ if attention_mask is None:
258
+ attention_mask = torch.ones(input_shape, device=device)
259
+ else:
260
+ attention_mask = torch.cat(
261
+ [torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1
262
+ )
263
+ if encoder_attention_mask is None:
264
+ encoder_attention_mask = torch.ones(input_shape, device=device)
265
+ else:
266
+ encoder_attention_mask = torch.cat(
267
+ [torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
268
+ )
269
+
270
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
271
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
272
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
273
+
274
+ encoder_outputs = self.transformer.encoder(
275
+ embedding_output,
276
+ attention_mask=extended_attention_mask,
277
+ head_mask=head_mask,
278
+ encoder_hidden_states=encoder_hidden_states,
279
+ encoder_attention_mask=encoder_extended_attention_mask,
280
+ output_attentions=output_attentions,
281
+ output_hidden_states=output_hidden_states,
282
+ return_dict=return_dict,
283
+ )
284
+
285
+ sequence_output = encoder_outputs[0]
286
+ pooled_output = self.transformer.pooler(sequence_output)
287
+
288
+ if not return_dict:
289
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
290
+
291
+ return BaseModelOutputWithPooling(
292
+ last_hidden_state=sequence_output,
293
+ pooler_output=pooled_output,
294
+ hidden_states=encoder_outputs.hidden_states,
295
+ attentions=encoder_outputs.attentions,
296
+ )
297
+
298
+ def get_input_embeddings(self):
299
+ return self.embeddings.word_embeddings
300
+
301
+ def set_input_embeddings(self, value):
302
+ self.embeddings.word_embeddings = value
303
+
304
+
305
+ @add_start_docstrings(
306
+ """
307
+ MMBT Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
308
+ """,
309
+ MMBT_START_DOCSTRING,
310
+ MMBT_INPUTS_DOCSTRING,
311
+ )
312
+ class MMBTForClassification(nn.Module):
313
+ r"""
314
+ **labels**: (*optional*) `torch.LongTensor` of shape `(batch_size,)`:
315
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
316
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
317
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
318
+
319
+ Returns: *Tuple* comprising various elements depending on the configuration (config) and inputs: **loss**:
320
+ (*optional*, returned when `labels` is provided) `torch.FloatTensor` of shape `(1,)`: Classification (or
321
+ regression if config.num_labels==1) loss. **logits**:
322
+ `torch.FloatTensor` of shape `(batch_size, config.num_labels)` Classification (or regression if
323
+ config.num_labels==1) scores (before SoftMax).
324
+ **hidden_states**: (*optional*, returned when `output_hidden_states=True`) list of `torch.FloatTensor` (one for
325
+ the output of each layer + the output of the embeddings) of shape `(batch_size, sequence_length, hidden_size)`:
326
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**:
327
+ (*optional*, returned when `output_attentions=True`) list of `torch.FloatTensor` (one for each layer) of shape
328
+ `(batch_size, num_heads, sequence_length, sequence_length)`: Attentions weights after the attention softmax, used
329
+ to compute the weighted average in the self-attention heads.
330
+
331
+ Examples:
332
+
333
+ ```python
334
+ # For example purposes. Not runnable.
335
+ transformer = BertModel.from_pretrained("google-bert/bert-base-uncased")
336
+ encoder = ImageEncoder(args)
337
+ model = MMBTForClassification(config, transformer, encoder)
338
+ outputs = model(input_modal, input_ids, labels=labels)
339
+ loss, logits = outputs[:2]
340
+ ```"""
341
+
342
+ def __init__(self, config, transformer, encoder):
343
+ super().__init__()
344
+ self.num_labels = config.num_labels
345
+
346
+ self.mmbt = MMBTModel(config, transformer, encoder)
347
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
348
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
349
+
350
+ def forward(
351
+ self,
352
+ input_modal,
353
+ input_ids=None,
354
+ modal_start_tokens=None,
355
+ modal_end_tokens=None,
356
+ attention_mask=None,
357
+ token_type_ids=None,
358
+ modal_token_type_ids=None,
359
+ position_ids=None,
360
+ modal_position_ids=None,
361
+ head_mask=None,
362
+ inputs_embeds=None,
363
+ labels=None,
364
+ return_dict=None,
365
+ ):
366
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
367
+
368
+ outputs = self.mmbt(
369
+ input_modal=input_modal,
370
+ input_ids=input_ids,
371
+ modal_start_tokens=modal_start_tokens,
372
+ modal_end_tokens=modal_end_tokens,
373
+ attention_mask=attention_mask,
374
+ token_type_ids=token_type_ids,
375
+ modal_token_type_ids=modal_token_type_ids,
376
+ position_ids=position_ids,
377
+ modal_position_ids=modal_position_ids,
378
+ head_mask=head_mask,
379
+ inputs_embeds=inputs_embeds,
380
+ return_dict=return_dict,
381
+ )
382
+
383
+ pooled_output = outputs[1]
384
+
385
+ pooled_output = self.dropout(pooled_output)
386
+ logits = self.classifier(pooled_output)
387
+
388
+ loss = None
389
+ if labels is not None:
390
+ if self.num_labels == 1:
391
+ # We are doing regression
392
+ loss_fct = MSELoss()
393
+ loss = loss_fct(logits.view(-1), labels.view(-1))
394
+ else:
395
+ loss_fct = CrossEntropyLoss()
396
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
397
+
398
+ if not return_dict:
399
+ output = (logits,) + outputs[2:]
400
+ return ((loss,) + output) if loss is not None else output
401
+
402
+ return SequenceClassifierOutput(
403
+ loss=loss,
404
+ logits=logits,
405
+ hidden_states=outputs.hidden_states,
406
+ attentions=outputs.attentions,
407
+ )
408
+
409
+
410
+ __all__ = ["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]
docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_nat import *
22
+ from .modeling_nat import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Neighborhood Attention Transformer model configuration"""
16
+
17
+ from ....configuration_utils import PretrainedConfig
18
+ from ....utils import logging
19
+ from ....utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class NatConfig(BackboneConfigMixin, PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`NatModel`]. It is used to instantiate a Nat model
28
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29
+ defaults will yield a similar configuration to that of the Nat
30
+ [shi-labs/nat-mini-in1k-224](https://huggingface.co/shi-labs/nat-mini-in1k-224) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ patch_size (`int`, *optional*, defaults to 4):
37
+ The size (resolution) of each patch. NOTE: Only patch size of 4 is supported at the moment.
38
+ num_channels (`int`, *optional*, defaults to 3):
39
+ The number of input channels.
40
+ embed_dim (`int`, *optional*, defaults to 64):
41
+ Dimensionality of patch embedding.
42
+ depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 5]`):
43
+ Number of layers in each level of the encoder.
44
+ num_heads (`List[int]`, *optional*, defaults to `[2, 4, 8, 16]`):
45
+ Number of attention heads in each layer of the Transformer encoder.
46
+ kernel_size (`int`, *optional*, defaults to 7):
47
+ Neighborhood Attention kernel size.
48
+ mlp_ratio (`float`, *optional*, defaults to 3.0):
49
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
50
+ qkv_bias (`bool`, *optional*, defaults to `True`):
51
+ Whether or not a learnable bias should be added to the queries, keys and values.
52
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
53
+ The dropout probability for all fully connected layers in the embeddings and encoder.
54
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
55
+ The dropout ratio for the attention probabilities.
56
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
57
+ Stochastic depth rate.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
59
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
60
+ `"selu"` and `"gelu_new"` are supported.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
64
+ The epsilon used by the layer normalization layers.
65
+ layer_scale_init_value (`float`, *optional*, defaults to 0.0):
66
+ The initial value for the layer scale. Disabled if <=0.
67
+ out_features (`List[str]`, *optional*):
68
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
69
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
70
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
71
+ same order as defined in the `stage_names` attribute.
72
+ out_indices (`List[int]`, *optional*):
73
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
74
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
75
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
76
+ same order as defined in the `stage_names` attribute.
77
+
78
+ Example:
79
+
80
+ ```python
81
+ >>> from transformers import NatConfig, NatModel
82
+
83
+ >>> # Initializing a Nat shi-labs/nat-mini-in1k-224 style configuration
84
+ >>> configuration = NatConfig()
85
+
86
+ >>> # Initializing a model (with random weights) from the shi-labs/nat-mini-in1k-224 style configuration
87
+ >>> model = NatModel(configuration)
88
+
89
+ >>> # Accessing the model configuration
90
+ >>> configuration = model.config
91
+ ```"""
92
+
93
+ model_type = "nat"
94
+
95
+ attribute_map = {
96
+ "num_attention_heads": "num_heads",
97
+ "num_hidden_layers": "num_layers",
98
+ }
99
+
100
+ def __init__(
101
+ self,
102
+ patch_size=4,
103
+ num_channels=3,
104
+ embed_dim=64,
105
+ depths=[3, 4, 6, 5],
106
+ num_heads=[2, 4, 8, 16],
107
+ kernel_size=7,
108
+ mlp_ratio=3.0,
109
+ qkv_bias=True,
110
+ hidden_dropout_prob=0.0,
111
+ attention_probs_dropout_prob=0.0,
112
+ drop_path_rate=0.1,
113
+ hidden_act="gelu",
114
+ initializer_range=0.02,
115
+ layer_norm_eps=1e-5,
116
+ layer_scale_init_value=0.0,
117
+ out_features=None,
118
+ out_indices=None,
119
+ **kwargs,
120
+ ):
121
+ super().__init__(**kwargs)
122
+
123
+ self.patch_size = patch_size
124
+ self.num_channels = num_channels
125
+ self.embed_dim = embed_dim
126
+ self.depths = depths
127
+ self.num_layers = len(depths)
128
+ self.num_heads = num_heads
129
+ self.kernel_size = kernel_size
130
+ self.mlp_ratio = mlp_ratio
131
+ self.qkv_bias = qkv_bias
132
+ self.hidden_dropout_prob = hidden_dropout_prob
133
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
134
+ self.drop_path_rate = drop_path_rate
135
+ self.hidden_act = hidden_act
136
+ self.layer_norm_eps = layer_norm_eps
137
+ self.initializer_range = initializer_range
138
+ # we set the hidden_size attribute in order to make Nat work with VisionEncoderDecoderModel
139
+ # this indicates the channel dimension after the last stage of the model
140
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
141
+ self.layer_scale_init_value = layer_scale_init_value
142
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
143
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
144
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
145
+ )
146
+
147
+
148
+ __all__ = ["NatConfig"]
docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Neighborhood Attention Transformer model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ....activations import ACT2FN
27
+ from ....modeling_outputs import BackboneOutput
28
+ from ....modeling_utils import PreTrainedModel
29
+ from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
30
+ from ....utils import (
31
+ ModelOutput,
32
+ OptionalDependencyNotAvailable,
33
+ add_code_sample_docstrings,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ is_natten_available,
37
+ logging,
38
+ replace_return_docstrings,
39
+ requires_backends,
40
+ )
41
+ from ....utils.backbone_utils import BackboneMixin
42
+ from .configuration_nat import NatConfig
43
+
44
+
45
+ if is_natten_available():
46
+ from natten.functional import natten2dav, natten2dqkrpb
47
+ else:
48
+
49
+ def natten2dqkrpb(*args, **kwargs):
50
+ raise OptionalDependencyNotAvailable()
51
+
52
+ def natten2dav(*args, **kwargs):
53
+ raise OptionalDependencyNotAvailable()
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ # General docstring
59
+ _CONFIG_FOR_DOC = "NatConfig"
60
+
61
+ # Base docstring
62
+ _CHECKPOINT_FOR_DOC = "shi-labs/nat-mini-in1k-224"
63
+ _EXPECTED_OUTPUT_SHAPE = [1, 7, 7, 512]
64
+
65
+ # Image classification docstring
66
+ _IMAGE_CLASS_CHECKPOINT = "shi-labs/nat-mini-in1k-224"
67
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat"
68
+
69
+
70
+ # drop_path and NatDropPath are from the timm library.
71
+
72
+
73
+ @dataclass
74
+ class NatEncoderOutput(ModelOutput):
75
+ """
76
+ Nat encoder's outputs, with potential hidden states and attentions.
77
+
78
+ Args:
79
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
80
+ Sequence of hidden-states at the output of the last layer of the model.
81
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
82
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
83
+ shape `(batch_size, sequence_length, hidden_size)`.
84
+
85
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
86
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
87
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
88
+ sequence_length)`.
89
+
90
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
91
+ heads.
92
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
93
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
94
+ shape `(batch_size, hidden_size, height, width)`.
95
+
96
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
97
+ include the spatial dimensions.
98
+ """
99
+
100
+ last_hidden_state: Optional[torch.FloatTensor] = None
101
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
102
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
103
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
104
+
105
+
106
+ @dataclass
107
+ class NatModelOutput(ModelOutput):
108
+ """
109
+ Nat model's outputs that also contains a pooling of the last hidden states.
110
+
111
+ Args:
112
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
113
+ Sequence of hidden-states at the output of the last layer of the model.
114
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
115
+ Average pooling of the last layer hidden-state.
116
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
117
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
118
+ shape `(batch_size, sequence_length, hidden_size)`.
119
+
120
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
121
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
122
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
123
+ sequence_length)`.
124
+
125
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
126
+ heads.
127
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
128
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
129
+ shape `(batch_size, hidden_size, height, width)`.
130
+
131
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
132
+ include the spatial dimensions.
133
+ """
134
+
135
+ last_hidden_state: Optional[torch.FloatTensor] = None
136
+ pooler_output: Optional[torch.FloatTensor] = None
137
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
138
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
139
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
140
+
141
+
142
+ @dataclass
143
+ class NatImageClassifierOutput(ModelOutput):
144
+ """
145
+ Nat outputs for image classification.
146
+
147
+ Args:
148
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
149
+ Classification (or regression if config.num_labels==1) loss.
150
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
151
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
152
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
153
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
154
+ shape `(batch_size, sequence_length, hidden_size)`.
155
+
156
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
157
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
158
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
159
+ sequence_length)`.
160
+
161
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
162
+ heads.
163
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
164
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
165
+ shape `(batch_size, hidden_size, height, width)`.
166
+
167
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
168
+ include the spatial dimensions.
169
+ """
170
+
171
+ loss: Optional[torch.FloatTensor] = None
172
+ logits: Optional[torch.FloatTensor] = None
173
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
174
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
175
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
176
+
177
+
178
+ class NatEmbeddings(nn.Module):
179
+ """
180
+ Construct the patch and position embeddings.
181
+ """
182
+
183
+ def __init__(self, config):
184
+ super().__init__()
185
+
186
+ self.patch_embeddings = NatPatchEmbeddings(config)
187
+
188
+ self.norm = nn.LayerNorm(config.embed_dim)
189
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
190
+
191
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor]:
192
+ embeddings = self.patch_embeddings(pixel_values)
193
+ embeddings = self.norm(embeddings)
194
+
195
+ embeddings = self.dropout(embeddings)
196
+
197
+ return embeddings
198
+
199
+
200
+ class NatPatchEmbeddings(nn.Module):
201
+ """
202
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
203
+ `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
204
+ Transformer.
205
+ """
206
+
207
+ def __init__(self, config):
208
+ super().__init__()
209
+ patch_size = config.patch_size
210
+ num_channels, hidden_size = config.num_channels, config.embed_dim
211
+ self.num_channels = num_channels
212
+
213
+ if patch_size == 4:
214
+ pass
215
+ else:
216
+ # TODO: Support arbitrary patch sizes.
217
+ raise ValueError("Dinat only supports patch size of 4 at the moment.")
218
+
219
+ self.projection = nn.Sequential(
220
+ nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
221
+ nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
222
+ )
223
+
224
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor:
225
+ _, num_channels, height, width = pixel_values.shape
226
+ if num_channels != self.num_channels:
227
+ raise ValueError(
228
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
229
+ )
230
+ embeddings = self.projection(pixel_values)
231
+ embeddings = embeddings.permute(0, 2, 3, 1)
232
+
233
+ return embeddings
234
+
235
+
236
+ class NatDownsampler(nn.Module):
237
+ """
238
+ Convolutional Downsampling Layer.
239
+
240
+ Args:
241
+ dim (`int`):
242
+ Number of input channels.
243
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
244
+ Normalization layer class.
245
+ """
246
+
247
+ def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
248
+ super().__init__()
249
+ self.dim = dim
250
+ self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
251
+ self.norm = norm_layer(2 * dim)
252
+
253
+ def forward(self, input_feature: torch.Tensor) -> torch.Tensor:
254
+ input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
255
+ input_feature = self.norm(input_feature)
256
+ return input_feature
257
+
258
+
259
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
260
+ """
261
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
262
+
263
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
264
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
265
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
266
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
267
+ argument.
268
+ """
269
+ if drop_prob == 0.0 or not training:
270
+ return input
271
+ keep_prob = 1 - drop_prob
272
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
273
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
274
+ random_tensor.floor_() # binarize
275
+ output = input.div(keep_prob) * random_tensor
276
+ return output
277
+
278
+
279
+ class NatDropPath(nn.Module):
280
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
281
+
282
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
283
+ super().__init__()
284
+ self.drop_prob = drop_prob
285
+
286
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
287
+ return drop_path(hidden_states, self.drop_prob, self.training)
288
+
289
+ def extra_repr(self) -> str:
290
+ return "p={}".format(self.drop_prob)
291
+
292
+
293
+ class NeighborhoodAttention(nn.Module):
294
+ def __init__(self, config, dim, num_heads, kernel_size):
295
+ super().__init__()
296
+ if dim % num_heads != 0:
297
+ raise ValueError(
298
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
299
+ )
300
+
301
+ self.num_attention_heads = num_heads
302
+ self.attention_head_size = int(dim / num_heads)
303
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
304
+ self.kernel_size = kernel_size
305
+
306
+ # rpb is learnable relative positional biases; same concept is used Swin.
307
+ self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1)))
308
+
309
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
310
+ self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
311
+ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
312
+
313
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
314
+
315
+ def transpose_for_scores(self, x):
316
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
317
+ x = x.view(new_x_shape)
318
+ return x.permute(0, 3, 1, 2, 4)
319
+
320
+ def forward(
321
+ self,
322
+ hidden_states: torch.Tensor,
323
+ output_attentions: Optional[bool] = False,
324
+ ) -> Tuple[torch.Tensor]:
325
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
326
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
327
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
328
+
329
+ # Apply the scale factor before computing attention weights. It's usually more efficient because
330
+ # attention weights are typically a bigger tensor compared to query.
331
+ # It gives identical results because scalars are commutable in matrix multiplication.
332
+ query_layer = query_layer / math.sqrt(self.attention_head_size)
333
+
334
+ # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases.
335
+ attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, 1)
336
+
337
+ # Normalize the attention scores to probabilities.
338
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
339
+
340
+ # This is actually dropping out entire tokens to attend to, which might
341
+ # seem a bit unusual, but is taken from the original Transformer paper.
342
+ attention_probs = self.dropout(attention_probs)
343
+
344
+ context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, 1)
345
+ context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
346
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
347
+ context_layer = context_layer.view(new_context_layer_shape)
348
+
349
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
350
+
351
+ return outputs
352
+
353
+
354
+ class NeighborhoodAttentionOutput(nn.Module):
355
+ def __init__(self, config, dim):
356
+ super().__init__()
357
+ self.dense = nn.Linear(dim, dim)
358
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
359
+
360
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
361
+ hidden_states = self.dense(hidden_states)
362
+ hidden_states = self.dropout(hidden_states)
363
+
364
+ return hidden_states
365
+
366
+
367
+ class NeighborhoodAttentionModule(nn.Module):
368
+ def __init__(self, config, dim, num_heads, kernel_size):
369
+ super().__init__()
370
+ self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size)
371
+ self.output = NeighborhoodAttentionOutput(config, dim)
372
+ self.pruned_heads = set()
373
+
374
+ def prune_heads(self, heads):
375
+ if len(heads) == 0:
376
+ return
377
+ heads, index = find_pruneable_heads_and_indices(
378
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
379
+ )
380
+
381
+ # Prune linear layers
382
+ self.self.query = prune_linear_layer(self.self.query, index)
383
+ self.self.key = prune_linear_layer(self.self.key, index)
384
+ self.self.value = prune_linear_layer(self.self.value, index)
385
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
386
+
387
+ # Update hyper params and store pruned heads
388
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
389
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
390
+ self.pruned_heads = self.pruned_heads.union(heads)
391
+
392
+ def forward(
393
+ self,
394
+ hidden_states: torch.Tensor,
395
+ output_attentions: Optional[bool] = False,
396
+ ) -> Tuple[torch.Tensor]:
397
+ self_outputs = self.self(hidden_states, output_attentions)
398
+ attention_output = self.output(self_outputs[0], hidden_states)
399
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
400
+ return outputs
401
+
402
+
403
+ class NatIntermediate(nn.Module):
404
+ def __init__(self, config, dim):
405
+ super().__init__()
406
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
407
+ if isinstance(config.hidden_act, str):
408
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
409
+ else:
410
+ self.intermediate_act_fn = config.hidden_act
411
+
412
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
413
+ hidden_states = self.dense(hidden_states)
414
+ hidden_states = self.intermediate_act_fn(hidden_states)
415
+ return hidden_states
416
+
417
+
418
+ class NatOutput(nn.Module):
419
+ def __init__(self, config, dim):
420
+ super().__init__()
421
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
422
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
423
+
424
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
425
+ hidden_states = self.dense(hidden_states)
426
+ hidden_states = self.dropout(hidden_states)
427
+ return hidden_states
428
+
429
+
430
+ class NatLayer(nn.Module):
431
+ def __init__(self, config, dim, num_heads, drop_path_rate=0.0):
432
+ super().__init__()
433
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
434
+ self.kernel_size = config.kernel_size
435
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
436
+ self.attention = NeighborhoodAttentionModule(config, dim, num_heads, kernel_size=self.kernel_size)
437
+ self.drop_path = NatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
438
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
439
+ self.intermediate = NatIntermediate(config, dim)
440
+ self.output = NatOutput(config, dim)
441
+ self.layer_scale_parameters = (
442
+ nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
443
+ if config.layer_scale_init_value > 0
444
+ else None
445
+ )
446
+
447
+ def maybe_pad(self, hidden_states, height, width):
448
+ window_size = self.kernel_size
449
+ pad_values = (0, 0, 0, 0, 0, 0)
450
+ if height < window_size or width < window_size:
451
+ pad_l = pad_t = 0
452
+ pad_r = max(0, window_size - width)
453
+ pad_b = max(0, window_size - height)
454
+ pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b)
455
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
456
+ return hidden_states, pad_values
457
+
458
+ def forward(
459
+ self,
460
+ hidden_states: torch.Tensor,
461
+ output_attentions: Optional[bool] = False,
462
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
463
+ batch_size, height, width, channels = hidden_states.size()
464
+ shortcut = hidden_states
465
+
466
+ hidden_states = self.layernorm_before(hidden_states)
467
+ # pad hidden_states if they are smaller than kernel size
468
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
469
+
470
+ _, height_pad, width_pad, _ = hidden_states.shape
471
+
472
+ attention_outputs = self.attention(hidden_states, output_attentions=output_attentions)
473
+
474
+ attention_output = attention_outputs[0]
475
+
476
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
477
+ if was_padded:
478
+ attention_output = attention_output[:, :height, :width, :].contiguous()
479
+
480
+ if self.layer_scale_parameters is not None:
481
+ attention_output = self.layer_scale_parameters[0] * attention_output
482
+
483
+ hidden_states = shortcut + self.drop_path(attention_output)
484
+
485
+ layer_output = self.layernorm_after(hidden_states)
486
+ layer_output = self.output(self.intermediate(layer_output))
487
+
488
+ if self.layer_scale_parameters is not None:
489
+ layer_output = self.layer_scale_parameters[1] * layer_output
490
+
491
+ layer_output = hidden_states + self.drop_path(layer_output)
492
+
493
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
494
+ return layer_outputs
495
+
496
+
497
+ class NatStage(nn.Module):
498
+ def __init__(self, config, dim, depth, num_heads, drop_path_rate, downsample):
499
+ super().__init__()
500
+ self.config = config
501
+ self.dim = dim
502
+ self.layers = nn.ModuleList(
503
+ [
504
+ NatLayer(
505
+ config=config,
506
+ dim=dim,
507
+ num_heads=num_heads,
508
+ drop_path_rate=drop_path_rate[i],
509
+ )
510
+ for i in range(depth)
511
+ ]
512
+ )
513
+
514
+ # patch merging layer
515
+ if downsample is not None:
516
+ self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
517
+ else:
518
+ self.downsample = None
519
+
520
+ self.pointing = False
521
+
522
+ def forward(
523
+ self,
524
+ hidden_states: torch.Tensor,
525
+ output_attentions: Optional[bool] = False,
526
+ ) -> Tuple[torch.Tensor]:
527
+ _, height, width, _ = hidden_states.size()
528
+ for i, layer_module in enumerate(self.layers):
529
+ layer_outputs = layer_module(hidden_states, output_attentions)
530
+ hidden_states = layer_outputs[0]
531
+
532
+ hidden_states_before_downsampling = hidden_states
533
+ if self.downsample is not None:
534
+ hidden_states = self.downsample(hidden_states_before_downsampling)
535
+
536
+ stage_outputs = (hidden_states, hidden_states_before_downsampling)
537
+
538
+ if output_attentions:
539
+ stage_outputs += layer_outputs[1:]
540
+ return stage_outputs
541
+
542
+
543
+ class NatEncoder(nn.Module):
544
+ def __init__(self, config):
545
+ super().__init__()
546
+ self.num_levels = len(config.depths)
547
+ self.config = config
548
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
549
+ self.levels = nn.ModuleList(
550
+ [
551
+ NatStage(
552
+ config=config,
553
+ dim=int(config.embed_dim * 2**i_layer),
554
+ depth=config.depths[i_layer],
555
+ num_heads=config.num_heads[i_layer],
556
+ drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
557
+ downsample=NatDownsampler if (i_layer < self.num_levels - 1) else None,
558
+ )
559
+ for i_layer in range(self.num_levels)
560
+ ]
561
+ )
562
+
563
+ def forward(
564
+ self,
565
+ hidden_states: torch.Tensor,
566
+ output_attentions: Optional[bool] = False,
567
+ output_hidden_states: Optional[bool] = False,
568
+ output_hidden_states_before_downsampling: Optional[bool] = False,
569
+ return_dict: Optional[bool] = True,
570
+ ) -> Union[Tuple, NatEncoderOutput]:
571
+ all_hidden_states = () if output_hidden_states else None
572
+ all_reshaped_hidden_states = () if output_hidden_states else None
573
+ all_self_attentions = () if output_attentions else None
574
+
575
+ if output_hidden_states:
576
+ # rearrange b h w c -> b c h w
577
+ reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
578
+ all_hidden_states += (hidden_states,)
579
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
580
+
581
+ for i, layer_module in enumerate(self.levels):
582
+ layer_outputs = layer_module(hidden_states, output_attentions)
583
+
584
+ hidden_states = layer_outputs[0]
585
+ hidden_states_before_downsampling = layer_outputs[1]
586
+
587
+ if output_hidden_states and output_hidden_states_before_downsampling:
588
+ # rearrange b h w c -> b c h w
589
+ reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
590
+ all_hidden_states += (hidden_states_before_downsampling,)
591
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
592
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
593
+ # rearrange b h w c -> b c h w
594
+ reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
595
+ all_hidden_states += (hidden_states,)
596
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
597
+
598
+ if output_attentions:
599
+ all_self_attentions += layer_outputs[2:]
600
+
601
+ if not return_dict:
602
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
603
+
604
+ return NatEncoderOutput(
605
+ last_hidden_state=hidden_states,
606
+ hidden_states=all_hidden_states,
607
+ attentions=all_self_attentions,
608
+ reshaped_hidden_states=all_reshaped_hidden_states,
609
+ )
610
+
611
+
612
+ class NatPreTrainedModel(PreTrainedModel):
613
+ """
614
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
615
+ models.
616
+ """
617
+
618
+ config_class = NatConfig
619
+ base_model_prefix = "nat"
620
+ main_input_name = "pixel_values"
621
+
622
+ def _init_weights(self, module):
623
+ """Initialize the weights"""
624
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
625
+ # Slightly different from the TF version which uses truncated_normal for initialization
626
+ # cf https://github.com/pytorch/pytorch/pull/5617
627
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
628
+ if module.bias is not None:
629
+ module.bias.data.zero_()
630
+ elif isinstance(module, nn.LayerNorm):
631
+ module.bias.data.zero_()
632
+ module.weight.data.fill_(1.0)
633
+
634
+
635
+ NAT_START_DOCSTRING = r"""
636
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
637
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
638
+ behavior.
639
+
640
+ Parameters:
641
+ config ([`NatConfig`]): Model configuration class with all the parameters of the model.
642
+ Initializing with a config file does not load the weights associated with the model, only the
643
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
644
+ """
645
+
646
+
647
+ NAT_INPUTS_DOCSTRING = r"""
648
+ Args:
649
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
650
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
651
+ for details.
652
+
653
+ output_attentions (`bool`, *optional*):
654
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
655
+ tensors for more detail.
656
+ output_hidden_states (`bool`, *optional*):
657
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
658
+ more detail.
659
+ return_dict (`bool`, *optional*):
660
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
661
+ """
662
+
663
+
664
+ @add_start_docstrings(
665
+ "The bare Nat Model transformer outputting raw hidden-states without any specific head on top.",
666
+ NAT_START_DOCSTRING,
667
+ )
668
+ class NatModel(NatPreTrainedModel):
669
+ def __init__(self, config, add_pooling_layer=True):
670
+ super().__init__(config)
671
+
672
+ requires_backends(self, ["natten"])
673
+
674
+ self.config = config
675
+ self.num_levels = len(config.depths)
676
+ self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1))
677
+
678
+ self.embeddings = NatEmbeddings(config)
679
+ self.encoder = NatEncoder(config)
680
+
681
+ self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
682
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
683
+
684
+ # Initialize weights and apply final processing
685
+ self.post_init()
686
+
687
+ def get_input_embeddings(self):
688
+ return self.embeddings.patch_embeddings
689
+
690
+ def _prune_heads(self, heads_to_prune):
691
+ """
692
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
693
+ class PreTrainedModel
694
+ """
695
+ for layer, heads in heads_to_prune.items():
696
+ self.encoder.layer[layer].attention.prune_heads(heads)
697
+
698
+ @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
699
+ @add_code_sample_docstrings(
700
+ checkpoint=_CHECKPOINT_FOR_DOC,
701
+ output_type=NatModelOutput,
702
+ config_class=_CONFIG_FOR_DOC,
703
+ modality="vision",
704
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
705
+ )
706
+ def forward(
707
+ self,
708
+ pixel_values: Optional[torch.FloatTensor] = None,
709
+ output_attentions: Optional[bool] = None,
710
+ output_hidden_states: Optional[bool] = None,
711
+ return_dict: Optional[bool] = None,
712
+ ) -> Union[Tuple, NatModelOutput]:
713
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
714
+ output_hidden_states = (
715
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
716
+ )
717
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
718
+
719
+ if pixel_values is None:
720
+ raise ValueError("You have to specify pixel_values")
721
+
722
+ embedding_output = self.embeddings(pixel_values)
723
+
724
+ encoder_outputs = self.encoder(
725
+ embedding_output,
726
+ output_attentions=output_attentions,
727
+ output_hidden_states=output_hidden_states,
728
+ return_dict=return_dict,
729
+ )
730
+
731
+ sequence_output = encoder_outputs[0]
732
+ sequence_output = self.layernorm(sequence_output)
733
+
734
+ pooled_output = None
735
+ if self.pooler is not None:
736
+ pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2))
737
+ pooled_output = torch.flatten(pooled_output, 1)
738
+
739
+ if not return_dict:
740
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
741
+
742
+ return output
743
+
744
+ return NatModelOutput(
745
+ last_hidden_state=sequence_output,
746
+ pooler_output=pooled_output,
747
+ hidden_states=encoder_outputs.hidden_states,
748
+ attentions=encoder_outputs.attentions,
749
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
750
+ )
751
+
752
+
753
+ @add_start_docstrings(
754
+ """
755
+ Nat Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
756
+ the [CLS] token) e.g. for ImageNet.
757
+ """,
758
+ NAT_START_DOCSTRING,
759
+ )
760
+ class NatForImageClassification(NatPreTrainedModel):
761
+ def __init__(self, config):
762
+ super().__init__(config)
763
+
764
+ requires_backends(self, ["natten"])
765
+
766
+ self.num_labels = config.num_labels
767
+ self.nat = NatModel(config)
768
+
769
+ # Classifier head
770
+ self.classifier = (
771
+ nn.Linear(self.nat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
772
+ )
773
+
774
+ # Initialize weights and apply final processing
775
+ self.post_init()
776
+
777
+ @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
778
+ @add_code_sample_docstrings(
779
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
780
+ output_type=NatImageClassifierOutput,
781
+ config_class=_CONFIG_FOR_DOC,
782
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
783
+ )
784
+ def forward(
785
+ self,
786
+ pixel_values: Optional[torch.FloatTensor] = None,
787
+ labels: Optional[torch.LongTensor] = None,
788
+ output_attentions: Optional[bool] = None,
789
+ output_hidden_states: Optional[bool] = None,
790
+ return_dict: Optional[bool] = None,
791
+ ) -> Union[Tuple, NatImageClassifierOutput]:
792
+ r"""
793
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
794
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
795
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
796
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
797
+ """
798
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
799
+
800
+ outputs = self.nat(
801
+ pixel_values,
802
+ output_attentions=output_attentions,
803
+ output_hidden_states=output_hidden_states,
804
+ return_dict=return_dict,
805
+ )
806
+
807
+ pooled_output = outputs[1]
808
+
809
+ logits = self.classifier(pooled_output)
810
+
811
+ loss = None
812
+ if labels is not None:
813
+ if self.config.problem_type is None:
814
+ if self.num_labels == 1:
815
+ self.config.problem_type = "regression"
816
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
817
+ self.config.problem_type = "single_label_classification"
818
+ else:
819
+ self.config.problem_type = "multi_label_classification"
820
+
821
+ if self.config.problem_type == "regression":
822
+ loss_fct = MSELoss()
823
+ if self.num_labels == 1:
824
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
825
+ else:
826
+ loss = loss_fct(logits, labels)
827
+ elif self.config.problem_type == "single_label_classification":
828
+ loss_fct = CrossEntropyLoss()
829
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
830
+ elif self.config.problem_type == "multi_label_classification":
831
+ loss_fct = BCEWithLogitsLoss()
832
+ loss = loss_fct(logits, labels)
833
+
834
+ if not return_dict:
835
+ output = (logits,) + outputs[2:]
836
+ return ((loss,) + output) if loss is not None else output
837
+
838
+ return NatImageClassifierOutput(
839
+ loss=loss,
840
+ logits=logits,
841
+ hidden_states=outputs.hidden_states,
842
+ attentions=outputs.attentions,
843
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
844
+ )
845
+
846
+
847
+ @add_start_docstrings(
848
+ "NAT backbone, to be used with frameworks like DETR and MaskFormer.",
849
+ NAT_START_DOCSTRING,
850
+ )
851
+ class NatBackbone(NatPreTrainedModel, BackboneMixin):
852
+ def __init__(self, config):
853
+ super().__init__(config)
854
+ super()._init_backbone(config)
855
+
856
+ requires_backends(self, ["natten"])
857
+
858
+ self.embeddings = NatEmbeddings(config)
859
+ self.encoder = NatEncoder(config)
860
+ self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
861
+
862
+ # Add layer norms to hidden states of out_features
863
+ hidden_states_norms = {}
864
+ for stage, num_channels in zip(self.out_features, self.channels):
865
+ hidden_states_norms[stage] = nn.LayerNorm(num_channels)
866
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
867
+
868
+ # Initialize weights and apply final processing
869
+ self.post_init()
870
+
871
+ def get_input_embeddings(self):
872
+ return self.embeddings.patch_embeddings
873
+
874
+ @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
875
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
876
+ def forward(
877
+ self,
878
+ pixel_values: torch.Tensor,
879
+ output_hidden_states: Optional[bool] = None,
880
+ output_attentions: Optional[bool] = None,
881
+ return_dict: Optional[bool] = None,
882
+ ) -> BackboneOutput:
883
+ """
884
+ Returns:
885
+
886
+ Examples:
887
+
888
+ ```python
889
+ >>> from transformers import AutoImageProcessor, AutoBackbone
890
+ >>> import torch
891
+ >>> from PIL import Image
892
+ >>> import requests
893
+
894
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
895
+ >>> image = Image.open(requests.get(url, stream=True).raw)
896
+
897
+ >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
898
+ >>> model = AutoBackbone.from_pretrained(
899
+ ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
900
+ ... )
901
+
902
+ >>> inputs = processor(image, return_tensors="pt")
903
+
904
+ >>> outputs = model(**inputs)
905
+
906
+ >>> feature_maps = outputs.feature_maps
907
+ >>> list(feature_maps[-1].shape)
908
+ [1, 512, 7, 7]
909
+ ```"""
910
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
911
+ output_hidden_states = (
912
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
913
+ )
914
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
915
+
916
+ embedding_output = self.embeddings(pixel_values)
917
+
918
+ outputs = self.encoder(
919
+ embedding_output,
920
+ output_attentions=output_attentions,
921
+ output_hidden_states=True,
922
+ output_hidden_states_before_downsampling=True,
923
+ return_dict=True,
924
+ )
925
+
926
+ hidden_states = outputs.reshaped_hidden_states
927
+
928
+ feature_maps = ()
929
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
930
+ if stage in self.out_features:
931
+ # TODO can we simplify this?
932
+ batch_size, num_channels, height, width = hidden_state.shape
933
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
934
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
935
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
936
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
937
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
938
+ feature_maps += (hidden_state,)
939
+
940
+ if not return_dict:
941
+ output = (feature_maps,)
942
+ if output_hidden_states:
943
+ output += (outputs.hidden_states,)
944
+ return output
945
+
946
+ return BackboneOutput(
947
+ feature_maps=feature_maps,
948
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
949
+ attentions=outputs.attentions,
950
+ )
951
+
952
+
953
+ __all__ = ["NatForImageClassification", "NatModel", "NatPreTrainedModel", "NatBackbone"]
docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_nezha import *
22
+ from .modeling_nezha import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .... import PretrainedConfig
2
+
3
+
4
+ class NezhaConfig(PretrainedConfig):
5
+ r"""
6
+ This is the configuration class to store the configuration of an [`NezhaModel`]. It is used to instantiate an Nezha
7
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
8
+ defaults will yield a similar configuration to that of the Nezha
9
+ [sijunhe/nezha-cn-base](https://huggingface.co/sijunhe/nezha-cn-base) architecture.
10
+
11
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
12
+ documentation from [`PretrainedConfig`] for more information.
13
+
14
+
15
+ Args:
16
+ vocab_size (`int`, optional, defaults to 21128):
17
+ Vocabulary size of the NEZHA model. Defines the different tokens that can be represented by the
18
+ *inputs_ids* passed to the forward method of [`NezhaModel`].
19
+ hidden_size (`int`, optional, defaults to 768):
20
+ Dimensionality of the encoder layers and the pooler layer.
21
+ num_hidden_layers (`int`, optional, defaults to 12):
22
+ Number of hidden layers in the Transformer encoder.
23
+ num_attention_heads (`int`, optional, defaults to 12):
24
+ Number of attention heads for each attention layer in the Transformer encoder.
25
+ intermediate_size (`int`, optional, defaults to 3072):
26
+ The dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
27
+ hidden_act (`str` or `function`, optional, defaults to "gelu"):
28
+ The non-linear activation function (function or string) in the encoder and pooler.
29
+ hidden_dropout_prob (`float`, optional, defaults to 0.1):
30
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
31
+ attention_probs_dropout_prob (`float`, optional, defaults to 0.1):
32
+ The dropout ratio for the attention probabilities.
33
+ max_position_embeddings (`int`, optional, defaults to 512):
34
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
35
+ (e.g., 512 or 1024 or 2048).
36
+ type_vocab_size (`int`, optional, defaults to 2):
37
+ The vocabulary size of the *token_type_ids* passed into [`NezhaModel`].
38
+ initializer_range (`float`, optional, defaults to 0.02):
39
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
40
+ layer_norm_eps (`float`, optional, defaults to 1e-12):
41
+ The epsilon used by the layer normalization layers.
42
+ classifier_dropout (`float`, optional, defaults to 0.1):
43
+ The dropout ratio for attached classifiers.
44
+ is_decoder (`bool`, *optional*, defaults to `False`):
45
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
46
+
47
+ Example:
48
+
49
+ ```python
50
+ >>> from transformers import NezhaConfig, NezhaModel
51
+
52
+ >>> # Initializing an Nezha configuration
53
+ >>> configuration = NezhaConfig()
54
+
55
+ >>> # Initializing a model (with random weights) from the Nezha-base style configuration model
56
+ >>> model = NezhaModel(configuration)
57
+
58
+ >>> # Accessing the model configuration
59
+ >>> configuration = model.config
60
+ ```"""
61
+
62
+ model_type = "nezha"
63
+
64
+ def __init__(
65
+ self,
66
+ vocab_size=21128,
67
+ hidden_size=768,
68
+ num_hidden_layers=12,
69
+ num_attention_heads=12,
70
+ intermediate_size=3072,
71
+ hidden_act="gelu",
72
+ hidden_dropout_prob=0.1,
73
+ attention_probs_dropout_prob=0.1,
74
+ max_position_embeddings=512,
75
+ max_relative_position=64,
76
+ type_vocab_size=2,
77
+ initializer_range=0.02,
78
+ layer_norm_eps=1e-12,
79
+ classifier_dropout=0.1,
80
+ pad_token_id=0,
81
+ bos_token_id=2,
82
+ eos_token_id=3,
83
+ use_cache=True,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
87
+
88
+ self.vocab_size = vocab_size
89
+ self.hidden_size = hidden_size
90
+ self.num_hidden_layers = num_hidden_layers
91
+ self.num_attention_heads = num_attention_heads
92
+ self.hidden_act = hidden_act
93
+ self.intermediate_size = intermediate_size
94
+ self.hidden_dropout_prob = hidden_dropout_prob
95
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
96
+ self.max_position_embeddings = max_position_embeddings
97
+ self.max_relative_position = max_relative_position
98
+ self.type_vocab_size = type_vocab_size
99
+ self.initializer_range = initializer_range
100
+ self.layer_norm_eps = layer_norm_eps
101
+ self.classifier_dropout = classifier_dropout
102
+ self.use_cache = use_cache
103
+
104
+
105
+ __all__ = ["NezhaConfig"]
docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py ADDED
@@ -0,0 +1,1697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Nezha model."""
16
+
17
+ import math
18
+ import os
19
+ import warnings
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from ....activations import ACT2FN
29
+ from ....modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from ....modeling_utils import PreTrainedModel
40
+ from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
41
+ from ....utils import (
42
+ ModelOutput,
43
+ add_code_sample_docstrings,
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ logging,
47
+ replace_return_docstrings,
48
+ )
49
+ from .configuration_nezha import NezhaConfig
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CHECKPOINT_FOR_DOC = "sijunhe/nezha-cn-base"
55
+ _CONFIG_FOR_DOC = "NezhaConfig"
56
+
57
+
58
+ def load_tf_weights_in_nezha(model, config, tf_checkpoint_path):
59
+ """Load tf checkpoints in a pytorch model."""
60
+ try:
61
+ import re
62
+
63
+ import numpy as np
64
+ import tensorflow as tf
65
+ except ImportError:
66
+ logger.error(
67
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
68
+ "https://www.tensorflow.org/install/ for installation instructions."
69
+ )
70
+ raise
71
+ tf_path = os.path.abspath(tf_checkpoint_path)
72
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
73
+ # Load weights from TF model
74
+ init_vars = tf.train.list_variables(tf_path)
75
+ names = []
76
+ arrays = []
77
+ for name, shape in init_vars:
78
+ logger.info(f"Loading TF weight {name} with shape {shape}")
79
+ array = tf.train.load_variable(tf_path, name)
80
+ names.append(name)
81
+ arrays.append(array)
82
+
83
+ for name, array in zip(names, arrays):
84
+ name = name.split("/")
85
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
86
+ # which are not required for using pretrained model
87
+ if any(
88
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
89
+ for n in name
90
+ ):
91
+ logger.info(f"Skipping {'/'.join(name)}")
92
+ continue
93
+ pointer = model
94
+ for m_name in name:
95
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
96
+ scope_names = re.split(r"_(\d+)", m_name)
97
+ else:
98
+ scope_names = [m_name]
99
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
100
+ pointer = getattr(pointer, "weight")
101
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
102
+ pointer = getattr(pointer, "bias")
103
+ elif scope_names[0] == "output_weights":
104
+ pointer = getattr(pointer, "weight")
105
+ elif scope_names[0] == "squad":
106
+ pointer = getattr(pointer, "classifier")
107
+ else:
108
+ try:
109
+ pointer = getattr(pointer, scope_names[0])
110
+ except AttributeError:
111
+ logger.info(f"Skipping {'/'.join(name)}")
112
+ continue
113
+ if len(scope_names) >= 2:
114
+ num = int(scope_names[1])
115
+ pointer = pointer[num]
116
+ if m_name[-11:] == "_embeddings":
117
+ pointer = getattr(pointer, "weight")
118
+ elif m_name == "kernel":
119
+ array = np.transpose(array)
120
+ try:
121
+ if pointer.shape != array.shape:
122
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
123
+ except AssertionError as e:
124
+ e.args += (pointer.shape, array.shape)
125
+ raise
126
+ logger.info(f"Initialize PyTorch weight {name}")
127
+ pointer.data = torch.from_numpy(array)
128
+ return model
129
+
130
+
131
+ class NezhaRelativePositionsEncoding(nn.Module):
132
+ """Implement the Functional Relative Position Encoding"""
133
+
134
+ def __init__(self, length, depth, max_relative_position=127):
135
+ super().__init__()
136
+ vocab_size = max_relative_position * 2 + 1
137
+ range_vec = torch.arange(length)
138
+ range_mat = range_vec.repeat(length).view(length, length)
139
+ distance_mat = range_mat - torch.t(range_mat)
140
+ distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position)
141
+ final_mat = distance_mat_clipped + max_relative_position
142
+
143
+ embeddings_table = torch.zeros(vocab_size, depth)
144
+ position = torch.arange(0, vocab_size, dtype=torch.int64).float().unsqueeze(1)
145
+ div_term = torch.exp(torch.arange(0, depth, 2).float() * (-math.log(10000.0) / depth))
146
+ embeddings_table[:, 0::2] = torch.sin(position * div_term)
147
+ embeddings_table[:, 1::2] = torch.cos(position * div_term)
148
+
149
+ flat_relative_positions_matrix = final_mat.view(-1)
150
+ one_hot_relative_positions_matrix = torch.nn.functional.one_hot(
151
+ flat_relative_positions_matrix, num_classes=vocab_size
152
+ ).float()
153
+ positions_encoding = torch.matmul(one_hot_relative_positions_matrix, embeddings_table)
154
+ my_shape = list(final_mat.size())
155
+ my_shape.append(depth)
156
+ positions_encoding = positions_encoding.view(my_shape)
157
+ self.register_buffer("positions_encoding", positions_encoding, persistent=False)
158
+
159
+ def forward(self, length):
160
+ return self.positions_encoding[:length, :length, :]
161
+
162
+
163
+ class NezhaEmbeddings(nn.Module):
164
+ """Construct the embeddings from word and token_type embeddings."""
165
+
166
+ def __init__(self, config):
167
+ super().__init__()
168
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
169
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
170
+
171
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
172
+ # any TensorFlow checkpoint file
173
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
174
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
175
+ self.register_buffer(
176
+ "token_type_ids", torch.zeros((1, config.max_position_embeddings), dtype=torch.long), persistent=False
177
+ )
178
+
179
+ def forward(
180
+ self,
181
+ input_ids: Optional[torch.LongTensor] = None,
182
+ token_type_ids: Optional[torch.LongTensor] = None,
183
+ inputs_embeds: Optional[torch.FloatTensor] = None,
184
+ ) -> torch.Tensor:
185
+ if input_ids is not None:
186
+ input_shape = input_ids.size()
187
+ else:
188
+ input_shape = inputs_embeds.size()[:-1]
189
+
190
+ seq_length = input_shape[1]
191
+
192
+ if inputs_embeds is None:
193
+ inputs_embeds = self.word_embeddings(input_ids)
194
+
195
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
196
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
197
+ # issue #5664
198
+ if token_type_ids is None:
199
+ if hasattr(self, "token_type_ids"):
200
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
201
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
202
+ token_type_ids = buffered_token_type_ids_expanded
203
+ else:
204
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
205
+
206
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
207
+
208
+ embeddings = inputs_embeds + token_type_embeddings
209
+ embeddings = self.LayerNorm(embeddings)
210
+ embeddings = self.dropout(embeddings)
211
+ return embeddings
212
+
213
+
214
+ class NezhaSelfAttention(nn.Module):
215
+ def __init__(self, config):
216
+ super().__init__()
217
+ if config.hidden_size % config.num_attention_heads != 0:
218
+ raise ValueError(
219
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
220
+ f"heads ({config.num_attention_heads})"
221
+ )
222
+
223
+ self.num_attention_heads = config.num_attention_heads
224
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
225
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
226
+
227
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
228
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
229
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
230
+
231
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
232
+ self.relative_positions_encoding = NezhaRelativePositionsEncoding(
233
+ length=config.max_position_embeddings,
234
+ depth=self.attention_head_size,
235
+ max_relative_position=config.max_relative_position,
236
+ )
237
+ self.is_decoder = config.is_decoder
238
+
239
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
240
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
241
+ x = x.view(new_x_shape)
242
+ return x.permute(0, 2, 1, 3)
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.Tensor,
247
+ attention_mask: Optional[torch.FloatTensor] = None,
248
+ head_mask: Optional[torch.FloatTensor] = None,
249
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
250
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
251
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
252
+ output_attentions: Optional[bool] = False,
253
+ ) -> Tuple[torch.Tensor]:
254
+ mixed_query_layer = self.query(hidden_states)
255
+
256
+ # If this is instantiated as a cross-attention module, the keys
257
+ # and values come from an encoder; the attention mask needs to be
258
+ # such that the encoder's padding tokens are not attended to.
259
+ is_cross_attention = encoder_hidden_states is not None
260
+
261
+ if is_cross_attention and past_key_value is not None:
262
+ # reuse k,v, cross_attentions
263
+ key_layer = past_key_value[0]
264
+ value_layer = past_key_value[1]
265
+ attention_mask = encoder_attention_mask
266
+ elif is_cross_attention:
267
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
268
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
269
+ attention_mask = encoder_attention_mask
270
+ elif past_key_value is not None:
271
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
272
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
273
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
274
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
275
+ else:
276
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
277
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
278
+
279
+ query_layer = self.transpose_for_scores(mixed_query_layer)
280
+
281
+ if self.is_decoder:
282
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
283
+ # Further calls to cross_attention layer can then reuse all cross-attention
284
+ # key/value_states (first "if" case)
285
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
286
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
287
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
288
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
289
+ past_key_value = (key_layer, value_layer)
290
+
291
+ # Take the dot product between "query" and "key" to get the raw attention scores.
292
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
293
+
294
+ batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.size()
295
+ relations_keys = self.relative_positions_encoding(to_seq_length)
296
+ query_layer_t = query_layer.permute(2, 0, 1, 3)
297
+
298
+ query_layer_r = query_layer_t.contiguous().view(
299
+ from_seq_length, batch_size * num_attention_heads, self.attention_head_size
300
+ )
301
+ key_position_scores = torch.matmul(query_layer_r, relations_keys.permute(0, 2, 1))
302
+ key_position_scores_r = key_position_scores.view(
303
+ from_seq_length, batch_size, num_attention_heads, from_seq_length
304
+ )
305
+ key_position_scores_r_t = key_position_scores_r.permute(1, 2, 0, 3)
306
+ attention_scores = attention_scores + key_position_scores_r_t
307
+
308
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
309
+
310
+ if attention_mask is not None:
311
+ # Apply the attention mask is (precomputed for all layers in NezhaModel forward() function)
312
+ attention_scores = attention_scores + attention_mask
313
+
314
+ # Normalize the attention scores to probabilities.
315
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
316
+
317
+ # This is actually dropping out entire tokens to attend to, which might
318
+ # seem a bit unusual, but is taken from the original Transformer paper.
319
+ attention_probs = self.dropout(attention_probs)
320
+
321
+ # Mask heads if we want to
322
+ if head_mask is not None:
323
+ attention_probs = attention_probs * head_mask
324
+
325
+ context_layer = torch.matmul(attention_probs, value_layer)
326
+ relations_values = self.relative_positions_encoding(to_seq_length)
327
+ attention_probs_t = attention_probs.permute(2, 0, 1, 3)
328
+ attentions_probs_r = attention_probs_t.contiguous().view(
329
+ from_seq_length, batch_size * num_attention_heads, to_seq_length
330
+ )
331
+ value_position_scores = torch.matmul(attentions_probs_r, relations_values)
332
+ value_position_scores_r = value_position_scores.view(
333
+ from_seq_length, batch_size, num_attention_heads, self.attention_head_size
334
+ )
335
+ value_position_scores_r_t = value_position_scores_r.permute(1, 2, 0, 3)
336
+ context_layer = context_layer + value_position_scores_r_t
337
+
338
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
339
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
340
+ context_layer = context_layer.view(new_context_layer_shape)
341
+
342
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
343
+
344
+ if self.is_decoder:
345
+ outputs = outputs + (past_key_value,)
346
+ return outputs
347
+
348
+
349
+ class NezhaSelfOutput(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
353
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
354
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
355
+
356
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
357
+ hidden_states = self.dense(hidden_states)
358
+ hidden_states = self.dropout(hidden_states)
359
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
360
+ return hidden_states
361
+
362
+
363
+ class NezhaAttention(nn.Module):
364
+ def __init__(self, config):
365
+ super().__init__()
366
+ self.self = NezhaSelfAttention(config)
367
+ self.output = NezhaSelfOutput(config)
368
+ self.pruned_heads = set()
369
+
370
+ def prune_heads(self, heads):
371
+ if len(heads) == 0:
372
+ return
373
+ heads, index = find_pruneable_heads_and_indices(
374
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
375
+ )
376
+
377
+ # Prune linear layers
378
+ self.self.query = prune_linear_layer(self.self.query, index)
379
+ self.self.key = prune_linear_layer(self.self.key, index)
380
+ self.self.value = prune_linear_layer(self.self.value, index)
381
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
382
+
383
+ # Update hyper params and store pruned heads
384
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
385
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
386
+ self.pruned_heads = self.pruned_heads.union(heads)
387
+
388
+ def forward(
389
+ self,
390
+ hidden_states: torch.Tensor,
391
+ attention_mask: Optional[torch.FloatTensor] = None,
392
+ head_mask: Optional[torch.FloatTensor] = None,
393
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
394
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
395
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
396
+ output_attentions: Optional[bool] = False,
397
+ ) -> Tuple[torch.Tensor]:
398
+ self_outputs = self.self(
399
+ hidden_states,
400
+ attention_mask,
401
+ head_mask,
402
+ encoder_hidden_states,
403
+ encoder_attention_mask,
404
+ past_key_value,
405
+ output_attentions,
406
+ )
407
+ attention_output = self.output(self_outputs[0], hidden_states)
408
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
409
+ return outputs
410
+
411
+
412
+ class NezhaIntermediate(nn.Module):
413
+ def __init__(self, config):
414
+ super().__init__()
415
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
416
+ if isinstance(config.hidden_act, str):
417
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
418
+ else:
419
+ self.intermediate_act_fn = config.hidden_act
420
+
421
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
422
+ hidden_states = self.dense(hidden_states)
423
+ hidden_states = self.intermediate_act_fn(hidden_states)
424
+ return hidden_states
425
+
426
+
427
+ class NezhaOutput(nn.Module):
428
+ def __init__(self, config):
429
+ super().__init__()
430
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
431
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
432
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
433
+
434
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
435
+ hidden_states = self.dense(hidden_states)
436
+ hidden_states = self.dropout(hidden_states)
437
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
438
+ return hidden_states
439
+
440
+
441
+ class NezhaLayer(nn.Module):
442
+ def __init__(self, config):
443
+ super().__init__()
444
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
445
+ self.seq_len_dim = 1
446
+ self.attention = NezhaAttention(config)
447
+ self.is_decoder = config.is_decoder
448
+ self.add_cross_attention = config.add_cross_attention
449
+ if self.add_cross_attention:
450
+ if not self.is_decoder:
451
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
452
+ self.crossattention = NezhaAttention(config)
453
+ self.intermediate = NezhaIntermediate(config)
454
+ self.output = NezhaOutput(config)
455
+
456
+ def forward(
457
+ self,
458
+ hidden_states: torch.Tensor,
459
+ attention_mask: Optional[torch.FloatTensor] = None,
460
+ head_mask: Optional[torch.FloatTensor] = None,
461
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
462
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
463
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
464
+ output_attentions: Optional[bool] = False,
465
+ ) -> Tuple[torch.Tensor]:
466
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
467
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
468
+ self_attention_outputs = self.attention(
469
+ hidden_states,
470
+ attention_mask,
471
+ head_mask,
472
+ output_attentions=output_attentions,
473
+ past_key_value=self_attn_past_key_value,
474
+ )
475
+ attention_output = self_attention_outputs[0]
476
+
477
+ # if decoder, the last output is tuple of self-attn cache
478
+ if self.is_decoder:
479
+ outputs = self_attention_outputs[1:-1]
480
+ present_key_value = self_attention_outputs[-1]
481
+ else:
482
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
483
+
484
+ cross_attn_present_key_value = None
485
+ if self.is_decoder and encoder_hidden_states is not None:
486
+ if not hasattr(self, "crossattention"):
487
+ raise ValueError(
488
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
489
+ " by setting `config.add_cross_attention=True`"
490
+ )
491
+
492
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
493
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
494
+ cross_attention_outputs = self.crossattention(
495
+ attention_output,
496
+ attention_mask,
497
+ head_mask,
498
+ encoder_hidden_states,
499
+ encoder_attention_mask,
500
+ cross_attn_past_key_value,
501
+ output_attentions,
502
+ )
503
+ attention_output = cross_attention_outputs[0]
504
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
505
+
506
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
507
+ cross_attn_present_key_value = cross_attention_outputs[-1]
508
+ present_key_value = present_key_value + cross_attn_present_key_value
509
+
510
+ layer_output = apply_chunking_to_forward(
511
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
512
+ )
513
+ outputs = (layer_output,) + outputs
514
+
515
+ # if decoder, return the attn key/values as the last output
516
+ if self.is_decoder:
517
+ outputs = outputs + (present_key_value,)
518
+
519
+ return outputs
520
+
521
+ def feed_forward_chunk(self, attention_output):
522
+ intermediate_output = self.intermediate(attention_output)
523
+ layer_output = self.output(intermediate_output, attention_output)
524
+ return layer_output
525
+
526
+
527
+ class NezhaEncoder(nn.Module):
528
+ def __init__(self, config):
529
+ super().__init__()
530
+ self.config = config
531
+ self.layer = nn.ModuleList([NezhaLayer(config) for _ in range(config.num_hidden_layers)])
532
+ self.gradient_checkpointing = False
533
+
534
+ def forward(
535
+ self,
536
+ hidden_states: torch.Tensor,
537
+ attention_mask: Optional[torch.FloatTensor] = None,
538
+ head_mask: Optional[torch.FloatTensor] = None,
539
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
540
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
541
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
542
+ use_cache: Optional[bool] = None,
543
+ output_attentions: Optional[bool] = False,
544
+ output_hidden_states: Optional[bool] = False,
545
+ return_dict: Optional[bool] = True,
546
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
547
+ all_hidden_states = () if output_hidden_states else None
548
+ all_self_attentions = () if output_attentions else None
549
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
550
+
551
+ if self.gradient_checkpointing and self.training:
552
+ if use_cache:
553
+ logger.warning_once(
554
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
555
+ )
556
+ use_cache = False
557
+
558
+ next_decoder_cache = () if use_cache else None
559
+ for i, layer_module in enumerate(self.layer):
560
+ if output_hidden_states:
561
+ all_hidden_states = all_hidden_states + (hidden_states,)
562
+
563
+ layer_head_mask = head_mask[i] if head_mask is not None else None
564
+ past_key_value = past_key_values[i] if past_key_values is not None else None
565
+
566
+ if self.gradient_checkpointing and self.training:
567
+ layer_outputs = self._gradient_checkpointing_func(
568
+ layer_module.__call__,
569
+ hidden_states,
570
+ attention_mask,
571
+ layer_head_mask,
572
+ encoder_hidden_states,
573
+ encoder_attention_mask,
574
+ past_key_value,
575
+ output_attentions,
576
+ )
577
+ else:
578
+ layer_outputs = layer_module(
579
+ hidden_states,
580
+ attention_mask,
581
+ layer_head_mask,
582
+ encoder_hidden_states,
583
+ encoder_attention_mask,
584
+ past_key_value,
585
+ output_attentions,
586
+ )
587
+
588
+ hidden_states = layer_outputs[0]
589
+ if use_cache:
590
+ next_decoder_cache += (layer_outputs[-1],)
591
+ if output_attentions:
592
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
593
+ if self.config.add_cross_attention:
594
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
595
+
596
+ if output_hidden_states:
597
+ all_hidden_states = all_hidden_states + (hidden_states,)
598
+
599
+ if not return_dict:
600
+ return tuple(
601
+ v
602
+ for v in [
603
+ hidden_states,
604
+ next_decoder_cache,
605
+ all_hidden_states,
606
+ all_self_attentions,
607
+ all_cross_attentions,
608
+ ]
609
+ if v is not None
610
+ )
611
+ return BaseModelOutputWithPastAndCrossAttentions(
612
+ last_hidden_state=hidden_states,
613
+ past_key_values=next_decoder_cache,
614
+ hidden_states=all_hidden_states,
615
+ attentions=all_self_attentions,
616
+ cross_attentions=all_cross_attentions,
617
+ )
618
+
619
+
620
+ class NezhaPooler(nn.Module):
621
+ def __init__(self, config):
622
+ super().__init__()
623
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
624
+ self.activation = nn.Tanh()
625
+
626
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
627
+ # We "pool" the model by simply taking the hidden state corresponding
628
+ # to the first token.
629
+ first_token_tensor = hidden_states[:, 0]
630
+ pooled_output = self.dense(first_token_tensor)
631
+ pooled_output = self.activation(pooled_output)
632
+ return pooled_output
633
+
634
+
635
+ class NezhaPredictionHeadTransform(nn.Module):
636
+ def __init__(self, config):
637
+ super().__init__()
638
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
639
+ if isinstance(config.hidden_act, str):
640
+ self.transform_act_fn = ACT2FN[config.hidden_act]
641
+ else:
642
+ self.transform_act_fn = config.hidden_act
643
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
644
+
645
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
646
+ hidden_states = self.dense(hidden_states)
647
+ hidden_states = self.transform_act_fn(hidden_states)
648
+ hidden_states = self.LayerNorm(hidden_states)
649
+ return hidden_states
650
+
651
+
652
+ class NezhaLMPredictionHead(nn.Module):
653
+ def __init__(self, config):
654
+ super().__init__()
655
+ self.transform = NezhaPredictionHeadTransform(config)
656
+
657
+ # The output weights are the same as the input embeddings, but there is
658
+ # an output-only bias for each token.
659
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
660
+
661
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
662
+
663
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
664
+ self.decoder.bias = self.bias
665
+
666
+ def _tie_weights(self):
667
+ self.decoder.bias = self.bias
668
+
669
+ def forward(self, hidden_states):
670
+ hidden_states = self.transform(hidden_states)
671
+ hidden_states = self.decoder(hidden_states)
672
+ return hidden_states
673
+
674
+
675
+ class NezhaOnlyMLMHead(nn.Module):
676
+ def __init__(self, config):
677
+ super().__init__()
678
+ self.predictions = NezhaLMPredictionHead(config)
679
+
680
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
681
+ prediction_scores = self.predictions(sequence_output)
682
+ return prediction_scores
683
+
684
+
685
+ class NezhaOnlyNSPHead(nn.Module):
686
+ def __init__(self, config):
687
+ super().__init__()
688
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
689
+
690
+ def forward(self, pooled_output):
691
+ seq_relationship_score = self.seq_relationship(pooled_output)
692
+ return seq_relationship_score
693
+
694
+
695
+ class NezhaPreTrainingHeads(nn.Module):
696
+ def __init__(self, config):
697
+ super().__init__()
698
+ self.predictions = NezhaLMPredictionHead(config)
699
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
700
+
701
+ def forward(self, sequence_output, pooled_output):
702
+ prediction_scores = self.predictions(sequence_output)
703
+ seq_relationship_score = self.seq_relationship(pooled_output)
704
+ return prediction_scores, seq_relationship_score
705
+
706
+
707
+ class NezhaPreTrainedModel(PreTrainedModel):
708
+ """
709
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
710
+ models.
711
+ """
712
+
713
+ config_class = NezhaConfig
714
+ load_tf_weights = load_tf_weights_in_nezha
715
+ base_model_prefix = "nezha"
716
+ supports_gradient_checkpointing = True
717
+
718
+ def _init_weights(self, module):
719
+ """Initialize the weights"""
720
+ if isinstance(module, nn.Linear):
721
+ # Slightly different from the TF version which uses truncated_normal for initialization
722
+ # cf https://github.com/pytorch/pytorch/pull/5617
723
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
724
+ if module.bias is not None:
725
+ module.bias.data.zero_()
726
+ elif isinstance(module, nn.Embedding):
727
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
728
+ if module.padding_idx is not None:
729
+ module.weight.data[module.padding_idx].zero_()
730
+ elif isinstance(module, nn.LayerNorm):
731
+ module.bias.data.zero_()
732
+ module.weight.data.fill_(1.0)
733
+
734
+
735
+ @dataclass
736
+ class NezhaForPreTrainingOutput(ModelOutput):
737
+ """
738
+ Output type of [`NezhaForPreTraining`].
739
+
740
+ Args:
741
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
742
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
743
+ (classification) loss.
744
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
745
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
746
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
747
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
748
+ before SoftMax).
749
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
750
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
751
+ shape `(batch_size, sequence_length, hidden_size)`.
752
+
753
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
754
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
755
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
756
+ sequence_length)`.
757
+
758
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
759
+ heads.
760
+ """
761
+
762
+ loss: Optional[torch.FloatTensor] = None
763
+ prediction_logits: Optional[torch.FloatTensor] = None
764
+ seq_relationship_logits: Optional[torch.FloatTensor] = None
765
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
766
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
767
+
768
+
769
+ NEZHA_START_DOCSTRING = r"""
770
+
771
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
772
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
773
+ etc.)
774
+
775
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
776
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
777
+ and behavior.
778
+
779
+ Parameters:
780
+ config ([`NezhaConfig`]): Model configuration class with all the parameters of the model.
781
+ Initializing with a config file does not load the weights associated with the model, only the
782
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
783
+ """
784
+
785
+ NEZHA_INPUTS_DOCSTRING = r"""
786
+ Args:
787
+ input_ids (`torch.LongTensor` of shape `({0})`):
788
+ Indices of input sequence tokens in the vocabulary.
789
+
790
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
791
+ [`PreTrainedTokenizer.__call__`] for details.
792
+
793
+ [What are input IDs?](../glossary#input-ids)
794
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
795
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
796
+
797
+ - 1 for tokens that are **not masked**,
798
+ - 0 for tokens that are **masked**.
799
+
800
+ [What are attention masks?](../glossary#attention-mask)
801
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
802
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
803
+ 1]`:
804
+
805
+ - 0 corresponds to a *sentence A* token,
806
+ - 1 corresponds to a *sentence B* token.
807
+
808
+ [What are token type IDs?](../glossary#token-type-ids)
809
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
810
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
811
+
812
+ - 1 indicates the head is **not masked**,
813
+ - 0 indicates the head is **masked**.
814
+
815
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
816
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
817
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
818
+ model's internal embedding lookup matrix.
819
+ output_attentions (`bool`, *optional*):
820
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
821
+ tensors for more detail.
822
+ output_hidden_states (`bool`, *optional*):
823
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
824
+ more detail.
825
+ return_dict (`bool`, *optional*):
826
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
827
+ """
828
+
829
+
830
+ @add_start_docstrings(
831
+ "The bare Nezha Model transformer outputting raw hidden-states without any specific head on top.",
832
+ NEZHA_START_DOCSTRING,
833
+ )
834
+ class NezhaModel(NezhaPreTrainedModel):
835
+ """
836
+
837
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
838
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
839
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
840
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
841
+
842
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
843
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
844
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
845
+ """
846
+
847
+ def __init__(self, config, add_pooling_layer=True):
848
+ super().__init__(config)
849
+ self.config = config
850
+
851
+ self.embeddings = NezhaEmbeddings(config)
852
+ self.encoder = NezhaEncoder(config)
853
+
854
+ self.pooler = NezhaPooler(config) if add_pooling_layer else None
855
+
856
+ # Initialize weights and apply final processing
857
+ self.post_init()
858
+
859
+ def get_input_embeddings(self):
860
+ return self.embeddings.word_embeddings
861
+
862
+ def set_input_embeddings(self, value):
863
+ self.embeddings.word_embeddings = value
864
+
865
+ def _prune_heads(self, heads_to_prune):
866
+ """
867
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
868
+ class PreTrainedModel
869
+ """
870
+ for layer, heads in heads_to_prune.items():
871
+ self.encoder.layer[layer].attention.prune_heads(heads)
872
+
873
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
874
+ @add_code_sample_docstrings(
875
+ checkpoint=_CHECKPOINT_FOR_DOC,
876
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
877
+ config_class=_CONFIG_FOR_DOC,
878
+ )
879
+ def forward(
880
+ self,
881
+ input_ids: Optional[torch.Tensor] = None,
882
+ attention_mask: Optional[torch.Tensor] = None,
883
+ token_type_ids: Optional[torch.Tensor] = None,
884
+ head_mask: Optional[torch.Tensor] = None,
885
+ inputs_embeds: Optional[torch.Tensor] = None,
886
+ encoder_hidden_states: Optional[torch.Tensor] = None,
887
+ encoder_attention_mask: Optional[torch.Tensor] = None,
888
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
889
+ use_cache: Optional[bool] = None,
890
+ output_attentions: Optional[bool] = None,
891
+ output_hidden_states: Optional[bool] = None,
892
+ return_dict: Optional[bool] = None,
893
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
894
+ r"""
895
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
896
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
897
+ the model is configured as a decoder.
898
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
899
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
900
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
901
+
902
+ - 1 for tokens that are **not masked**,
903
+ - 0 for tokens that are **masked**.
904
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
905
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
906
+
907
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
908
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
909
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
910
+ use_cache (`bool`, *optional*):
911
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
912
+ `past_key_values`).
913
+ """
914
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
915
+ output_hidden_states = (
916
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
917
+ )
918
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
919
+
920
+ if self.config.is_decoder:
921
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
922
+ else:
923
+ use_cache = False
924
+
925
+ if input_ids is not None and inputs_embeds is not None:
926
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
927
+ elif input_ids is not None:
928
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
929
+ input_shape = input_ids.size()
930
+ elif inputs_embeds is not None:
931
+ input_shape = inputs_embeds.size()[:-1]
932
+ else:
933
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
934
+
935
+ batch_size, seq_length = input_shape
936
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
937
+
938
+ # past_key_values_length
939
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
940
+
941
+ if attention_mask is None:
942
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
943
+
944
+ if token_type_ids is None:
945
+ if hasattr(self.embeddings, "token_type_ids"):
946
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
947
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
948
+ token_type_ids = buffered_token_type_ids_expanded
949
+ else:
950
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
951
+
952
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
953
+ # ourselves in which case we just need to make it broadcastable to all heads.
954
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
955
+
956
+ # If a 2D or 3D attention mask is provided for the cross-attention
957
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
958
+ if self.config.is_decoder and encoder_hidden_states is not None:
959
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
960
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
961
+ if encoder_attention_mask is None:
962
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
963
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
964
+ else:
965
+ encoder_extended_attention_mask = None
966
+
967
+ # Prepare head mask if needed
968
+ # 1.0 in head_mask indicate we keep the head
969
+ # attention_probs has shape bsz x n_heads x N x N
970
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
971
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
972
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
973
+
974
+ embedding_output = self.embeddings(
975
+ input_ids=input_ids,
976
+ token_type_ids=token_type_ids,
977
+ inputs_embeds=inputs_embeds,
978
+ )
979
+ encoder_outputs = self.encoder(
980
+ embedding_output,
981
+ attention_mask=extended_attention_mask,
982
+ head_mask=head_mask,
983
+ encoder_hidden_states=encoder_hidden_states,
984
+ encoder_attention_mask=encoder_extended_attention_mask,
985
+ past_key_values=past_key_values,
986
+ use_cache=use_cache,
987
+ output_attentions=output_attentions,
988
+ output_hidden_states=output_hidden_states,
989
+ return_dict=return_dict,
990
+ )
991
+ sequence_output = encoder_outputs[0]
992
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
993
+
994
+ if not return_dict:
995
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
996
+
997
+ return BaseModelOutputWithPoolingAndCrossAttentions(
998
+ last_hidden_state=sequence_output,
999
+ pooler_output=pooled_output,
1000
+ past_key_values=encoder_outputs.past_key_values,
1001
+ hidden_states=encoder_outputs.hidden_states,
1002
+ attentions=encoder_outputs.attentions,
1003
+ cross_attentions=encoder_outputs.cross_attentions,
1004
+ )
1005
+
1006
+
1007
+ @add_start_docstrings(
1008
+ """
1009
+ Nezha Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1010
+ sentence prediction (classification)` head.
1011
+ """,
1012
+ NEZHA_START_DOCSTRING,
1013
+ )
1014
+ class NezhaForPreTraining(NezhaPreTrainedModel):
1015
+ _tied_weights_keys = ["cls.predictions.decoder"]
1016
+
1017
+ def __init__(self, config):
1018
+ super().__init__(config)
1019
+
1020
+ self.nezha = NezhaModel(config)
1021
+ self.cls = NezhaPreTrainingHeads(config)
1022
+
1023
+ # Initialize weights and apply final processing
1024
+ self.post_init()
1025
+
1026
+ def get_output_embeddings(self):
1027
+ return self.cls.predictions.decoder
1028
+
1029
+ def set_output_embeddings(self, new_embeddings):
1030
+ self.cls.predictions.decoder = new_embeddings
1031
+ self.cls.predictions.bias = new_embeddings.bias
1032
+
1033
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1034
+ @replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1035
+ def forward(
1036
+ self,
1037
+ input_ids: Optional[torch.Tensor] = None,
1038
+ attention_mask: Optional[torch.Tensor] = None,
1039
+ token_type_ids: Optional[torch.Tensor] = None,
1040
+ head_mask: Optional[torch.Tensor] = None,
1041
+ inputs_embeds: Optional[torch.Tensor] = None,
1042
+ labels: Optional[torch.Tensor] = None,
1043
+ next_sentence_label: Optional[torch.Tensor] = None,
1044
+ output_attentions: Optional[bool] = None,
1045
+ output_hidden_states: Optional[bool] = None,
1046
+ return_dict: Optional[bool] = None,
1047
+ ) -> Union[Tuple[torch.Tensor], NezhaForPreTrainingOutput]:
1048
+ r"""
1049
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1050
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1051
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1052
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1053
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1054
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
1055
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
1056
+
1057
+ - 0 indicates sequence B is a continuation of sequence A,
1058
+ - 1 indicates sequence B is a random sequence.
1059
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1060
+ Used to hide legacy arguments that have been deprecated.
1061
+
1062
+ Returns:
1063
+
1064
+ Example:
1065
+
1066
+ ```python
1067
+ >>> from transformers import AutoTokenizer, NezhaForPreTraining
1068
+ >>> import torch
1069
+
1070
+ >>> tokenizer = AutoTokenizer.from_pretrained("sijunhe/nezha-cn-base")
1071
+ >>> model = NezhaForPreTraining.from_pretrained("sijunhe/nezha-cn-base")
1072
+
1073
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1074
+ >>> outputs = model(**inputs)
1075
+
1076
+ >>> prediction_logits = outputs.prediction_logits
1077
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1078
+ ```
1079
+ """
1080
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1081
+
1082
+ outputs = self.nezha(
1083
+ input_ids,
1084
+ attention_mask=attention_mask,
1085
+ token_type_ids=token_type_ids,
1086
+ head_mask=head_mask,
1087
+ inputs_embeds=inputs_embeds,
1088
+ output_attentions=output_attentions,
1089
+ output_hidden_states=output_hidden_states,
1090
+ return_dict=return_dict,
1091
+ )
1092
+
1093
+ sequence_output, pooled_output = outputs[:2]
1094
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1095
+
1096
+ total_loss = None
1097
+ if labels is not None and next_sentence_label is not None:
1098
+ loss_fct = CrossEntropyLoss()
1099
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1100
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1101
+ total_loss = masked_lm_loss + next_sentence_loss
1102
+
1103
+ if not return_dict:
1104
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1105
+ return ((total_loss,) + output) if total_loss is not None else output
1106
+
1107
+ return NezhaForPreTrainingOutput(
1108
+ loss=total_loss,
1109
+ prediction_logits=prediction_scores,
1110
+ seq_relationship_logits=seq_relationship_score,
1111
+ hidden_states=outputs.hidden_states,
1112
+ attentions=outputs.attentions,
1113
+ )
1114
+
1115
+
1116
+ @add_start_docstrings("""Nezha Model with a `language modeling` head on top.""", NEZHA_START_DOCSTRING)
1117
+ class NezhaForMaskedLM(NezhaPreTrainedModel):
1118
+ _tied_weights_keys = ["cls.predictions.decoder"]
1119
+
1120
+ def __init__(self, config):
1121
+ super().__init__(config)
1122
+
1123
+ if config.is_decoder:
1124
+ logger.warning(
1125
+ "If you want to use `NezhaForMaskedLM` make sure `config.is_decoder=False` for "
1126
+ "bi-directional self-attention."
1127
+ )
1128
+
1129
+ self.nezha = NezhaModel(config, add_pooling_layer=False)
1130
+ self.cls = NezhaOnlyMLMHead(config)
1131
+
1132
+ # Initialize weights and apply final processing
1133
+ self.post_init()
1134
+
1135
+ def get_output_embeddings(self):
1136
+ return self.cls.predictions.decoder
1137
+
1138
+ def set_output_embeddings(self, new_embeddings):
1139
+ self.cls.predictions.decoder = new_embeddings
1140
+ self.cls.predictions.bias = new_embeddings.bias
1141
+
1142
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1143
+ @add_code_sample_docstrings(
1144
+ checkpoint=_CHECKPOINT_FOR_DOC,
1145
+ output_type=MaskedLMOutput,
1146
+ config_class=_CONFIG_FOR_DOC,
1147
+ )
1148
+ def forward(
1149
+ self,
1150
+ input_ids: Optional[torch.Tensor] = None,
1151
+ attention_mask: Optional[torch.Tensor] = None,
1152
+ token_type_ids: Optional[torch.Tensor] = None,
1153
+ head_mask: Optional[torch.Tensor] = None,
1154
+ inputs_embeds: Optional[torch.Tensor] = None,
1155
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1156
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1157
+ labels: Optional[torch.Tensor] = None,
1158
+ output_attentions: Optional[bool] = None,
1159
+ output_hidden_states: Optional[bool] = None,
1160
+ return_dict: Optional[bool] = None,
1161
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1162
+ r"""
1163
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1164
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1165
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1166
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1167
+ """
1168
+
1169
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1170
+
1171
+ outputs = self.nezha(
1172
+ input_ids,
1173
+ attention_mask=attention_mask,
1174
+ token_type_ids=token_type_ids,
1175
+ head_mask=head_mask,
1176
+ inputs_embeds=inputs_embeds,
1177
+ encoder_hidden_states=encoder_hidden_states,
1178
+ encoder_attention_mask=encoder_attention_mask,
1179
+ output_attentions=output_attentions,
1180
+ output_hidden_states=output_hidden_states,
1181
+ return_dict=return_dict,
1182
+ )
1183
+
1184
+ sequence_output = outputs[0]
1185
+ prediction_scores = self.cls(sequence_output)
1186
+
1187
+ masked_lm_loss = None
1188
+ if labels is not None:
1189
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1190
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1191
+
1192
+ if not return_dict:
1193
+ output = (prediction_scores,) + outputs[2:]
1194
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1195
+
1196
+ return MaskedLMOutput(
1197
+ loss=masked_lm_loss,
1198
+ logits=prediction_scores,
1199
+ hidden_states=outputs.hidden_states,
1200
+ attentions=outputs.attentions,
1201
+ )
1202
+
1203
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1204
+ input_shape = input_ids.shape
1205
+ effective_batch_size = input_shape[0]
1206
+
1207
+ # add a dummy token
1208
+ if self.config.pad_token_id is None:
1209
+ raise ValueError("The PAD token should be defined for generation")
1210
+
1211
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1212
+ dummy_token = torch.full(
1213
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1214
+ )
1215
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1216
+
1217
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1218
+
1219
+
1220
+ @add_start_docstrings(
1221
+ """Nezha Model with a `next sentence prediction (classification)` head on top.""",
1222
+ NEZHA_START_DOCSTRING,
1223
+ )
1224
+ class NezhaForNextSentencePrediction(NezhaPreTrainedModel):
1225
+ def __init__(self, config):
1226
+ super().__init__(config)
1227
+
1228
+ self.nezha = NezhaModel(config)
1229
+ self.cls = NezhaOnlyNSPHead(config)
1230
+
1231
+ # Initialize weights and apply final processing
1232
+ self.post_init()
1233
+
1234
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1235
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1236
+ def forward(
1237
+ self,
1238
+ input_ids: Optional[torch.Tensor] = None,
1239
+ attention_mask: Optional[torch.Tensor] = None,
1240
+ token_type_ids: Optional[torch.Tensor] = None,
1241
+ head_mask: Optional[torch.Tensor] = None,
1242
+ inputs_embeds: Optional[torch.Tensor] = None,
1243
+ labels: Optional[torch.Tensor] = None,
1244
+ output_attentions: Optional[bool] = None,
1245
+ output_hidden_states: Optional[bool] = None,
1246
+ return_dict: Optional[bool] = None,
1247
+ **kwargs,
1248
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
1249
+ r"""
1250
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1251
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1252
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1253
+
1254
+ - 0 indicates sequence B is a continuation of sequence A,
1255
+ - 1 indicates sequence B is a random sequence.
1256
+
1257
+ Returns:
1258
+
1259
+ Example:
1260
+
1261
+ ```python
1262
+ >>> from transformers import AutoTokenizer, NezhaForNextSentencePrediction
1263
+ >>> import torch
1264
+
1265
+ >>> tokenizer = AutoTokenizer.from_pretrained("sijunhe/nezha-cn-base")
1266
+ >>> model = NezhaForNextSentencePrediction.from_pretrained("sijunhe/nezha-cn-base")
1267
+
1268
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1269
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1270
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1271
+
1272
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1273
+ >>> logits = outputs.logits
1274
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1275
+ ```
1276
+ """
1277
+
1278
+ if "next_sentence_label" in kwargs:
1279
+ warnings.warn(
1280
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
1281
+ " `labels` instead.",
1282
+ FutureWarning,
1283
+ )
1284
+ labels = kwargs.pop("next_sentence_label")
1285
+
1286
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1287
+
1288
+ outputs = self.nezha(
1289
+ input_ids,
1290
+ attention_mask=attention_mask,
1291
+ token_type_ids=token_type_ids,
1292
+ head_mask=head_mask,
1293
+ inputs_embeds=inputs_embeds,
1294
+ output_attentions=output_attentions,
1295
+ output_hidden_states=output_hidden_states,
1296
+ return_dict=return_dict,
1297
+ )
1298
+
1299
+ pooled_output = outputs[1]
1300
+
1301
+ seq_relationship_scores = self.cls(pooled_output)
1302
+
1303
+ next_sentence_loss = None
1304
+ if labels is not None:
1305
+ loss_fct = CrossEntropyLoss()
1306
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1307
+
1308
+ if not return_dict:
1309
+ output = (seq_relationship_scores,) + outputs[2:]
1310
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1311
+
1312
+ return NextSentencePredictorOutput(
1313
+ loss=next_sentence_loss,
1314
+ logits=seq_relationship_scores,
1315
+ hidden_states=outputs.hidden_states,
1316
+ attentions=outputs.attentions,
1317
+ )
1318
+
1319
+
1320
+ @add_start_docstrings(
1321
+ """
1322
+ Nezha Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1323
+ output) e.g. for GLUE tasks.
1324
+ """,
1325
+ NEZHA_START_DOCSTRING,
1326
+ )
1327
+ class NezhaForSequenceClassification(NezhaPreTrainedModel):
1328
+ def __init__(self, config):
1329
+ super().__init__(config)
1330
+ self.num_labels = config.num_labels
1331
+ self.config = config
1332
+
1333
+ self.nezha = NezhaModel(config)
1334
+ classifier_dropout = (
1335
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1336
+ )
1337
+ self.dropout = nn.Dropout(classifier_dropout)
1338
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1339
+
1340
+ # Initialize weights and apply final processing
1341
+ self.post_init()
1342
+
1343
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1344
+ @add_code_sample_docstrings(
1345
+ checkpoint=_CHECKPOINT_FOR_DOC,
1346
+ output_type=SequenceClassifierOutput,
1347
+ config_class=_CONFIG_FOR_DOC,
1348
+ )
1349
+ def forward(
1350
+ self,
1351
+ input_ids: Optional[torch.Tensor] = None,
1352
+ attention_mask: Optional[torch.Tensor] = None,
1353
+ token_type_ids: Optional[torch.Tensor] = None,
1354
+ head_mask: Optional[torch.Tensor] = None,
1355
+ inputs_embeds: Optional[torch.Tensor] = None,
1356
+ labels: Optional[torch.Tensor] = None,
1357
+ output_attentions: Optional[bool] = None,
1358
+ output_hidden_states: Optional[bool] = None,
1359
+ return_dict: Optional[bool] = None,
1360
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1361
+ r"""
1362
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1363
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1364
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1365
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1366
+ """
1367
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1368
+
1369
+ outputs = self.nezha(
1370
+ input_ids,
1371
+ attention_mask=attention_mask,
1372
+ token_type_ids=token_type_ids,
1373
+ head_mask=head_mask,
1374
+ inputs_embeds=inputs_embeds,
1375
+ output_attentions=output_attentions,
1376
+ output_hidden_states=output_hidden_states,
1377
+ return_dict=return_dict,
1378
+ )
1379
+
1380
+ pooled_output = outputs[1]
1381
+
1382
+ pooled_output = self.dropout(pooled_output)
1383
+ logits = self.classifier(pooled_output)
1384
+
1385
+ loss = None
1386
+ if labels is not None:
1387
+ if self.config.problem_type is None:
1388
+ if self.num_labels == 1:
1389
+ self.config.problem_type = "regression"
1390
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1391
+ self.config.problem_type = "single_label_classification"
1392
+ else:
1393
+ self.config.problem_type = "multi_label_classification"
1394
+
1395
+ if self.config.problem_type == "regression":
1396
+ loss_fct = MSELoss()
1397
+ if self.num_labels == 1:
1398
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1399
+ else:
1400
+ loss = loss_fct(logits, labels)
1401
+ elif self.config.problem_type == "single_label_classification":
1402
+ loss_fct = CrossEntropyLoss()
1403
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1404
+ elif self.config.problem_type == "multi_label_classification":
1405
+ loss_fct = BCEWithLogitsLoss()
1406
+ loss = loss_fct(logits, labels)
1407
+ if not return_dict:
1408
+ output = (logits,) + outputs[2:]
1409
+ return ((loss,) + output) if loss is not None else output
1410
+
1411
+ return SequenceClassifierOutput(
1412
+ loss=loss,
1413
+ logits=logits,
1414
+ hidden_states=outputs.hidden_states,
1415
+ attentions=outputs.attentions,
1416
+ )
1417
+
1418
+
1419
+ @add_start_docstrings(
1420
+ """
1421
+ Nezha Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1422
+ softmax) e.g. for RocStories/SWAG tasks.
1423
+ """,
1424
+ NEZHA_START_DOCSTRING,
1425
+ )
1426
+ class NezhaForMultipleChoice(NezhaPreTrainedModel):
1427
+ def __init__(self, config):
1428
+ super().__init__(config)
1429
+
1430
+ self.nezha = NezhaModel(config)
1431
+ classifier_dropout = (
1432
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1433
+ )
1434
+ self.dropout = nn.Dropout(classifier_dropout)
1435
+ self.classifier = nn.Linear(config.hidden_size, 1)
1436
+
1437
+ # Initialize weights and apply final processing
1438
+ self.post_init()
1439
+
1440
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1441
+ @add_code_sample_docstrings(
1442
+ checkpoint=_CHECKPOINT_FOR_DOC,
1443
+ output_type=MultipleChoiceModelOutput,
1444
+ config_class=_CONFIG_FOR_DOC,
1445
+ )
1446
+ def forward(
1447
+ self,
1448
+ input_ids: Optional[torch.Tensor] = None,
1449
+ attention_mask: Optional[torch.Tensor] = None,
1450
+ token_type_ids: Optional[torch.Tensor] = None,
1451
+ head_mask: Optional[torch.Tensor] = None,
1452
+ inputs_embeds: Optional[torch.Tensor] = None,
1453
+ labels: Optional[torch.Tensor] = None,
1454
+ output_attentions: Optional[bool] = None,
1455
+ output_hidden_states: Optional[bool] = None,
1456
+ return_dict: Optional[bool] = None,
1457
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1458
+ r"""
1459
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1460
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1461
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1462
+ `input_ids` above)
1463
+ """
1464
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1465
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1466
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1467
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1468
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1469
+ inputs_embeds = (
1470
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1471
+ if inputs_embeds is not None
1472
+ else None
1473
+ )
1474
+
1475
+ outputs = self.nezha(
1476
+ input_ids,
1477
+ attention_mask=attention_mask,
1478
+ token_type_ids=token_type_ids,
1479
+ head_mask=head_mask,
1480
+ inputs_embeds=inputs_embeds,
1481
+ output_attentions=output_attentions,
1482
+ output_hidden_states=output_hidden_states,
1483
+ return_dict=return_dict,
1484
+ )
1485
+
1486
+ pooled_output = outputs[1]
1487
+ print(pooled_output.shape)
1488
+ pooled_output = self.dropout(pooled_output)
1489
+ logits = self.classifier(pooled_output)
1490
+ print(logits.shape)
1491
+ print(num_choices)
1492
+ reshaped_logits = logits.view(-1, num_choices)
1493
+
1494
+ loss = None
1495
+ if labels is not None:
1496
+ loss_fct = CrossEntropyLoss()
1497
+ loss = loss_fct(reshaped_logits, labels)
1498
+
1499
+ if not return_dict:
1500
+ output = (reshaped_logits,) + outputs[2:]
1501
+ return ((loss,) + output) if loss is not None else output
1502
+
1503
+ return MultipleChoiceModelOutput(
1504
+ loss=loss,
1505
+ logits=reshaped_logits,
1506
+ hidden_states=outputs.hidden_states,
1507
+ attentions=outputs.attentions,
1508
+ )
1509
+
1510
+
1511
+ @add_start_docstrings(
1512
+ """
1513
+ Nezha Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1514
+ Named-Entity-Recognition (NER) tasks.
1515
+ """,
1516
+ NEZHA_START_DOCSTRING,
1517
+ )
1518
+ class NezhaForTokenClassification(NezhaPreTrainedModel):
1519
+ def __init__(self, config):
1520
+ super().__init__(config)
1521
+ self.num_labels = config.num_labels
1522
+
1523
+ self.nezha = NezhaModel(config, add_pooling_layer=False)
1524
+ classifier_dropout = (
1525
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1526
+ )
1527
+ self.dropout = nn.Dropout(classifier_dropout)
1528
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1529
+
1530
+ # Initialize weights and apply final processing
1531
+ self.post_init()
1532
+
1533
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1534
+ @add_code_sample_docstrings(
1535
+ checkpoint=_CHECKPOINT_FOR_DOC,
1536
+ output_type=TokenClassifierOutput,
1537
+ config_class=_CONFIG_FOR_DOC,
1538
+ )
1539
+ def forward(
1540
+ self,
1541
+ input_ids: Optional[torch.Tensor] = None,
1542
+ attention_mask: Optional[torch.Tensor] = None,
1543
+ token_type_ids: Optional[torch.Tensor] = None,
1544
+ head_mask: Optional[torch.Tensor] = None,
1545
+ inputs_embeds: Optional[torch.Tensor] = None,
1546
+ labels: Optional[torch.Tensor] = None,
1547
+ output_attentions: Optional[bool] = None,
1548
+ output_hidden_states: Optional[bool] = None,
1549
+ return_dict: Optional[bool] = None,
1550
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1551
+ r"""
1552
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1553
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1554
+ """
1555
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1556
+
1557
+ outputs = self.nezha(
1558
+ input_ids,
1559
+ attention_mask=attention_mask,
1560
+ token_type_ids=token_type_ids,
1561
+ head_mask=head_mask,
1562
+ inputs_embeds=inputs_embeds,
1563
+ output_attentions=output_attentions,
1564
+ output_hidden_states=output_hidden_states,
1565
+ return_dict=return_dict,
1566
+ )
1567
+
1568
+ sequence_output = outputs[0]
1569
+
1570
+ sequence_output = self.dropout(sequence_output)
1571
+ logits = self.classifier(sequence_output)
1572
+
1573
+ loss = None
1574
+ if labels is not None:
1575
+ loss_fct = CrossEntropyLoss()
1576
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1577
+
1578
+ if not return_dict:
1579
+ output = (logits,) + outputs[2:]
1580
+ return ((loss,) + output) if loss is not None else output
1581
+
1582
+ return TokenClassifierOutput(
1583
+ loss=loss,
1584
+ logits=logits,
1585
+ hidden_states=outputs.hidden_states,
1586
+ attentions=outputs.attentions,
1587
+ )
1588
+
1589
+
1590
+ @add_start_docstrings(
1591
+ """
1592
+ Nezha Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1593
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1594
+ """,
1595
+ NEZHA_START_DOCSTRING,
1596
+ )
1597
+ class NezhaForQuestionAnswering(NezhaPreTrainedModel):
1598
+ def __init__(self, config):
1599
+ super().__init__(config)
1600
+ self.num_labels = config.num_labels
1601
+
1602
+ self.nezha = NezhaModel(config, add_pooling_layer=False)
1603
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1604
+
1605
+ # Initialize weights and apply final processing
1606
+ self.post_init()
1607
+
1608
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1609
+ @add_code_sample_docstrings(
1610
+ checkpoint=_CHECKPOINT_FOR_DOC,
1611
+ output_type=QuestionAnsweringModelOutput,
1612
+ config_class=_CONFIG_FOR_DOC,
1613
+ )
1614
+ def forward(
1615
+ self,
1616
+ input_ids: Optional[torch.Tensor] = None,
1617
+ attention_mask: Optional[torch.Tensor] = None,
1618
+ token_type_ids: Optional[torch.Tensor] = None,
1619
+ head_mask: Optional[torch.Tensor] = None,
1620
+ inputs_embeds: Optional[torch.Tensor] = None,
1621
+ start_positions: Optional[torch.Tensor] = None,
1622
+ end_positions: Optional[torch.Tensor] = None,
1623
+ output_attentions: Optional[bool] = None,
1624
+ output_hidden_states: Optional[bool] = None,
1625
+ return_dict: Optional[bool] = None,
1626
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1627
+ r"""
1628
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1629
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1630
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1631
+ are not taken into account for computing the loss.
1632
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1633
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1634
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1635
+ are not taken into account for computing the loss.
1636
+ """
1637
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1638
+
1639
+ outputs = self.nezha(
1640
+ input_ids,
1641
+ attention_mask=attention_mask,
1642
+ token_type_ids=token_type_ids,
1643
+ head_mask=head_mask,
1644
+ inputs_embeds=inputs_embeds,
1645
+ output_attentions=output_attentions,
1646
+ output_hidden_states=output_hidden_states,
1647
+ return_dict=return_dict,
1648
+ )
1649
+
1650
+ sequence_output = outputs[0]
1651
+
1652
+ logits = self.qa_outputs(sequence_output)
1653
+ start_logits, end_logits = logits.split(1, dim=-1)
1654
+ start_logits = start_logits.squeeze(-1).contiguous()
1655
+ end_logits = end_logits.squeeze(-1).contiguous()
1656
+
1657
+ total_loss = None
1658
+ if start_positions is not None and end_positions is not None:
1659
+ # If we are on multi-GPU, split add a dimension
1660
+ if len(start_positions.size()) > 1:
1661
+ start_positions = start_positions.squeeze(-1)
1662
+ if len(end_positions.size()) > 1:
1663
+ end_positions = end_positions.squeeze(-1)
1664
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1665
+ ignored_index = start_logits.size(1)
1666
+ start_positions = start_positions.clamp(0, ignored_index)
1667
+ end_positions = end_positions.clamp(0, ignored_index)
1668
+
1669
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1670
+ start_loss = loss_fct(start_logits, start_positions)
1671
+ end_loss = loss_fct(end_logits, end_positions)
1672
+ total_loss = (start_loss + end_loss) / 2
1673
+
1674
+ if not return_dict:
1675
+ output = (start_logits, end_logits) + outputs[2:]
1676
+ return ((total_loss,) + output) if total_loss is not None else output
1677
+
1678
+ return QuestionAnsweringModelOutput(
1679
+ loss=total_loss,
1680
+ start_logits=start_logits,
1681
+ end_logits=end_logits,
1682
+ hidden_states=outputs.hidden_states,
1683
+ attentions=outputs.attentions,
1684
+ )
1685
+
1686
+
1687
+ __all__ = [
1688
+ "NezhaForNextSentencePrediction",
1689
+ "NezhaForMaskedLM",
1690
+ "NezhaForPreTraining",
1691
+ "NezhaForMultipleChoice",
1692
+ "NezhaForQuestionAnswering",
1693
+ "NezhaForSequenceClassification",
1694
+ "NezhaForTokenClassification",
1695
+ "NezhaModel",
1696
+ "NezhaPreTrainedModel",
1697
+ ]
docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_open_llama import *
22
+ from .modeling_open_llama import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Open-Llama model configuration"""
21
+
22
+ from ....configuration_utils import PretrainedConfig
23
+ from ....utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class OpenLlamaConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`OpenLlamaModel`]. It is used to instantiate an
32
+ Open-Llama model according to the specified arguments, defining the model architecture. Instantiating a
33
+ configuration with the defaults will yield a similar configuration to that of the
34
+ [s-JoL/Open-Llama-V1](https://huggingface.co/s-JoL/Open-Llama-V1).
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+
40
+ Args:
41
+ vocab_size (`int`, *optional*, defaults to 32000):
42
+ Vocabulary size of the Open-Llama model. Defines the number of different tokens that can be represented by
43
+ the `inputs_ids` passed when calling [`OpenLlamaModel`]
44
+ hidden_size (`int`, *optional*, defaults to 4096):
45
+ Dimension of the hidden representations.
46
+ intermediate_size (`int`, *optional*, defaults to 11008):
47
+ Dimension of the MLP representations.
48
+ num_hidden_layers (`int`, *optional*, defaults to 32):
49
+ Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (`int`, *optional*, defaults to 32):
51
+ Number of attention heads for each attention layer in the Transformer encoder.
52
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
53
+ The non-linear activation function (function or string) in the decoder.
54
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
55
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
56
+ just in case (e.g., 512 or 1024 or 2048).
57
+ initializer_range (`float`, *optional*, defaults to 0.02):
58
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
59
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
60
+ The epsilon used by the rms normalization layers.
61
+ use_cache (`bool`, *optional*, defaults to `True`):
62
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
63
+ relevant if `config.is_decoder=True`.
64
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
65
+ Whether to tie weight embeddings
66
+ rope_theta (`float`, *optional*, defaults to 10000.0):
67
+ The base period of the RoPE embeddings.
68
+ rope_scaling (`Dict`, *optional*):
69
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
70
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
71
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
72
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
73
+ these scaling strategies behave:
74
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
75
+ experimental feature, subject to breaking API changes in future versions.
76
+
77
+ Example:
78
+
79
+ ```python
80
+ >>> from transformers import OpenLlamaModel, OpenLlamaConfig
81
+
82
+ >>> # Initializing a Open-Llama open_llama-7b style configuration
83
+ >>> configuration = OpenLlamaConfig()
84
+
85
+ >>> # Initializing a model from the open_llama-7b style configuration
86
+ >>> model = OpenLlamaModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+
92
+ model_type = "open-llama"
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=100000,
97
+ hidden_size=4096,
98
+ intermediate_size=11008,
99
+ num_hidden_layers=32,
100
+ num_attention_heads=32,
101
+ hidden_act="silu",
102
+ max_position_embeddings=2048,
103
+ initializer_range=0.02,
104
+ rms_norm_eps=1e-6,
105
+ use_cache=True,
106
+ pad_token_id=0,
107
+ bos_token_id=1,
108
+ eos_token_id=2,
109
+ tie_word_embeddings=False,
110
+ use_memory_efficient_attention=True,
111
+ hidden_dropout_prob=0.1,
112
+ attention_dropout_prob=0.1,
113
+ use_stable_embedding=True,
114
+ shared_input_output_embedding=True,
115
+ rope_theta=10000.0,
116
+ rope_scaling=None,
117
+ **kwargs,
118
+ ):
119
+ self.vocab_size = vocab_size
120
+ self.max_position_embeddings = max_position_embeddings
121
+ self.hidden_size = hidden_size
122
+ self.intermediate_size = intermediate_size
123
+ self.num_hidden_layers = num_hidden_layers
124
+ self.num_attention_heads = num_attention_heads
125
+ self.hidden_act = hidden_act
126
+ self.initializer_range = initializer_range
127
+ self.rms_norm_eps = rms_norm_eps
128
+ self.use_cache = use_cache
129
+ self.use_memory_efficient_attention = kwargs.pop(
130
+ "use_memorry_efficient_attention", use_memory_efficient_attention
131
+ )
132
+ self.hidden_dropout_prob = hidden_dropout_prob
133
+ self.attention_dropout_prob = attention_dropout_prob
134
+ self.use_stable_embedding = use_stable_embedding
135
+ self.shared_input_output_embedding = shared_input_output_embedding
136
+ self.rope_theta = rope_theta
137
+ self.rope_scaling = rope_scaling
138
+ self._rope_scaling_validation()
139
+
140
+ super().__init__(
141
+ pad_token_id=pad_token_id,
142
+ bos_token_id=bos_token_id,
143
+ eos_token_id=eos_token_id,
144
+ tie_word_embeddings=tie_word_embeddings,
145
+ **kwargs,
146
+ )
147
+
148
+ def _rope_scaling_validation(self):
149
+ """
150
+ Validate the `rope_scaling` configuration.
151
+ """
152
+ if self.rope_scaling is None:
153
+ return
154
+
155
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
156
+ raise ValueError(
157
+ f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
158
+ )
159
+ rope_scaling_type = self.rope_scaling.get("type", None)
160
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
161
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
162
+ raise ValueError(
163
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
164
+ )
165
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
166
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
167
+
168
+
169
+ __all__ = ["OpenLlamaConfig"]
docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch Open-Llama model."""
21
+
22
+ import math
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from ....activations import ACT2FN
31
+ from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
+ from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
33
+ from ....modeling_utils import PreTrainedModel
34
+ from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
35
+ from .configuration_open_llama import OpenLlamaConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ try:
41
+ from xformers import ops as xops
42
+ except ImportError:
43
+ xops = None
44
+
45
+
46
+ _CONFIG_FOR_DOC = "OpenLlamaConfig"
47
+
48
+
49
+ class OpenLlamaRMSNorm(nn.Module):
50
+ def __init__(self, hidden_size, eps=1e-6):
51
+ """
52
+ OpenLlamaRMSNorm is equivalent to T5LayerNorm
53
+ """
54
+ super().__init__()
55
+ self.weight = nn.Parameter(torch.ones(hidden_size))
56
+ self.variance_epsilon = eps
57
+
58
+ def forward(self, hidden_states):
59
+ input_dtype = hidden_states.dtype
60
+ hidden_states = hidden_states.to(torch.float32)
61
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
62
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
63
+ return self.weight * hidden_states.to(input_dtype)
64
+
65
+ def extra_repr(self):
66
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
67
+
68
+
69
+ class OpenLlamaRotaryEmbedding(nn.Module):
70
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
71
+ super().__init__()
72
+
73
+ self.dim = dim
74
+ self.max_position_embeddings = max_position_embeddings
75
+ self.base = base
76
+ inv_freq = 1.0 / (
77
+ self.base
78
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
79
+ )
80
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
81
+
82
+ # Build here to make `torch.jit.trace` work.
83
+ self._set_cos_sin_cache(
84
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
85
+ )
86
+
87
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
88
+ self.max_seq_len_cached = seq_len
89
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
90
+
91
+ freqs = torch.outer(t, self.inv_freq)
92
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
93
+ emb = torch.cat((freqs, freqs), dim=-1)
94
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
95
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
96
+
97
+ def forward(self, x, seq_len=None):
98
+ # x: [bs, num_attention_heads, seq_len, head_size]
99
+ if seq_len > self.max_seq_len_cached:
100
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
101
+
102
+ return (
103
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
104
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
105
+ )
106
+
107
+
108
+ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
109
+ """OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
110
+
111
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
112
+ self.scaling_factor = scaling_factor
113
+ super().__init__(dim, max_position_embeddings, base, device)
114
+
115
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
116
+ self.max_seq_len_cached = seq_len
117
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
118
+ t = t / self.scaling_factor
119
+
120
+ freqs = torch.outer(t, self.inv_freq)
121
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
122
+ emb = torch.cat((freqs, freqs), dim=-1)
123
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
124
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
125
+
126
+
127
+ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
128
+ """OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
129
+
130
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
131
+ self.scaling_factor = scaling_factor
132
+ super().__init__(dim, max_position_embeddings, base, device)
133
+
134
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
135
+ self.max_seq_len_cached = seq_len
136
+
137
+ if seq_len > self.max_position_embeddings:
138
+ base = self.base * (
139
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
140
+ ) ** (self.dim / (self.dim - 2))
141
+ inv_freq = 1.0 / (
142
+ base
143
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
144
+ )
145
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
146
+
147
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
148
+
149
+ freqs = torch.outer(t, self.inv_freq)
150
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
151
+ emb = torch.cat((freqs, freqs), dim=-1)
152
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
153
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
154
+
155
+
156
+ def rotate_half(x):
157
+ """Rotates half the hidden dims of the input."""
158
+ x1 = x[..., : x.shape[-1] // 2]
159
+ x2 = x[..., x.shape[-1] // 2 :]
160
+ return torch.cat((-x2, x1), dim=-1)
161
+
162
+
163
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
164
+ """Applies Rotary Position Embedding to the query and key tensors.
165
+
166
+ Args:
167
+ q (`torch.Tensor`): The query tensor.
168
+ k (`torch.Tensor`): The key tensor.
169
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
170
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
171
+ position_ids (`torch.Tensor`):
172
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
173
+ used to pass offsetted position ids when working with a KV-cache.
174
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
175
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
176
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
177
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
178
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
179
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
180
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
181
+ Returns:
182
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
183
+ """
184
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
185
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
186
+ q_embed = (q * cos) + (rotate_half(q) * sin)
187
+ k_embed = (k * cos) + (rotate_half(k) * sin)
188
+ return q_embed, k_embed
189
+
190
+
191
+ class OpenLlamaMLP(nn.Module):
192
+ def __init__(
193
+ self,
194
+ hidden_size: int,
195
+ intermediate_size: int,
196
+ hidden_act: str,
197
+ dropout_prob: float,
198
+ ):
199
+ super().__init__()
200
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
201
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
202
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
203
+ self.act_fn = ACT2FN[hidden_act]
204
+ self.dropout = nn.Dropout(dropout_prob)
205
+
206
+ def forward(self, x):
207
+ out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
208
+ return self.dropout(out)
209
+
210
+
211
+ class OpenLlamaAttention(nn.Module):
212
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
213
+
214
+ def __init__(self, config: OpenLlamaConfig):
215
+ super().__init__()
216
+ self.config = config
217
+ self.hidden_size = config.hidden_size
218
+ self.num_heads = config.num_attention_heads
219
+ self.head_dim = self.hidden_size // self.num_heads
220
+ self.max_position_embeddings = config.max_position_embeddings
221
+ self.dropout_prob = config.attention_dropout_prob
222
+ self.rope_theta = config.rope_theta
223
+
224
+ if (self.head_dim * self.num_heads) != self.hidden_size:
225
+ raise ValueError(
226
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
227
+ f" and `num_heads`: {self.num_heads})."
228
+ )
229
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
230
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
231
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
232
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
233
+ self._init_rope()
234
+
235
+ def _init_rope(self):
236
+ if self.config.rope_scaling is None:
237
+ self.rotary_emb = OpenLlamaRotaryEmbedding(
238
+ self.head_dim,
239
+ max_position_embeddings=self.max_position_embeddings,
240
+ base=self.rope_theta,
241
+ )
242
+ else:
243
+ scaling_type = self.config.rope_scaling["type"]
244
+ scaling_factor = self.config.rope_scaling["factor"]
245
+ if scaling_type == "linear":
246
+ self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding(
247
+ self.head_dim,
248
+ max_position_embeddings=self.max_position_embeddings,
249
+ scaling_factor=scaling_factor,
250
+ base=self.rope_theta,
251
+ )
252
+ elif scaling_type == "dynamic":
253
+ self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding(
254
+ self.head_dim,
255
+ max_position_embeddings=self.max_position_embeddings,
256
+ scaling_factor=scaling_factor,
257
+ base=self.rope_theta,
258
+ )
259
+ else:
260
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
261
+
262
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
263
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
264
+
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ attention_mask: Optional[torch.Tensor] = None,
269
+ position_ids: Optional[torch.LongTensor] = None,
270
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
271
+ output_attentions: bool = False,
272
+ use_cache: bool = False,
273
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
274
+ bsz, q_len, _ = hidden_states.size()
275
+
276
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
277
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
278
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
279
+
280
+ kv_seq_len = key_states.shape[-2]
281
+ if past_key_value is not None:
282
+ kv_seq_len += past_key_value[0].shape[-2]
283
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
284
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
285
+ # [bsz, nh, t, hd]
286
+
287
+ if past_key_value is not None:
288
+ # reuse k, v, self_attention
289
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
290
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
291
+
292
+ past_key_value = (key_states, value_states) if use_cache else None
293
+
294
+ if self.config.use_memory_efficient_attention and xops is not None and self.training:
295
+ attn_weights = None
296
+ query_states = query_states.transpose(1, 2)
297
+ key_states = key_states.transpose(1, 2)
298
+ value_states = value_states.transpose(1, 2)
299
+ attn_output = xops.memory_efficient_attention(
300
+ query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
301
+ )
302
+ else:
303
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
304
+
305
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
306
+ raise ValueError(
307
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
308
+ f" {attn_weights.size()}"
309
+ )
310
+
311
+ if attention_mask is not None:
312
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
313
+ raise ValueError(
314
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
315
+ )
316
+ attn_weights = attn_weights + attention_mask
317
+ attn_weights = torch.max(
318
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
319
+ )
320
+
321
+ # upcast attention to fp32
322
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
323
+ attn_output = torch.matmul(attn_weights, value_states)
324
+
325
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
326
+ raise ValueError(
327
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
328
+ f" {attn_output.size()}"
329
+ )
330
+
331
+ attn_output = attn_output.transpose(1, 2)
332
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
333
+
334
+ attn_output = self.o_proj(attn_output)
335
+
336
+ if not output_attentions:
337
+ attn_weights = None
338
+
339
+ return attn_output, attn_weights, past_key_value
340
+
341
+
342
+ class OpenLlamaDecoderLayer(nn.Module):
343
+ def __init__(self, config: OpenLlamaConfig):
344
+ super().__init__()
345
+ self.hidden_size = config.hidden_size
346
+ self.self_attn = OpenLlamaAttention(config=config)
347
+ self.mlp = OpenLlamaMLP(
348
+ hidden_size=self.hidden_size,
349
+ intermediate_size=config.intermediate_size,
350
+ hidden_act=config.hidden_act,
351
+ dropout_prob=config.hidden_dropout_prob,
352
+ )
353
+ self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
354
+ self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
355
+
356
+ def forward(
357
+ self,
358
+ hidden_states: torch.Tensor,
359
+ attention_mask: Optional[torch.Tensor] = None,
360
+ position_ids: Optional[torch.LongTensor] = None,
361
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
362
+ output_attentions: Optional[bool] = False,
363
+ use_cache: Optional[bool] = False,
364
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
365
+ """
366
+ Args:
367
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
368
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
369
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
370
+ output_attentions (`bool`, *optional*):
371
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
372
+ returned tensors for more detail.
373
+ use_cache (`bool`, *optional*):
374
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
375
+ (see `past_key_values`).
376
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
377
+ """
378
+
379
+ residual = hidden_states
380
+
381
+ hidden_states = self.input_layernorm(hidden_states)
382
+
383
+ # Self Attention
384
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
385
+ hidden_states=hidden_states,
386
+ attention_mask=attention_mask,
387
+ position_ids=position_ids,
388
+ past_key_value=past_key_value,
389
+ output_attentions=output_attentions,
390
+ use_cache=use_cache,
391
+ )
392
+ hidden_states = residual + hidden_states
393
+
394
+ # Fully Connected
395
+ residual = hidden_states
396
+ hidden_states = self.post_attention_layernorm(hidden_states)
397
+ hidden_states = self.mlp(hidden_states)
398
+ hidden_states = residual + hidden_states
399
+
400
+ outputs = (hidden_states,)
401
+
402
+ if output_attentions:
403
+ outputs += (self_attn_weights,)
404
+
405
+ if use_cache:
406
+ outputs += (present_key_value,)
407
+
408
+ return outputs
409
+
410
+
411
+ OPEN_LLAMA_START_DOCSTRING = r"""
412
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
413
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
414
+ etc.)
415
+
416
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
417
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
418
+ and behavior.
419
+
420
+ Parameters:
421
+ config ([`OpenLlamaConfig`]):
422
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
423
+ load the weights associated with the model, only the configuration. Check out the
424
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
425
+ """
426
+
427
+
428
+ @add_start_docstrings(
429
+ "The bare Open-Llama Model outputting raw hidden-states without any specific head on top.",
430
+ OPEN_LLAMA_START_DOCSTRING,
431
+ )
432
+ class OpenLlamaPreTrainedModel(PreTrainedModel):
433
+ config_class = OpenLlamaConfig
434
+ base_model_prefix = "model"
435
+ supports_gradient_checkpointing = True
436
+ _no_split_modules = ["OpenLlamaDecoderLayer"]
437
+
438
+ def _init_weights(self, module):
439
+ std = self.config.initializer_range
440
+ if isinstance(module, nn.Linear):
441
+ module.weight.data.normal_(mean=0.0, std=std)
442
+ if module.bias is not None:
443
+ module.bias.data.zero_()
444
+ elif isinstance(module, nn.Embedding):
445
+ if self.config.use_stable_embedding:
446
+ torch.nn.init.xavier_normal_(module.weight.data)
447
+ else:
448
+ module.weight.data.normal_(mean=0.0, std=std)
449
+ if module.padding_idx is not None:
450
+ module.weight.data[module.padding_idx].zero_()
451
+
452
+
453
+ OPEN_LLAMA_INPUTS_DOCSTRING = r"""
454
+ Args:
455
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
456
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
457
+ it.
458
+
459
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
460
+ [`PreTrainedTokenizer.__call__`] for details.
461
+
462
+ [What are input IDs?](../glossary#input-ids)
463
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
464
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
465
+
466
+ - 1 for tokens that are **not masked**,
467
+ - 0 for tokens that are **masked**.
468
+
469
+ [What are attention masks?](../glossary#attention-mask)
470
+
471
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
472
+ [`PreTrainedTokenizer.__call__`] for details.
473
+
474
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
475
+ `past_key_values`).
476
+
477
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
478
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
479
+ information on the default strategy.
480
+
481
+ - 1 indicates the head is **not masked**,
482
+ - 0 indicates the head is **masked**.
483
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
484
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
485
+ config.n_positions - 1]`.
486
+
487
+ [What are position IDs?](../glossary#position-ids)
488
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
489
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
490
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
491
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
492
+
493
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
494
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
495
+
496
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
497
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
498
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
499
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
500
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
501
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
502
+ model's internal embedding lookup matrix.
503
+ use_cache (`bool`, *optional*):
504
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
505
+ `past_key_values`).
506
+ output_attentions (`bool`, *optional*):
507
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
508
+ tensors for more detail.
509
+ output_hidden_states (`bool`, *optional*):
510
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
511
+ more detail.
512
+ return_dict (`bool`, *optional*):
513
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
514
+ """
515
+
516
+
517
+ @add_start_docstrings(
518
+ "The bare Open-Llama Model outputting raw hidden-states without any specific head on top.",
519
+ OPEN_LLAMA_START_DOCSTRING,
520
+ )
521
+ class OpenLlamaModel(OpenLlamaPreTrainedModel):
522
+ """
523
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OpenLlamaDecoderLayer`]
524
+
525
+ Args:
526
+ config: OpenLlamaConfig
527
+ """
528
+
529
+ def __init__(self, config: OpenLlamaConfig):
530
+ super().__init__(config)
531
+ self.padding_idx = config.pad_token_id
532
+ self.vocab_size = config.vocab_size
533
+
534
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
535
+ if config.use_stable_embedding:
536
+ self.embed_layer_norm = nn.LayerNorm(config.hidden_size)
537
+ else:
538
+ self.embed_layer_norm = None
539
+ self.layers = nn.ModuleList([OpenLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
540
+ self.norm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
541
+
542
+ self.gradient_checkpointing = False
543
+ # Initialize weights and apply final processing
544
+ self.post_init()
545
+
546
+ def get_input_embeddings(self):
547
+ return self.embed_tokens
548
+
549
+ def set_input_embeddings(self, value):
550
+ self.embed_tokens = value
551
+
552
+ @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
553
+ def forward(
554
+ self,
555
+ input_ids: Optional[torch.LongTensor] = None,
556
+ attention_mask: Optional[torch.Tensor] = None,
557
+ position_ids: Optional[torch.LongTensor] = None,
558
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
559
+ inputs_embeds: Optional[torch.FloatTensor] = None,
560
+ use_cache: Optional[bool] = None,
561
+ output_attentions: Optional[bool] = None,
562
+ output_hidden_states: Optional[bool] = None,
563
+ return_dict: Optional[bool] = None,
564
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
565
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
566
+ output_hidden_states = (
567
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
568
+ )
569
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
570
+
571
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
572
+
573
+ # retrieve input_ids and inputs_embeds
574
+ if input_ids is not None and inputs_embeds is not None:
575
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
576
+ elif input_ids is not None:
577
+ batch_size, seq_length = input_ids.shape
578
+ elif inputs_embeds is not None:
579
+ batch_size, seq_length, _ = inputs_embeds.shape
580
+ else:
581
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
582
+
583
+ seq_length_with_past = seq_length
584
+ past_key_values_length = 0
585
+
586
+ if self.gradient_checkpointing and self.training:
587
+ if use_cache:
588
+ logger.warning_once(
589
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
590
+ )
591
+ use_cache = False
592
+
593
+ if past_key_values is not None:
594
+ past_key_values_length = past_key_values[0][0].shape[2]
595
+ seq_length_with_past = seq_length_with_past + past_key_values_length
596
+
597
+ if position_ids is None:
598
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
599
+ position_ids = torch.arange(
600
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
601
+ )
602
+ position_ids = position_ids.unsqueeze(0)
603
+
604
+ if inputs_embeds is None:
605
+ inputs_embeds = self.embed_tokens(input_ids)
606
+ if self.embed_layer_norm:
607
+ inputs_embeds = self.embed_layer_norm(inputs_embeds)
608
+ # embed positions
609
+ if self.config.use_memory_efficient_attention and self.training:
610
+ attention_mask = None
611
+ elif attention_mask is None:
612
+ attention_mask = torch.ones(
613
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
614
+ )
615
+
616
+ input_shape = (batch_size, seq_length)
617
+ attention_mask = _prepare_4d_causal_attention_mask(
618
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
619
+ )
620
+
621
+ hidden_states = inputs_embeds
622
+
623
+ # decoder layers
624
+ all_hidden_states = () if output_hidden_states else None
625
+ all_self_attns = () if output_attentions else None
626
+ next_decoder_cache = () if use_cache else None
627
+
628
+ for idx, decoder_layer in enumerate(self.layers):
629
+ if output_hidden_states:
630
+ all_hidden_states += (hidden_states,)
631
+
632
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
633
+
634
+ if self.gradient_checkpointing and self.training:
635
+ layer_outputs = self._gradient_checkpointing_func(
636
+ decoder_layer.__call__,
637
+ hidden_states,
638
+ attention_mask,
639
+ position_ids,
640
+ None,
641
+ output_attentions,
642
+ None,
643
+ )
644
+ else:
645
+ layer_outputs = decoder_layer(
646
+ hidden_states,
647
+ attention_mask=attention_mask,
648
+ position_ids=position_ids,
649
+ past_key_value=past_key_value,
650
+ output_attentions=output_attentions,
651
+ use_cache=use_cache,
652
+ )
653
+
654
+ hidden_states = layer_outputs[0]
655
+
656
+ if use_cache:
657
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
658
+
659
+ if output_attentions:
660
+ all_self_attns += (layer_outputs[1],)
661
+
662
+ hidden_states = self.norm(hidden_states)
663
+
664
+ # add hidden states from the last decoder layer
665
+ if output_hidden_states:
666
+ all_hidden_states += (hidden_states,)
667
+
668
+ next_cache = next_decoder_cache if use_cache else None
669
+ if not return_dict:
670
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
671
+ return BaseModelOutputWithPast(
672
+ last_hidden_state=hidden_states,
673
+ past_key_values=next_cache,
674
+ hidden_states=all_hidden_states,
675
+ attentions=all_self_attns,
676
+ )
677
+
678
+
679
+ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
680
+ def __init__(self, config):
681
+ super().__init__(config)
682
+ self.model = OpenLlamaModel(config)
683
+ if config.shared_input_output_embedding:
684
+ self.lm_head = None
685
+ else:
686
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
687
+
688
+ # Initialize weights and apply final processing
689
+ self.post_init()
690
+
691
+ def get_input_embeddings(self):
692
+ return self.model.embed_tokens
693
+
694
+ def set_input_embeddings(self, value):
695
+ self.model.embed_tokens = value
696
+
697
+ def get_output_embeddings(self):
698
+ return self.lm_head
699
+
700
+ def set_output_embeddings(self, new_embeddings):
701
+ self.lm_head = new_embeddings
702
+
703
+ def set_decoder(self, decoder):
704
+ self.model = decoder
705
+
706
+ def get_decoder(self):
707
+ return self.model
708
+
709
+ @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
710
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
711
+ def forward(
712
+ self,
713
+ input_ids: Optional[torch.LongTensor] = None,
714
+ attention_mask: Optional[torch.Tensor] = None,
715
+ position_ids: Optional[torch.LongTensor] = None,
716
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
717
+ inputs_embeds: Optional[torch.FloatTensor] = None,
718
+ labels: Optional[torch.LongTensor] = None,
719
+ use_cache: Optional[bool] = None,
720
+ output_attentions: Optional[bool] = None,
721
+ output_hidden_states: Optional[bool] = None,
722
+ return_dict: Optional[bool] = None,
723
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
724
+ r"""
725
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
726
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
727
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
728
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
729
+
730
+ Returns:
731
+
732
+ Example:
733
+
734
+ ```python
735
+ >>> from transformers import AutoTokenizer, OpenLlamaForCausalLM
736
+
737
+ >>> model = OpenLlamaForCausalLM.from_pretrained("openlm-research/open_llama_7b")
738
+ >>> tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b")
739
+
740
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
741
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
742
+
743
+ >>> # Generate
744
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
745
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
746
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
747
+ ```"""
748
+
749
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
750
+ output_hidden_states = (
751
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
752
+ )
753
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
754
+
755
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
756
+ outputs = self.model(
757
+ input_ids=input_ids,
758
+ attention_mask=attention_mask,
759
+ position_ids=position_ids,
760
+ past_key_values=past_key_values,
761
+ inputs_embeds=inputs_embeds,
762
+ use_cache=use_cache,
763
+ output_attentions=output_attentions,
764
+ output_hidden_states=output_hidden_states,
765
+ return_dict=return_dict,
766
+ )
767
+
768
+ hidden_states = outputs[0]
769
+ if self.config.shared_input_output_embedding:
770
+ logits = torch.einsum(
771
+ "blh,vh->blv", hidden_states.to(self.model.embed_tokens.weight.device), self.model.embed_tokens.weight
772
+ )
773
+ else:
774
+ logits = self.lm_head(hidden_states)
775
+
776
+ loss = None
777
+ if labels is not None:
778
+ # move labels to correct device to enable model parallelism
779
+ labels = labels.to(logits.device)
780
+ # Shift so that tokens < n predict n
781
+ shift_logits = logits[..., :-1, :].contiguous()
782
+ shift_labels = labels[..., 1:].contiguous()
783
+ # Flatten the tokens
784
+ loss_fct = CrossEntropyLoss()
785
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
786
+ shift_labels = shift_labels.view(-1)
787
+ # Enable model parallelism
788
+ shift_labels = shift_labels.to(shift_logits.device)
789
+ loss = loss_fct(shift_logits, shift_labels)
790
+
791
+ if not return_dict:
792
+ output = (logits,) + outputs[1:]
793
+ return (loss,) + output if loss is not None else output
794
+
795
+ return CausalLMOutputWithPast(
796
+ loss=loss,
797
+ logits=logits,
798
+ past_key_values=outputs.past_key_values,
799
+ hidden_states=outputs.hidden_states,
800
+ attentions=outputs.attentions,
801
+ )
802
+
803
+ def prepare_inputs_for_generation(
804
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
805
+ ):
806
+ if past_key_values is not None:
807
+ past_length = past_key_values[0][0].shape[2]
808
+
809
+ # Some generation methods already pass only the last input ID
810
+ if input_ids.shape[1] > past_length:
811
+ remove_prefix_length = past_length
812
+ else:
813
+ # Default to old behavior: keep only final ID
814
+ remove_prefix_length = input_ids.shape[1] - 1
815
+
816
+ input_ids = input_ids[:, remove_prefix_length:]
817
+
818
+ position_ids = kwargs.get("position_ids", None)
819
+ if attention_mask is not None and position_ids is None:
820
+ # create position_ids on the fly for batch generation
821
+ position_ids = attention_mask.long().cumsum(-1) - 1
822
+ position_ids.masked_fill_(attention_mask == 0, 1)
823
+ if past_key_values:
824
+ position_ids = position_ids[:, -input_ids.shape[1] :]
825
+
826
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
827
+ if inputs_embeds is not None and past_key_values is None:
828
+ model_inputs = {"inputs_embeds": inputs_embeds}
829
+ else:
830
+ model_inputs = {"input_ids": input_ids}
831
+
832
+ model_inputs.update(
833
+ {
834
+ "position_ids": position_ids,
835
+ "past_key_values": past_key_values,
836
+ "use_cache": kwargs.get("use_cache"),
837
+ "attention_mask": attention_mask,
838
+ }
839
+ )
840
+ return model_inputs
841
+
842
+ @staticmethod
843
+ def _reorder_cache(past_key_values, beam_idx):
844
+ reordered_past = ()
845
+ for layer_past in past_key_values:
846
+ reordered_past += (
847
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
848
+ )
849
+ return reordered_past
850
+
851
+
852
+ @add_start_docstrings(
853
+ """
854
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
855
+
856
+ [`OpenLlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal
857
+ models (e.g. GPT-2) do.
858
+
859
+ Since it does classification on the last token, it requires to know the position of the last token. If a
860
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
861
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
862
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
863
+ each row of the batch).
864
+ """,
865
+ OPEN_LLAMA_START_DOCSTRING,
866
+ )
867
+ class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
868
+ def __init__(self, config):
869
+ super().__init__(config)
870
+ self.num_labels = config.num_labels
871
+ self.model = OpenLlamaModel(config)
872
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
873
+
874
+ # Initialize weights and apply final processing
875
+ self.post_init()
876
+
877
+ def get_input_embeddings(self):
878
+ return self.model.embed_tokens
879
+
880
+ def set_input_embeddings(self, value):
881
+ self.model.embed_tokens = value
882
+
883
+ @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
884
+ def forward(
885
+ self,
886
+ input_ids: Optional[torch.LongTensor] = None,
887
+ attention_mask: Optional[torch.Tensor] = None,
888
+ position_ids: Optional[torch.LongTensor] = None,
889
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
890
+ inputs_embeds: Optional[torch.FloatTensor] = None,
891
+ labels: Optional[torch.LongTensor] = None,
892
+ use_cache: Optional[bool] = None,
893
+ output_attentions: Optional[bool] = None,
894
+ output_hidden_states: Optional[bool] = None,
895
+ return_dict: Optional[bool] = None,
896
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
897
+ r"""
898
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
899
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
900
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
901
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
902
+ """
903
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
904
+
905
+ transformer_outputs = self.model(
906
+ input_ids,
907
+ attention_mask=attention_mask,
908
+ position_ids=position_ids,
909
+ past_key_values=past_key_values,
910
+ inputs_embeds=inputs_embeds,
911
+ use_cache=use_cache,
912
+ output_attentions=output_attentions,
913
+ output_hidden_states=output_hidden_states,
914
+ return_dict=return_dict,
915
+ )
916
+ hidden_states = transformer_outputs[0]
917
+ logits = self.score(hidden_states)
918
+
919
+ if input_ids is not None:
920
+ batch_size = input_ids.shape[0]
921
+ else:
922
+ batch_size = inputs_embeds.shape[0]
923
+
924
+ if self.config.pad_token_id is None and batch_size != 1:
925
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
926
+ if self.config.pad_token_id is None:
927
+ sequence_lengths = -1
928
+ else:
929
+ if input_ids is not None:
930
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
931
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
932
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
933
+ sequence_lengths = sequence_lengths.to(logits.device)
934
+ else:
935
+ sequence_lengths = -1
936
+
937
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
938
+
939
+ loss = None
940
+ if labels is not None:
941
+ labels = labels.to(logits.device)
942
+ if self.config.problem_type is None:
943
+ if self.num_labels == 1:
944
+ self.config.problem_type = "regression"
945
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
946
+ self.config.problem_type = "single_label_classification"
947
+ else:
948
+ self.config.problem_type = "multi_label_classification"
949
+
950
+ if self.config.problem_type == "regression":
951
+ loss_fct = MSELoss()
952
+ if self.num_labels == 1:
953
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
954
+ else:
955
+ loss = loss_fct(pooled_logits, labels)
956
+ elif self.config.problem_type == "single_label_classification":
957
+ loss_fct = CrossEntropyLoss()
958
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
959
+ elif self.config.problem_type == "multi_label_classification":
960
+ loss_fct = BCEWithLogitsLoss()
961
+ loss = loss_fct(pooled_logits, labels)
962
+ if not return_dict:
963
+ output = (pooled_logits,) + transformer_outputs[1:]
964
+ return ((loss,) + output) if loss is not None else output
965
+
966
+ return SequenceClassifierOutputWithPast(
967
+ loss=loss,
968
+ logits=pooled_logits,
969
+ past_key_values=transformer_outputs.past_key_values,
970
+ hidden_states=transformer_outputs.hidden_states,
971
+ attentions=transformer_outputs.attentions,
972
+ )
973
+
974
+
975
+ __all__ = ["OpenLlamaPreTrainedModel", "OpenLlamaModel", "OpenLlamaForCausalLM", "OpenLlamaForSequenceClassification"]
docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ....utils import _LazyModule
17
+ from ....utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_qdqbert import *
22
+ from .modeling_qdqbert import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """QDQBERT model configuration"""
16
+
17
+ from ....configuration_utils import PretrainedConfig
18
+ from ....utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class QDQBertConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`QDQBertModel`]. It is used to instantiate an
27
+ QDQBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
28
+ with the defaults will yield a similar configuration to that of the BERT
29
+ [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 30522):
37
+ Vocabulary size of the QDQBERT model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`QDQBertModel`].
39
+ hidden_size (`int`, *optional*, defaults to 768):
40
+ Dimension of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 12):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 12):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
49
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
53
+ The dropout ratio for the attention probabilities.
54
+ max_position_embeddings (`int`, *optional*, defaults to 512):
55
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
56
+ just in case (e.g., 512 or 1024 or 2048).
57
+ type_vocab_size (`int`, *optional*, defaults to 2):
58
+ The vocabulary size of the `token_type_ids` passed when calling [`QDQBertModel`].
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
62
+ The epsilon used by the layer normalization layers.
63
+ is_decoder (`bool`, *optional*, defaults to `False`):
64
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+
69
+ Examples:
70
+
71
+ ```python
72
+ >>> from transformers import QDQBertModel, QDQBertConfig
73
+
74
+ >>> # Initializing a QDQBERT google-bert/bert-base-uncased style configuration
75
+ >>> configuration = QDQBertConfig()
76
+
77
+ >>> # Initializing a model from the google-bert/bert-base-uncased style configuration
78
+ >>> model = QDQBertModel(configuration)
79
+
80
+ >>> # Accessing the model configuration
81
+ >>> configuration = model.config
82
+ ```"""
83
+
84
+ model_type = "qdqbert"
85
+
86
+ def __init__(
87
+ self,
88
+ vocab_size=30522,
89
+ hidden_size=768,
90
+ num_hidden_layers=12,
91
+ num_attention_heads=12,
92
+ intermediate_size=3072,
93
+ hidden_act="gelu",
94
+ hidden_dropout_prob=0.1,
95
+ attention_probs_dropout_prob=0.1,
96
+ max_position_embeddings=512,
97
+ type_vocab_size=2,
98
+ initializer_range=0.02,
99
+ layer_norm_eps=1e-12,
100
+ use_cache=True,
101
+ pad_token_id=1,
102
+ bos_token_id=0,
103
+ eos_token_id=2,
104
+ **kwargs,
105
+ ):
106
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
107
+
108
+ self.vocab_size = vocab_size
109
+ self.max_position_embeddings = max_position_embeddings
110
+ self.hidden_size = hidden_size
111
+ self.num_hidden_layers = num_hidden_layers
112
+ self.num_attention_heads = num_attention_heads
113
+ self.intermediate_size = intermediate_size
114
+ self.hidden_act = hidden_act
115
+ self.hidden_dropout_prob = hidden_dropout_prob
116
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
117
+ self.initializer_range = initializer_range
118
+ self.type_vocab_size = type_vocab_size
119
+ self.layer_norm_eps = layer_norm_eps
120
+ self.use_cache = use_cache
121
+
122
+
123
+ __all__ = ["QDQBertConfig"]
docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py ADDED
@@ -0,0 +1,1749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 NVIDIA Corporation and The HuggingFace Team.
3
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch QDQBERT model."""
17
+
18
+ import math
19
+ import os
20
+ import warnings
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from ....activations import ACT2FN
29
+ from ....modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ CausalLMOutputWithCrossAttentions,
33
+ MaskedLMOutput,
34
+ MultipleChoiceModelOutput,
35
+ NextSentencePredictorOutput,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutput,
38
+ TokenClassifierOutput,
39
+ )
40
+ from ....modeling_utils import PreTrainedModel
41
+ from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
42
+ from ....utils import (
43
+ add_code_sample_docstrings,
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ is_pytorch_quantization_available,
47
+ logging,
48
+ replace_return_docstrings,
49
+ requires_backends,
50
+ )
51
+ from .configuration_qdqbert import QDQBertConfig
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ # soft dependency
57
+ if is_pytorch_quantization_available():
58
+ try:
59
+ from pytorch_quantization import nn as quant_nn
60
+ from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer
61
+ except OSError:
62
+ logger.error(
63
+ "QDQBERT model are not usable since `pytorch_quantization` can't be loaded. Please try to reinstall it"
64
+ " following the instructions here:"
65
+ " https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
66
+ )
67
+
68
+ _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
69
+ _CONFIG_FOR_DOC = "QDQBertConfig"
70
+
71
+
72
+ def load_tf_weights_in_qdqbert(model, tf_checkpoint_path):
73
+ """Load tf checkpoints in a pytorch model."""
74
+ try:
75
+ import re
76
+
77
+ import numpy as np
78
+ import tensorflow as tf
79
+ except ImportError:
80
+ logger.error(
81
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
82
+ "https://www.tensorflow.org/install/ for installation instructions."
83
+ )
84
+ raise
85
+ tf_path = os.path.abspath(tf_checkpoint_path)
86
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
87
+ # Load weights from TF model
88
+ init_vars = tf.train.list_variables(tf_path)
89
+ names = []
90
+ arrays = []
91
+ for name, shape in init_vars:
92
+ logger.info(f"Loading TF weight {name} with shape {shape}")
93
+ array = tf.train.load_variable(tf_path, name)
94
+ names.append(name)
95
+ arrays.append(array)
96
+
97
+ for name, array in zip(names, arrays):
98
+ name = name.split("/")
99
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
100
+ # which are not required for using pretrained model
101
+ if any(
102
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
103
+ for n in name
104
+ ):
105
+ logger.info(f"Skipping {'/'.join(name)}")
106
+ continue
107
+ pointer = model
108
+ for m_name in name:
109
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
110
+ scope_names = re.split(r"_(\d+)", m_name)
111
+ else:
112
+ scope_names = [m_name]
113
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
114
+ pointer = getattr(pointer, "weight")
115
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
116
+ pointer = getattr(pointer, "bias")
117
+ elif scope_names[0] == "output_weights":
118
+ pointer = getattr(pointer, "weight")
119
+ elif scope_names[0] == "squad":
120
+ pointer = getattr(pointer, "classifier")
121
+ else:
122
+ try:
123
+ pointer = getattr(pointer, scope_names[0])
124
+ except AttributeError:
125
+ logger.info(f"Skipping {'/'.join(name)}")
126
+ continue
127
+ if len(scope_names) >= 2:
128
+ num = int(scope_names[1])
129
+ pointer = pointer[num]
130
+ if m_name[-11:] == "_embeddings":
131
+ pointer = getattr(pointer, "weight")
132
+ elif m_name == "kernel":
133
+ array = np.transpose(array)
134
+ try:
135
+ if pointer.shape != array.shape:
136
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
137
+ except AssertionError as e:
138
+ e.args += (pointer.shape, array.shape)
139
+ raise
140
+ logger.info(f"Initialize PyTorch weight {name}")
141
+ pointer.data = torch.from_numpy(array)
142
+ return model
143
+
144
+
145
+ class QDQBertEmbeddings(nn.Module):
146
+ """Construct the embeddings from word, position and token_type embeddings."""
147
+
148
+ def __init__(self, config):
149
+ super().__init__()
150
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
151
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
152
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
153
+
154
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
155
+ # any TensorFlow checkpoint file
156
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
157
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
158
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
159
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
160
+ self.register_buffer(
161
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
162
+ )
163
+ self.register_buffer(
164
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
165
+ )
166
+
167
+ def forward(
168
+ self,
169
+ input_ids: Optional[torch.LongTensor] = None,
170
+ token_type_ids: Optional[torch.LongTensor] = None,
171
+ position_ids: Optional[torch.LongTensor] = None,
172
+ inputs_embeds: Optional[torch.FloatTensor] = None,
173
+ past_key_values_length: int = 0,
174
+ ) -> torch.Tensor:
175
+ if input_ids is not None:
176
+ input_shape = input_ids.size()
177
+ else:
178
+ input_shape = inputs_embeds.size()[:-1]
179
+
180
+ seq_length = input_shape[1]
181
+
182
+ if position_ids is None:
183
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
184
+
185
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
186
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
187
+ # issue #5664
188
+ if token_type_ids is None:
189
+ if hasattr(self, "token_type_ids"):
190
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
191
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
192
+ token_type_ids = buffered_token_type_ids_expanded
193
+ else:
194
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
195
+
196
+ if inputs_embeds is None:
197
+ inputs_embeds = self.word_embeddings(input_ids)
198
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
199
+
200
+ embeddings = inputs_embeds + token_type_embeddings
201
+ if self.position_embedding_type == "absolute":
202
+ position_embeddings = self.position_embeddings(position_ids)
203
+ embeddings += position_embeddings
204
+ embeddings = self.LayerNorm(embeddings)
205
+ embeddings = self.dropout(embeddings)
206
+ return embeddings
207
+
208
+
209
+ class QDQBertSelfAttention(nn.Module):
210
+ def __init__(self, config):
211
+ super().__init__()
212
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
213
+ raise ValueError(
214
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
215
+ f"heads ({config.num_attention_heads})"
216
+ )
217
+
218
+ self.num_attention_heads = config.num_attention_heads
219
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
220
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
221
+
222
+ self.query = quant_nn.QuantLinear(config.hidden_size, self.all_head_size)
223
+ self.key = quant_nn.QuantLinear(config.hidden_size, self.all_head_size)
224
+ self.value = quant_nn.QuantLinear(config.hidden_size, self.all_head_size)
225
+
226
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
227
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
228
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
229
+ self.max_position_embeddings = config.max_position_embeddings
230
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
231
+
232
+ self.is_decoder = config.is_decoder
233
+
234
+ self.matmul_q_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
235
+ self.matmul_k_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
236
+ self.matmul_v_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
237
+ self.matmul_a_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
238
+
239
+ def transpose_for_scores(self, x):
240
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
241
+ x = x.view(*new_x_shape)
242
+ return x.permute(0, 2, 1, 3)
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states,
247
+ attention_mask=None,
248
+ head_mask=None,
249
+ encoder_hidden_states=None,
250
+ encoder_attention_mask=None,
251
+ past_key_value=None,
252
+ output_attentions=False,
253
+ ):
254
+ mixed_query_layer = self.query(hidden_states)
255
+
256
+ # If this is instantiated as a cross-attention module, the keys
257
+ # and values come from an encoder; the attention mask needs to be
258
+ # such that the encoder's padding tokens are not attended to.
259
+ is_cross_attention = encoder_hidden_states is not None
260
+
261
+ if is_cross_attention and past_key_value is not None:
262
+ # reuse k,v, cross_attentions
263
+ key_layer = past_key_value[0]
264
+ value_layer = past_key_value[1]
265
+ attention_mask = encoder_attention_mask
266
+ elif is_cross_attention:
267
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
268
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
269
+ attention_mask = encoder_attention_mask
270
+ elif past_key_value is not None:
271
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
272
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
273
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
274
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
275
+ else:
276
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
277
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
278
+
279
+ query_layer = self.transpose_for_scores(mixed_query_layer)
280
+
281
+ if self.is_decoder:
282
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
283
+ # Further calls to cross_attention layer can then reuse all cross-attention
284
+ # key/value_states (first "if" case)
285
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
286
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
287
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
288
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
289
+ past_key_value = (key_layer, value_layer)
290
+
291
+ # Take the dot product between "query" and "key" to get the raw attention scores.
292
+ attention_scores = torch.matmul(
293
+ self.matmul_q_input_quantizer(query_layer), self.matmul_k_input_quantizer(key_layer.transpose(-1, -2))
294
+ )
295
+
296
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
297
+ seq_length = hidden_states.size()[1]
298
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
299
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
300
+ distance = position_ids_l - position_ids_r
301
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
302
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
303
+
304
+ if self.position_embedding_type == "relative_key":
305
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
306
+ attention_scores = attention_scores + relative_position_scores
307
+ elif self.position_embedding_type == "relative_key_query":
308
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
309
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
310
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
311
+
312
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
313
+ if attention_mask is not None:
314
+ # Apply the attention mask is (precomputed for all layers in QDQBertModel forward() function)
315
+ attention_scores = attention_scores + attention_mask
316
+
317
+ # Normalize the attention scores to probabilities.
318
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
319
+
320
+ # This is actually dropping out entire tokens to attend to, which might
321
+ # seem a bit unusual, but is taken from the original Transformer paper.
322
+ attention_probs = self.dropout(attention_probs)
323
+
324
+ # Mask heads if we want to
325
+ if head_mask is not None:
326
+ attention_probs = attention_probs * head_mask
327
+
328
+ context_layer = torch.matmul(
329
+ self.matmul_a_input_quantizer(attention_probs), self.matmul_v_input_quantizer(value_layer)
330
+ )
331
+
332
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
333
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
334
+ context_layer = context_layer.view(*new_context_layer_shape)
335
+
336
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
337
+
338
+ if self.is_decoder:
339
+ outputs = outputs + (past_key_value,)
340
+ return outputs
341
+
342
+
343
+ class QDQBertSelfOutput(nn.Module):
344
+ def __init__(self, config):
345
+ super().__init__()
346
+ # Quantize Linear layer
347
+ self.dense = quant_nn.QuantLinear(config.hidden_size, config.hidden_size)
348
+
349
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
350
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
351
+
352
+ # Quantize the inputs to the residual add
353
+ self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
354
+ self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
355
+
356
+ def forward(self, hidden_states, input_tensor):
357
+ hidden_states = self.dense(hidden_states)
358
+ hidden_states = self.dropout(hidden_states)
359
+ # Quantize the inputs to the residual add
360
+ add_local = self.add_local_input_quantizer(hidden_states)
361
+ add_residual = self.add_residual_input_quantizer(input_tensor)
362
+ hidden_states = self.LayerNorm(add_local + add_residual)
363
+ return hidden_states
364
+
365
+
366
+ # Based on transformers.models.bert.modeling_bert.BertAttention with Bert -> QDQBert
367
+ class QDQBertAttention(nn.Module):
368
+ def __init__(self, config):
369
+ super().__init__()
370
+ self.self = QDQBertSelfAttention(config)
371
+ self.output = QDQBertSelfOutput(config)
372
+ self.pruned_heads = set()
373
+
374
+ def prune_heads(self, heads):
375
+ if len(heads) == 0:
376
+ return
377
+ heads, index = find_pruneable_heads_and_indices(
378
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
379
+ )
380
+
381
+ # Prune linear layers
382
+ self.self.query = prune_linear_layer(self.self.query, index)
383
+ self.self.key = prune_linear_layer(self.self.key, index)
384
+ self.self.value = prune_linear_layer(self.self.value, index)
385
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
386
+
387
+ # Update hyper params and store pruned heads
388
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
389
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
390
+ self.pruned_heads = self.pruned_heads.union(heads)
391
+
392
+ def forward(
393
+ self,
394
+ hidden_states,
395
+ attention_mask=None,
396
+ head_mask=None,
397
+ encoder_hidden_states=None,
398
+ encoder_attention_mask=None,
399
+ past_key_value=None,
400
+ output_attentions=False,
401
+ ):
402
+ self_outputs = self.self(
403
+ hidden_states,
404
+ attention_mask,
405
+ head_mask,
406
+ encoder_hidden_states,
407
+ encoder_attention_mask,
408
+ past_key_value,
409
+ output_attentions,
410
+ )
411
+ attention_output = self.output(self_outputs[0], hidden_states)
412
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
413
+ return outputs
414
+
415
+
416
+ class QDQBertIntermediate(nn.Module):
417
+ def __init__(self, config):
418
+ super().__init__()
419
+ # Quantize Linear layer
420
+ self.dense = quant_nn.QuantLinear(config.hidden_size, config.intermediate_size)
421
+ if isinstance(config.hidden_act, str):
422
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
423
+ else:
424
+ self.intermediate_act_fn = config.hidden_act
425
+
426
+ def forward(self, hidden_states):
427
+ hidden_states = self.dense(hidden_states)
428
+ hidden_states = self.intermediate_act_fn(hidden_states)
429
+ return hidden_states
430
+
431
+
432
+ class QDQBertOutput(nn.Module):
433
+ def __init__(self, config):
434
+ super().__init__()
435
+ # Quantize Linear layer
436
+ self.dense = quant_nn.QuantLinear(config.intermediate_size, config.hidden_size)
437
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
438
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
439
+
440
+ # Quantize the inputs to the residual add
441
+ self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
442
+ self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
443
+
444
+ def forward(self, hidden_states, input_tensor):
445
+ hidden_states = self.dense(hidden_states)
446
+ hidden_states = self.dropout(hidden_states)
447
+ # Quantize the inputs to the residual add
448
+ add_local = self.add_local_input_quantizer(hidden_states)
449
+ add_residual = self.add_residual_input_quantizer(input_tensor)
450
+ hidden_states = self.LayerNorm(add_local + add_residual)
451
+ return hidden_states
452
+
453
+
454
+ # Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert
455
+ class QDQBertLayer(nn.Module):
456
+ def __init__(self, config):
457
+ super().__init__()
458
+ self.seq_len_dim = 1
459
+ self.attention = QDQBertAttention(config)
460
+ self.is_decoder = config.is_decoder
461
+ self.add_cross_attention = config.add_cross_attention
462
+ if self.add_cross_attention:
463
+ if not self.is_decoder:
464
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
465
+ self.crossattention = QDQBertAttention(config)
466
+ self.intermediate = QDQBertIntermediate(config)
467
+ self.output = QDQBertOutput(config)
468
+
469
+ def forward(
470
+ self,
471
+ hidden_states,
472
+ attention_mask=None,
473
+ head_mask=None,
474
+ encoder_hidden_states=None,
475
+ encoder_attention_mask=None,
476
+ past_key_value=None,
477
+ output_attentions=False,
478
+ ):
479
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
480
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
481
+ self_attention_outputs = self.attention(
482
+ hidden_states,
483
+ attention_mask,
484
+ head_mask,
485
+ output_attentions=output_attentions,
486
+ past_key_value=self_attn_past_key_value,
487
+ )
488
+ attention_output = self_attention_outputs[0]
489
+
490
+ # if decoder, the last output is tuple of self-attn cache
491
+ if self.is_decoder:
492
+ outputs = self_attention_outputs[1:-1]
493
+ present_key_value = self_attention_outputs[-1]
494
+ else:
495
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
496
+
497
+ cross_attn_present_key_value = None
498
+ if self.is_decoder and encoder_hidden_states is not None:
499
+ if not hasattr(self, "crossattention"):
500
+ raise ValueError(
501
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
502
+ " by setting `config.add_cross_attention=True`"
503
+ )
504
+
505
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
506
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
507
+ cross_attention_outputs = self.crossattention(
508
+ attention_output,
509
+ attention_mask,
510
+ head_mask,
511
+ encoder_hidden_states,
512
+ encoder_attention_mask,
513
+ cross_attn_past_key_value,
514
+ output_attentions,
515
+ )
516
+ attention_output = cross_attention_outputs[0]
517
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
518
+
519
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
520
+ cross_attn_present_key_value = cross_attention_outputs[-1]
521
+ present_key_value = present_key_value + cross_attn_present_key_value
522
+
523
+ layer_output = self.feed_forward_chunk(attention_output)
524
+ outputs = (layer_output,) + outputs
525
+
526
+ # if decoder, return the attn key/values as the last output
527
+ if self.is_decoder:
528
+ outputs = outputs + (present_key_value,)
529
+
530
+ return outputs
531
+
532
+ def feed_forward_chunk(self, attention_output):
533
+ intermediate_output = self.intermediate(attention_output)
534
+ layer_output = self.output(intermediate_output, attention_output)
535
+ return layer_output
536
+
537
+
538
+ # Based on transformers.models.bert.modeling_bert.BertEncoder with Bert -> QDQBert
539
+ class QDQBertEncoder(nn.Module):
540
+ def __init__(self, config):
541
+ super().__init__()
542
+ self.config = config
543
+ self.layer = nn.ModuleList([QDQBertLayer(config) for _ in range(config.num_hidden_layers)])
544
+ self.gradient_checkpointing = False
545
+
546
+ def forward(
547
+ self,
548
+ hidden_states,
549
+ attention_mask=None,
550
+ head_mask=None,
551
+ encoder_hidden_states=None,
552
+ encoder_attention_mask=None,
553
+ past_key_values=None,
554
+ use_cache=None,
555
+ output_attentions=False,
556
+ output_hidden_states=False,
557
+ return_dict=True,
558
+ ):
559
+ all_hidden_states = () if output_hidden_states else None
560
+ all_self_attentions = () if output_attentions else None
561
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
562
+
563
+ next_decoder_cache = () if use_cache else None
564
+ for i, layer_module in enumerate(self.layer):
565
+ if output_hidden_states:
566
+ all_hidden_states = all_hidden_states + (hidden_states,)
567
+
568
+ layer_head_mask = head_mask[i] if head_mask is not None else None
569
+ past_key_value = past_key_values[i] if past_key_values is not None else None
570
+
571
+ if self.gradient_checkpointing and self.training:
572
+ if use_cache:
573
+ logger.warning_once(
574
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
575
+ )
576
+ use_cache = False
577
+ layer_outputs = self._gradient_checkpointing_func(
578
+ layer_module.__call__,
579
+ hidden_states,
580
+ attention_mask,
581
+ layer_head_mask,
582
+ encoder_hidden_states,
583
+ encoder_attention_mask,
584
+ past_key_value,
585
+ output_attentions,
586
+ )
587
+ else:
588
+ layer_outputs = layer_module(
589
+ hidden_states,
590
+ attention_mask,
591
+ layer_head_mask,
592
+ encoder_hidden_states,
593
+ encoder_attention_mask,
594
+ past_key_value,
595
+ output_attentions,
596
+ )
597
+
598
+ hidden_states = layer_outputs[0]
599
+ if use_cache:
600
+ next_decoder_cache += (layer_outputs[-1],)
601
+ if output_attentions:
602
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
603
+ if self.config.add_cross_attention:
604
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
605
+
606
+ if output_hidden_states:
607
+ all_hidden_states = all_hidden_states + (hidden_states,)
608
+
609
+ if not return_dict:
610
+ return tuple(
611
+ v
612
+ for v in [
613
+ hidden_states,
614
+ next_decoder_cache,
615
+ all_hidden_states,
616
+ all_self_attentions,
617
+ all_cross_attentions,
618
+ ]
619
+ if v is not None
620
+ )
621
+ return BaseModelOutputWithPastAndCrossAttentions(
622
+ last_hidden_state=hidden_states,
623
+ past_key_values=next_decoder_cache,
624
+ hidden_states=all_hidden_states,
625
+ attentions=all_self_attentions,
626
+ cross_attentions=all_cross_attentions,
627
+ )
628
+
629
+
630
+ class QDQBertPooler(nn.Module):
631
+ def __init__(self, config):
632
+ super().__init__()
633
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
634
+ self.activation = nn.Tanh()
635
+
636
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
637
+ # We "pool" the model by simply taking the hidden state corresponding
638
+ # to the first token.
639
+ first_token_tensor = hidden_states[:, 0]
640
+ pooled_output = self.dense(first_token_tensor)
641
+ pooled_output = self.activation(pooled_output)
642
+ return pooled_output
643
+
644
+
645
+ class QDQBertPredictionHeadTransform(nn.Module):
646
+ def __init__(self, config):
647
+ super().__init__()
648
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
649
+ if isinstance(config.hidden_act, str):
650
+ self.transform_act_fn = ACT2FN[config.hidden_act]
651
+ else:
652
+ self.transform_act_fn = config.hidden_act
653
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
654
+
655
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
656
+ hidden_states = self.dense(hidden_states)
657
+ hidden_states = self.transform_act_fn(hidden_states)
658
+ hidden_states = self.LayerNorm(hidden_states)
659
+ return hidden_states
660
+
661
+
662
+ # Based on transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert -> QDQBert
663
+ class QDQBertLMPredictionHead(nn.Module):
664
+ def __init__(self, config):
665
+ super().__init__()
666
+ self.transform = QDQBertPredictionHeadTransform(config)
667
+
668
+ # The output weights are the same as the input embeddings, but there is
669
+ # an output-only bias for each token.
670
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
671
+
672
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
673
+
674
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
675
+ self.decoder.bias = self.bias
676
+
677
+ def _tie_weights(self):
678
+ self.decoder.bias = self.bias
679
+
680
+ def forward(self, hidden_states):
681
+ hidden_states = self.transform(hidden_states)
682
+ hidden_states = self.decoder(hidden_states)
683
+ return hidden_states
684
+
685
+
686
+ # Based on transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert -> QDQBert
687
+ class QDQBertOnlyMLMHead(nn.Module):
688
+ def __init__(self, config):
689
+ super().__init__()
690
+ self.predictions = QDQBertLMPredictionHead(config)
691
+
692
+ def forward(self, sequence_output):
693
+ prediction_scores = self.predictions(sequence_output)
694
+ return prediction_scores
695
+
696
+
697
+ class QDQBertOnlyNSPHead(nn.Module):
698
+ def __init__(self, config):
699
+ super().__init__()
700
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
701
+
702
+ def forward(self, pooled_output):
703
+ seq_relationship_score = self.seq_relationship(pooled_output)
704
+ return seq_relationship_score
705
+
706
+
707
+ # Based on transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert -> QDQBert
708
+ class QDQBertPreTrainingHeads(nn.Module):
709
+ def __init__(self, config):
710
+ super().__init__()
711
+ self.predictions = QDQBertLMPredictionHead(config)
712
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
713
+
714
+ def forward(self, sequence_output, pooled_output):
715
+ prediction_scores = self.predictions(sequence_output)
716
+ seq_relationship_score = self.seq_relationship(pooled_output)
717
+ return prediction_scores, seq_relationship_score
718
+
719
+
720
+ # Based on transformers.models.bert.modeling_bert.BertPreTrainedModel with Bert -> QDQBert
721
+ class QDQBertPreTrainedModel(PreTrainedModel):
722
+ """
723
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
724
+ models.
725
+ """
726
+
727
+ config_class = QDQBertConfig
728
+ load_tf_weights = load_tf_weights_in_qdqbert
729
+ base_model_prefix = "bert"
730
+ supports_gradient_checkpointing = True
731
+
732
+ def _init_weights(self, module):
733
+ """Initialize the weights"""
734
+ if isinstance(module, nn.Linear):
735
+ # Slightly different from the TF version which uses truncated_normal for initialization
736
+ # cf https://github.com/pytorch/pytorch/pull/5617
737
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
738
+ if module.bias is not None:
739
+ module.bias.data.zero_()
740
+ elif isinstance(module, nn.Embedding):
741
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
742
+ if module.padding_idx is not None:
743
+ module.weight.data[module.padding_idx].zero_()
744
+ elif isinstance(module, nn.LayerNorm):
745
+ module.bias.data.zero_()
746
+ module.weight.data.fill_(1.0)
747
+
748
+
749
+ QDQBERT_START_DOCSTRING = r"""
750
+
751
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
752
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
753
+ etc.)
754
+
755
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
756
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
757
+ and behavior.
758
+
759
+ Parameters:
760
+ config ([`QDQBertConfig`]): Model configuration class with all the parameters of the model.
761
+ Initializing with a config file does not load the weights associated with the model, only the
762
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
763
+ """
764
+
765
+ QDQBERT_INPUTS_DOCSTRING = r"""
766
+ Args:
767
+ input_ids (`torch.LongTensor` of shape `({0})`):
768
+ Indices of input sequence tokens in the vocabulary.
769
+
770
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
771
+ [`PreTrainedTokenizer.__call__`] for details.
772
+
773
+ [What are input IDs?](../glossary#input-ids)
774
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
775
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
776
+
777
+ - 1 for tokens that are **not masked**,
778
+ - 0 for tokens that are **masked**.
779
+
780
+ [What are attention masks?](../glossary#attention-mask)
781
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
782
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
783
+ 1]`:
784
+
785
+ - 0 corresponds to a *sentence A* token,
786
+ - 1 corresponds to a *sentence B* token.
787
+
788
+ [What are token type IDs?](../glossary#token-type-ids)
789
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
790
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
791
+ config.max_position_embeddings - 1]`.
792
+
793
+ [What are position IDs?](../glossary#position-ids)
794
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
795
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
796
+
797
+ - 1 indicates the head is **not masked**,
798
+ - 0 indicates the head is **masked**.
799
+
800
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
801
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
802
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
803
+ model's internal embedding lookup matrix.
804
+ output_attentions (`bool`, *optional*):
805
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
806
+ tensors for more detail.
807
+ output_hidden_states (`bool`, *optional*):
808
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
809
+ more detail.
810
+ return_dict (`bool`, *optional*):
811
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
812
+ """
813
+
814
+
815
+ @add_start_docstrings(
816
+ "The bare QDQBERT Model transformer outputting raw hidden-states without any specific head on top.",
817
+ QDQBERT_START_DOCSTRING,
818
+ )
819
+ class QDQBertModel(QDQBertPreTrainedModel):
820
+ """
821
+
822
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
823
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
824
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
825
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
826
+
827
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
828
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
829
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
830
+ """
831
+
832
+ def __init__(self, config, add_pooling_layer: bool = True):
833
+ requires_backends(self, "pytorch_quantization")
834
+ super().__init__(config)
835
+ self.config = config
836
+
837
+ self.embeddings = QDQBertEmbeddings(config)
838
+ self.encoder = QDQBertEncoder(config)
839
+
840
+ self.pooler = QDQBertPooler(config) if add_pooling_layer else None
841
+
842
+ # Initialize weights and apply final processing
843
+ self.post_init()
844
+
845
+ def get_input_embeddings(self):
846
+ return self.embeddings.word_embeddings
847
+
848
+ def set_input_embeddings(self, value):
849
+ self.embeddings.word_embeddings = value
850
+
851
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]):
852
+ """
853
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
854
+ class PreTrainedModel
855
+ """
856
+ for layer, heads in heads_to_prune.items():
857
+ self.encoder.layer[layer].attention.prune_heads(heads)
858
+
859
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
860
+ @add_code_sample_docstrings(
861
+ checkpoint=_CHECKPOINT_FOR_DOC,
862
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
863
+ config_class=_CONFIG_FOR_DOC,
864
+ )
865
+ def forward(
866
+ self,
867
+ input_ids: Optional[torch.LongTensor] = None,
868
+ attention_mask: Optional[torch.FloatTensor] = None,
869
+ token_type_ids: Optional[torch.LongTensor] = None,
870
+ position_ids: Optional[torch.LongTensor] = None,
871
+ head_mask: Optional[torch.FloatTensor] = None,
872
+ inputs_embeds: Optional[torch.FloatTensor] = None,
873
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
874
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
875
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
876
+ use_cache: Optional[bool] = None,
877
+ output_attentions: Optional[bool] = None,
878
+ output_hidden_states: Optional[bool] = None,
879
+ return_dict: Optional[bool] = None,
880
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
881
+ r"""
882
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
883
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
884
+ the model is configured as a decoder.
885
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
886
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
887
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
888
+
889
+ - 1 for tokens that are **not masked**,
890
+ - 0 for tokens that are **masked**.
891
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
892
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
893
+
894
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
895
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
896
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
897
+ use_cache (`bool`, *optional*):
898
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
899
+ `past_key_values`).
900
+ """
901
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
902
+ output_hidden_states = (
903
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
904
+ )
905
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
906
+
907
+ if self.config.is_decoder:
908
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
909
+ else:
910
+ use_cache = False
911
+
912
+ if input_ids is not None and inputs_embeds is not None:
913
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
914
+ elif input_ids is not None:
915
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
916
+ input_shape = input_ids.size()
917
+ batch_size, seq_length = input_shape
918
+ elif inputs_embeds is not None:
919
+ input_shape = inputs_embeds.size()[:-1]
920
+ batch_size, seq_length = input_shape
921
+ else:
922
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
923
+
924
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
925
+
926
+ # past_key_values_length
927
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
928
+
929
+ if attention_mask is None:
930
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
931
+
932
+ if token_type_ids is None:
933
+ if hasattr(self.embeddings, "token_type_ids"):
934
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
935
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
936
+ token_type_ids = buffered_token_type_ids_expanded
937
+ else:
938
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
939
+
940
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
941
+ # ourselves in which case we just need to make it broadcastable to all heads.
942
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
943
+
944
+ # If a 2D or 3D attention mask is provided for the cross-attention
945
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
946
+ if self.config.is_decoder and encoder_hidden_states is not None:
947
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
948
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
949
+ if encoder_attention_mask is None:
950
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
951
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
952
+ else:
953
+ encoder_extended_attention_mask = None
954
+
955
+ # Prepare head mask if needed
956
+ # 1.0 in head_mask indicate we keep the head
957
+ # attention_probs has shape bsz x n_heads x N x N
958
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
959
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
960
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
961
+
962
+ embedding_output = self.embeddings(
963
+ input_ids=input_ids,
964
+ position_ids=position_ids,
965
+ token_type_ids=token_type_ids,
966
+ inputs_embeds=inputs_embeds,
967
+ past_key_values_length=past_key_values_length,
968
+ )
969
+ encoder_outputs = self.encoder(
970
+ embedding_output,
971
+ attention_mask=extended_attention_mask,
972
+ head_mask=head_mask,
973
+ encoder_hidden_states=encoder_hidden_states,
974
+ encoder_attention_mask=encoder_extended_attention_mask,
975
+ past_key_values=past_key_values,
976
+ use_cache=use_cache,
977
+ output_attentions=output_attentions,
978
+ output_hidden_states=output_hidden_states,
979
+ return_dict=return_dict,
980
+ )
981
+ sequence_output = encoder_outputs[0]
982
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
983
+
984
+ if not return_dict:
985
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
986
+
987
+ return BaseModelOutputWithPoolingAndCrossAttentions(
988
+ last_hidden_state=sequence_output,
989
+ pooler_output=pooled_output,
990
+ past_key_values=encoder_outputs.past_key_values,
991
+ hidden_states=encoder_outputs.hidden_states,
992
+ attentions=encoder_outputs.attentions,
993
+ cross_attentions=encoder_outputs.cross_attentions,
994
+ )
995
+
996
+
997
+ @add_start_docstrings(
998
+ """QDQBERT Model with a `language modeling` head on top for CLM fine-tuning.""", QDQBERT_START_DOCSTRING
999
+ )
1000
+ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
1001
+ _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
1002
+
1003
+ def __init__(self, config):
1004
+ super().__init__(config)
1005
+
1006
+ if not config.is_decoder:
1007
+ logger.warning("If you want to use `QDQBertLMHeadModel` as a standalone, add `is_decoder=True.`")
1008
+
1009
+ self.bert = QDQBertModel(config, add_pooling_layer=False)
1010
+ self.cls = QDQBertOnlyMLMHead(config)
1011
+
1012
+ # Initialize weights and apply final processing
1013
+ self.post_init()
1014
+
1015
+ def get_output_embeddings(self):
1016
+ return self.cls.predictions.decoder
1017
+
1018
+ def set_output_embeddings(self, new_embeddings):
1019
+ self.cls.predictions.decoder = new_embeddings
1020
+ self.cls.predictions.bias = new_embeddings.bias
1021
+
1022
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1023
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1024
+ def forward(
1025
+ self,
1026
+ input_ids: Optional[torch.LongTensor] = None,
1027
+ attention_mask: Optional[torch.Tensor] = None,
1028
+ token_type_ids: Optional[torch.LongTensor] = None,
1029
+ position_ids: Optional[torch.LongTensor] = None,
1030
+ head_mask: Optional[torch.Tensor] = None,
1031
+ inputs_embeds: Optional[torch.Tensor] = None,
1032
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1033
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1034
+ labels: Optional[torch.LongTensor] = None,
1035
+ past_key_values: Optional[Tuple[Tuple[torch.LongTensor]]] = None,
1036
+ use_cache: Optional[bool] = None,
1037
+ output_attentions: Optional[bool] = None,
1038
+ output_hidden_states: Optional[bool] = None,
1039
+ return_dict: Optional[bool] = None,
1040
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1041
+ r"""
1042
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1043
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1044
+ the model is configured as a decoder.
1045
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1046
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1047
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1048
+
1049
+ - 1 for tokens that are **not masked**,
1050
+ - 0 for tokens that are **masked**.
1051
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1052
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1053
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1054
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1055
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1056
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1057
+
1058
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1059
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1060
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1061
+ use_cache (`bool`, *optional*):
1062
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1063
+ `past_key_values`).
1064
+
1065
+ Returns:
1066
+
1067
+ Example:
1068
+
1069
+ ```python
1070
+ >>> from transformers import AutoTokenizer, QDQBertLMHeadModel, QDQBertConfig
1071
+ >>> import torch
1072
+
1073
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
1074
+ >>> config = QDQBertConfig.from_pretrained("google-bert/bert-base-cased")
1075
+ >>> config.is_decoder = True
1076
+ >>> model = QDQBertLMHeadModel.from_pretrained("google-bert/bert-base-cased", config=config)
1077
+
1078
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1079
+ >>> outputs = model(**inputs)
1080
+
1081
+ >>> prediction_logits = outputs.logits
1082
+ ```"""
1083
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1084
+ if labels is not None:
1085
+ use_cache = False
1086
+
1087
+ outputs = self.bert(
1088
+ input_ids,
1089
+ attention_mask=attention_mask,
1090
+ token_type_ids=token_type_ids,
1091
+ position_ids=position_ids,
1092
+ head_mask=head_mask,
1093
+ inputs_embeds=inputs_embeds,
1094
+ encoder_hidden_states=encoder_hidden_states,
1095
+ encoder_attention_mask=encoder_attention_mask,
1096
+ past_key_values=past_key_values,
1097
+ use_cache=use_cache,
1098
+ output_attentions=output_attentions,
1099
+ output_hidden_states=output_hidden_states,
1100
+ return_dict=return_dict,
1101
+ )
1102
+
1103
+ sequence_output = outputs[0]
1104
+ prediction_scores = self.cls(sequence_output)
1105
+
1106
+ lm_loss = None
1107
+ if labels is not None:
1108
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1109
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1110
+ labels = labels[:, 1:].contiguous()
1111
+ loss_fct = CrossEntropyLoss()
1112
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1113
+
1114
+ if not return_dict:
1115
+ output = (prediction_scores,) + outputs[2:]
1116
+ return ((lm_loss,) + output) if lm_loss is not None else output
1117
+
1118
+ return CausalLMOutputWithCrossAttentions(
1119
+ loss=lm_loss,
1120
+ logits=prediction_scores,
1121
+ past_key_values=outputs.past_key_values,
1122
+ hidden_states=outputs.hidden_states,
1123
+ attentions=outputs.attentions,
1124
+ cross_attentions=outputs.cross_attentions,
1125
+ )
1126
+
1127
+ def prepare_inputs_for_generation(
1128
+ self,
1129
+ input_ids: Optional[torch.LongTensor],
1130
+ past_key_values=None,
1131
+ attention_mask: Optional[torch.Tensor] = None,
1132
+ **model_kwargs,
1133
+ ):
1134
+ input_shape = input_ids.shape
1135
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1136
+ if attention_mask is None:
1137
+ attention_mask = input_ids.new_ones(input_shape)
1138
+
1139
+ # cut decoder_input_ids if past_key_values is used
1140
+ if past_key_values is not None:
1141
+ past_length = past_key_values[0][0].shape[2]
1142
+
1143
+ # Some generation methods already pass only the last input ID
1144
+ if input_ids.shape[1] > past_length:
1145
+ remove_prefix_length = past_length
1146
+ else:
1147
+ # Default to old behavior: keep only final ID
1148
+ remove_prefix_length = input_ids.shape[1] - 1
1149
+
1150
+ input_ids = input_ids[:, remove_prefix_length:]
1151
+
1152
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
1153
+
1154
+ def _reorder_cache(self, past_key_values, beam_idx):
1155
+ reordered_past = ()
1156
+ for layer_past in past_key_values:
1157
+ reordered_past += (
1158
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1159
+ )
1160
+ return reordered_past
1161
+
1162
+
1163
+ @add_start_docstrings("""QDQBERT Model with a `language modeling` head on top.""", QDQBERT_START_DOCSTRING)
1164
+ class QDQBertForMaskedLM(QDQBertPreTrainedModel):
1165
+ _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
1166
+
1167
+ def __init__(self, config):
1168
+ super().__init__(config)
1169
+
1170
+ if config.is_decoder:
1171
+ logger.warning(
1172
+ "If you want to use `QDQBertForMaskedLM` make sure `config.is_decoder=False` for "
1173
+ "bi-directional self-attention."
1174
+ )
1175
+
1176
+ self.bert = QDQBertModel(config, add_pooling_layer=False)
1177
+ self.cls = QDQBertOnlyMLMHead(config)
1178
+
1179
+ # Initialize weights and apply final processing
1180
+ self.post_init()
1181
+
1182
+ def get_output_embeddings(self):
1183
+ return self.cls.predictions.decoder
1184
+
1185
+ def set_output_embeddings(self, new_embeddings):
1186
+ self.cls.predictions.decoder = new_embeddings
1187
+ self.cls.predictions.bias = new_embeddings.bias
1188
+
1189
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1190
+ @add_code_sample_docstrings(
1191
+ checkpoint=_CHECKPOINT_FOR_DOC,
1192
+ output_type=MaskedLMOutput,
1193
+ config_class=_CONFIG_FOR_DOC,
1194
+ )
1195
+ def forward(
1196
+ self,
1197
+ input_ids: Optional[torch.LongTensor] = None,
1198
+ attention_mask: Optional[torch.FloatTensor] = None,
1199
+ token_type_ids: Optional[torch.LongTensor] = None,
1200
+ position_ids: Optional[torch.LongTensor] = None,
1201
+ head_mask: Optional[torch.FloatTensor] = None,
1202
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1203
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1204
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1205
+ labels: Optional[torch.LongTensor] = None,
1206
+ output_attentions: Optional[bool] = None,
1207
+ output_hidden_states: Optional[bool] = None,
1208
+ return_dict: Optional[bool] = None,
1209
+ ) -> Union[Tuple, MaskedLMOutput]:
1210
+ r"""
1211
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1212
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1213
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1214
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1215
+ """
1216
+
1217
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1218
+
1219
+ outputs = self.bert(
1220
+ input_ids,
1221
+ attention_mask=attention_mask,
1222
+ token_type_ids=token_type_ids,
1223
+ position_ids=position_ids,
1224
+ head_mask=head_mask,
1225
+ inputs_embeds=inputs_embeds,
1226
+ encoder_hidden_states=encoder_hidden_states,
1227
+ encoder_attention_mask=encoder_attention_mask,
1228
+ output_attentions=output_attentions,
1229
+ output_hidden_states=output_hidden_states,
1230
+ return_dict=return_dict,
1231
+ )
1232
+
1233
+ sequence_output = outputs[0]
1234
+ prediction_scores = self.cls(sequence_output)
1235
+
1236
+ masked_lm_loss = None
1237
+ if labels is not None:
1238
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1239
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1240
+
1241
+ if not return_dict:
1242
+ output = (prediction_scores,) + outputs[2:]
1243
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1244
+
1245
+ return MaskedLMOutput(
1246
+ loss=masked_lm_loss,
1247
+ logits=prediction_scores,
1248
+ hidden_states=outputs.hidden_states,
1249
+ attentions=outputs.attentions,
1250
+ )
1251
+
1252
+ def prepare_inputs_for_generation(
1253
+ self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, **model_kwargs
1254
+ ):
1255
+ input_shape = input_ids.shape
1256
+ effective_batch_size = input_shape[0]
1257
+
1258
+ # add a dummy token
1259
+ if self.config.pad_token_id is None:
1260
+ raise ValueError("The PAD token should be defined for generation")
1261
+
1262
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1263
+ dummy_token = torch.full(
1264
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1265
+ )
1266
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1267
+
1268
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1269
+
1270
+
1271
+ @add_start_docstrings(
1272
+ """Bert Model with a `next sentence prediction (classification)` head on top.""",
1273
+ QDQBERT_START_DOCSTRING,
1274
+ )
1275
+ class QDQBertForNextSentencePrediction(QDQBertPreTrainedModel):
1276
+ def __init__(self, config):
1277
+ super().__init__(config)
1278
+
1279
+ self.bert = QDQBertModel(config)
1280
+ self.cls = QDQBertOnlyNSPHead(config)
1281
+
1282
+ # Initialize weights and apply final processing
1283
+ self.post_init()
1284
+
1285
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1286
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1287
+ def forward(
1288
+ self,
1289
+ input_ids: Optional[torch.LongTensor] = None,
1290
+ attention_mask: Optional[torch.FloatTensor] = None,
1291
+ token_type_ids: Optional[torch.LongTensor] = None,
1292
+ position_ids: Optional[torch.LongTensor] = None,
1293
+ head_mask: Optional[torch.FloatTensor] = None,
1294
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1295
+ labels: Optional[torch.LongTensor] = None,
1296
+ output_attentions: Optional[bool] = None,
1297
+ output_hidden_states: Optional[bool] = None,
1298
+ return_dict: Optional[bool] = None,
1299
+ **kwargs,
1300
+ ) -> Union[Tuple, NextSentencePredictorOutput]:
1301
+ r"""
1302
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1303
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1304
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1305
+
1306
+ - 0 indicates sequence B is a continuation of sequence A,
1307
+ - 1 indicates sequence B is a random sequence.
1308
+
1309
+ Returns:
1310
+
1311
+ Example:
1312
+
1313
+ ```python
1314
+ >>> from transformers import AutoTokenizer, QDQBertForNextSentencePrediction
1315
+ >>> import torch
1316
+
1317
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1318
+ >>> model = QDQBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
1319
+
1320
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1321
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1322
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1323
+
1324
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1325
+ >>> logits = outputs.logits
1326
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1327
+ ```"""
1328
+
1329
+ if "next_sentence_label" in kwargs:
1330
+ warnings.warn(
1331
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
1332
+ " `labels` instead.",
1333
+ FutureWarning,
1334
+ )
1335
+ labels = kwargs.pop("next_sentence_label")
1336
+
1337
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1338
+
1339
+ outputs = self.bert(
1340
+ input_ids,
1341
+ attention_mask=attention_mask,
1342
+ token_type_ids=token_type_ids,
1343
+ position_ids=position_ids,
1344
+ head_mask=head_mask,
1345
+ inputs_embeds=inputs_embeds,
1346
+ output_attentions=output_attentions,
1347
+ output_hidden_states=output_hidden_states,
1348
+ return_dict=return_dict,
1349
+ )
1350
+
1351
+ pooled_output = outputs[1]
1352
+
1353
+ seq_relationship_scores = self.cls(pooled_output)
1354
+
1355
+ next_sentence_loss = None
1356
+ if labels is not None:
1357
+ loss_fct = CrossEntropyLoss()
1358
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1359
+
1360
+ if not return_dict:
1361
+ output = (seq_relationship_scores,) + outputs[2:]
1362
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1363
+
1364
+ return NextSentencePredictorOutput(
1365
+ loss=next_sentence_loss,
1366
+ logits=seq_relationship_scores,
1367
+ hidden_states=outputs.hidden_states,
1368
+ attentions=outputs.attentions,
1369
+ )
1370
+
1371
+
1372
+ @add_start_docstrings(
1373
+ """
1374
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1375
+ output) e.g. for GLUE tasks.
1376
+ """,
1377
+ QDQBERT_START_DOCSTRING,
1378
+ )
1379
+ class QDQBertForSequenceClassification(QDQBertPreTrainedModel):
1380
+ def __init__(self, config):
1381
+ super().__init__(config)
1382
+ self.num_labels = config.num_labels
1383
+ self.config = config
1384
+
1385
+ self.bert = QDQBertModel(config)
1386
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1387
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1388
+ # Initialize weights and apply final processing
1389
+ self.post_init()
1390
+
1391
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1392
+ @add_code_sample_docstrings(
1393
+ checkpoint=_CHECKPOINT_FOR_DOC,
1394
+ output_type=SequenceClassifierOutput,
1395
+ config_class=_CONFIG_FOR_DOC,
1396
+ )
1397
+ def forward(
1398
+ self,
1399
+ input_ids: Optional[torch.LongTensor] = None,
1400
+ attention_mask: Optional[torch.FloatTensor] = None,
1401
+ token_type_ids: Optional[torch.LongTensor] = None,
1402
+ position_ids: Optional[torch.LongTensor] = None,
1403
+ head_mask: Optional[torch.FloatTensor] = None,
1404
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1405
+ labels: Optional[torch.LongTensor] = None,
1406
+ output_attentions: Optional[bool] = None,
1407
+ output_hidden_states: Optional[bool] = None,
1408
+ return_dict: Optional[bool] = None,
1409
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1410
+ r"""
1411
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1412
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1413
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1414
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1415
+ """
1416
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1417
+
1418
+ outputs = self.bert(
1419
+ input_ids,
1420
+ attention_mask=attention_mask,
1421
+ token_type_ids=token_type_ids,
1422
+ position_ids=position_ids,
1423
+ head_mask=head_mask,
1424
+ inputs_embeds=inputs_embeds,
1425
+ output_attentions=output_attentions,
1426
+ output_hidden_states=output_hidden_states,
1427
+ return_dict=return_dict,
1428
+ )
1429
+
1430
+ pooled_output = outputs[1]
1431
+
1432
+ pooled_output = self.dropout(pooled_output)
1433
+ logits = self.classifier(pooled_output)
1434
+
1435
+ loss = None
1436
+ if labels is not None:
1437
+ if self.config.problem_type is None:
1438
+ if self.num_labels == 1:
1439
+ self.config.problem_type = "regression"
1440
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1441
+ self.config.problem_type = "single_label_classification"
1442
+ else:
1443
+ self.config.problem_type = "multi_label_classification"
1444
+
1445
+ if self.config.problem_type == "regression":
1446
+ loss_fct = MSELoss()
1447
+ if self.num_labels == 1:
1448
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1449
+ else:
1450
+ loss = loss_fct(logits, labels)
1451
+ elif self.config.problem_type == "single_label_classification":
1452
+ loss_fct = CrossEntropyLoss()
1453
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1454
+ elif self.config.problem_type == "multi_label_classification":
1455
+ loss_fct = BCEWithLogitsLoss()
1456
+ loss = loss_fct(logits, labels)
1457
+ if not return_dict:
1458
+ output = (logits,) + outputs[2:]
1459
+ return ((loss,) + output) if loss is not None else output
1460
+
1461
+ return SequenceClassifierOutput(
1462
+ loss=loss,
1463
+ logits=logits,
1464
+ hidden_states=outputs.hidden_states,
1465
+ attentions=outputs.attentions,
1466
+ )
1467
+
1468
+
1469
+ @add_start_docstrings(
1470
+ """
1471
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1472
+ softmax) e.g. for RocStories/SWAG tasks.
1473
+ """,
1474
+ QDQBERT_START_DOCSTRING,
1475
+ )
1476
+ class QDQBertForMultipleChoice(QDQBertPreTrainedModel):
1477
+ def __init__(self, config):
1478
+ super().__init__(config)
1479
+
1480
+ self.bert = QDQBertModel(config)
1481
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1482
+ self.classifier = nn.Linear(config.hidden_size, 1)
1483
+
1484
+ # Initialize weights and apply final processing
1485
+ self.post_init()
1486
+
1487
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1488
+ @add_code_sample_docstrings(
1489
+ checkpoint=_CHECKPOINT_FOR_DOC,
1490
+ output_type=MultipleChoiceModelOutput,
1491
+ config_class=_CONFIG_FOR_DOC,
1492
+ )
1493
+ def forward(
1494
+ self,
1495
+ input_ids: Optional[torch.LongTensor] = None,
1496
+ attention_mask: Optional[torch.FloatTensor] = None,
1497
+ token_type_ids: Optional[torch.LongTensor] = None,
1498
+ position_ids: Optional[torch.LongTensor] = None,
1499
+ head_mask: Optional[torch.FloatTensor] = None,
1500
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1501
+ labels: Optional[torch.LongTensor] = None,
1502
+ output_attentions: Optional[bool] = None,
1503
+ output_hidden_states: Optional[bool] = None,
1504
+ return_dict: Optional[bool] = None,
1505
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
1506
+ r"""
1507
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1508
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1509
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1510
+ `input_ids` above)
1511
+ """
1512
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1513
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1514
+
1515
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1516
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1517
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1518
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1519
+ inputs_embeds = (
1520
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1521
+ if inputs_embeds is not None
1522
+ else None
1523
+ )
1524
+
1525
+ outputs = self.bert(
1526
+ input_ids,
1527
+ attention_mask=attention_mask,
1528
+ token_type_ids=token_type_ids,
1529
+ position_ids=position_ids,
1530
+ head_mask=head_mask,
1531
+ inputs_embeds=inputs_embeds,
1532
+ output_attentions=output_attentions,
1533
+ output_hidden_states=output_hidden_states,
1534
+ return_dict=return_dict,
1535
+ )
1536
+
1537
+ pooled_output = outputs[1]
1538
+
1539
+ pooled_output = self.dropout(pooled_output)
1540
+ logits = self.classifier(pooled_output)
1541
+ reshaped_logits = logits.view(-1, num_choices)
1542
+
1543
+ loss = None
1544
+ if labels is not None:
1545
+ loss_fct = CrossEntropyLoss()
1546
+ loss = loss_fct(reshaped_logits, labels)
1547
+
1548
+ if not return_dict:
1549
+ output = (reshaped_logits,) + outputs[2:]
1550
+ return ((loss,) + output) if loss is not None else output
1551
+
1552
+ return MultipleChoiceModelOutput(
1553
+ loss=loss,
1554
+ logits=reshaped_logits,
1555
+ hidden_states=outputs.hidden_states,
1556
+ attentions=outputs.attentions,
1557
+ )
1558
+
1559
+
1560
+ @add_start_docstrings(
1561
+ """
1562
+ QDQBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1563
+ Named-Entity-Recognition (NER) tasks.
1564
+ """,
1565
+ QDQBERT_START_DOCSTRING,
1566
+ )
1567
+ class QDQBertForTokenClassification(QDQBertPreTrainedModel):
1568
+ def __init__(self, config):
1569
+ super().__init__(config)
1570
+ self.num_labels = config.num_labels
1571
+
1572
+ self.bert = QDQBertModel(config, add_pooling_layer=False)
1573
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1574
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1575
+
1576
+ # Initialize weights and apply final processing
1577
+ self.post_init()
1578
+
1579
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1580
+ @add_code_sample_docstrings(
1581
+ checkpoint=_CHECKPOINT_FOR_DOC,
1582
+ output_type=TokenClassifierOutput,
1583
+ config_class=_CONFIG_FOR_DOC,
1584
+ )
1585
+ def forward(
1586
+ self,
1587
+ input_ids: Optional[torch.LongTensor] = None,
1588
+ attention_mask: Optional[torch.FloatTensor] = None,
1589
+ token_type_ids: Optional[torch.LongTensor] = None,
1590
+ position_ids: Optional[torch.LongTensor] = None,
1591
+ head_mask: Optional[torch.FloatTensor] = None,
1592
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1593
+ labels: Optional[torch.LongTensor] = None,
1594
+ output_attentions: Optional[bool] = None,
1595
+ output_hidden_states: Optional[bool] = None,
1596
+ return_dict: Optional[bool] = None,
1597
+ ) -> Union[Tuple, TokenClassifierOutput]:
1598
+ r"""
1599
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1600
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1601
+ """
1602
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1603
+
1604
+ outputs = self.bert(
1605
+ input_ids,
1606
+ attention_mask=attention_mask,
1607
+ token_type_ids=token_type_ids,
1608
+ position_ids=position_ids,
1609
+ head_mask=head_mask,
1610
+ inputs_embeds=inputs_embeds,
1611
+ output_attentions=output_attentions,
1612
+ output_hidden_states=output_hidden_states,
1613
+ return_dict=return_dict,
1614
+ )
1615
+
1616
+ sequence_output = outputs[0]
1617
+
1618
+ sequence_output = self.dropout(sequence_output)
1619
+ logits = self.classifier(sequence_output)
1620
+
1621
+ loss = None
1622
+ if labels is not None:
1623
+ loss_fct = CrossEntropyLoss()
1624
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1625
+
1626
+ if not return_dict:
1627
+ output = (logits,) + outputs[2:]
1628
+ return ((loss,) + output) if loss is not None else output
1629
+
1630
+ return TokenClassifierOutput(
1631
+ loss=loss,
1632
+ logits=logits,
1633
+ hidden_states=outputs.hidden_states,
1634
+ attentions=outputs.attentions,
1635
+ )
1636
+
1637
+
1638
+ @add_start_docstrings(
1639
+ """
1640
+ QDQBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1641
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1642
+ """,
1643
+ QDQBERT_START_DOCSTRING,
1644
+ )
1645
+ class QDQBertForQuestionAnswering(QDQBertPreTrainedModel):
1646
+ def __init__(self, config):
1647
+ super().__init__(config)
1648
+ self.num_labels = config.num_labels
1649
+
1650
+ self.bert = QDQBertModel(config, add_pooling_layer=False)
1651
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1652
+
1653
+ # Initialize weights and apply final processing
1654
+ self.post_init()
1655
+
1656
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1657
+ @add_code_sample_docstrings(
1658
+ checkpoint=_CHECKPOINT_FOR_DOC,
1659
+ output_type=QuestionAnsweringModelOutput,
1660
+ config_class=_CONFIG_FOR_DOC,
1661
+ )
1662
+ def forward(
1663
+ self,
1664
+ input_ids: Optional[torch.LongTensor] = None,
1665
+ attention_mask: Optional[torch.FloatTensor] = None,
1666
+ token_type_ids: Optional[torch.LongTensor] = None,
1667
+ position_ids: Optional[torch.LongTensor] = None,
1668
+ head_mask: Optional[torch.FloatTensor] = None,
1669
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1670
+ start_positions: Optional[torch.LongTensor] = None,
1671
+ end_positions: Optional[torch.LongTensor] = None,
1672
+ output_attentions: Optional[bool] = None,
1673
+ output_hidden_states: Optional[bool] = None,
1674
+ return_dict: Optional[bool] = None,
1675
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1676
+ r"""
1677
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1678
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1679
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1680
+ are not taken into account for computing the loss.
1681
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1682
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1683
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1684
+ are not taken into account for computing the loss.
1685
+ """
1686
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1687
+
1688
+ outputs = self.bert(
1689
+ input_ids,
1690
+ attention_mask=attention_mask,
1691
+ token_type_ids=token_type_ids,
1692
+ position_ids=position_ids,
1693
+ head_mask=head_mask,
1694
+ inputs_embeds=inputs_embeds,
1695
+ output_attentions=output_attentions,
1696
+ output_hidden_states=output_hidden_states,
1697
+ return_dict=return_dict,
1698
+ )
1699
+
1700
+ sequence_output = outputs[0]
1701
+
1702
+ logits = self.qa_outputs(sequence_output)
1703
+ start_logits, end_logits = logits.split(1, dim=-1)
1704
+ start_logits = start_logits.squeeze(-1).contiguous()
1705
+ end_logits = end_logits.squeeze(-1).contiguous()
1706
+
1707
+ total_loss = None
1708
+ if start_positions is not None and end_positions is not None:
1709
+ # If we are on multi-GPU, split add a dimension
1710
+ if len(start_positions.size()) > 1:
1711
+ start_positions = start_positions.squeeze(-1)
1712
+ if len(end_positions.size()) > 1:
1713
+ end_positions = end_positions.squeeze(-1)
1714
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1715
+ ignored_index = start_logits.size(1)
1716
+ start_positions = start_positions.clamp(0, ignored_index)
1717
+ end_positions = end_positions.clamp(0, ignored_index)
1718
+
1719
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1720
+ start_loss = loss_fct(start_logits, start_positions)
1721
+ end_loss = loss_fct(end_logits, end_positions)
1722
+ total_loss = (start_loss + end_loss) / 2
1723
+
1724
+ if not return_dict:
1725
+ output = (start_logits, end_logits) + outputs[2:]
1726
+ return ((total_loss,) + output) if total_loss is not None else output
1727
+
1728
+ return QuestionAnsweringModelOutput(
1729
+ loss=total_loss,
1730
+ start_logits=start_logits,
1731
+ end_logits=end_logits,
1732
+ hidden_states=outputs.hidden_states,
1733
+ attentions=outputs.attentions,
1734
+ )
1735
+
1736
+
1737
+ __all__ = [
1738
+ "QDQBertForMaskedLM",
1739
+ "QDQBertForMultipleChoice",
1740
+ "QDQBertForNextSentencePrediction",
1741
+ "QDQBertForQuestionAnswering",
1742
+ "QDQBertForSequenceClassification",
1743
+ "QDQBertForTokenClassification",
1744
+ "QDQBertLayer",
1745
+ "QDQBertLMHeadModel",
1746
+ "QDQBertModel",
1747
+ "QDQBertPreTrainedModel",
1748
+ "load_tf_weights_in_qdqbert",
1749
+ ]